1/*
2 * SSL3 Protocol
3 *
4 * This Source Code Form is subject to the terms of the Mozilla Public
5 * License, v. 2.0. If a copy of the MPL was not distributed with this
6 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
7
8/* ECC code moved here from ssl3con.c */
9
10#include "nss.h"
11#include "cert.h"
12#include "ssl.h"
13#include "cryptohi.h"	/* for DSAU_ stuff */
14#include "keyhi.h"
15#include "secder.h"
16#include "secitem.h"
17
18#include "sslimpl.h"
19#include "sslproto.h"
20#include "sslerr.h"
21#include "prtime.h"
22#include "prinrval.h"
23#include "prerror.h"
24#include "pratom.h"
25#include "prthread.h"
26#include "prinit.h"
27
28#include "pk11func.h"
29#include "secmod.h"
30
31#include <stdio.h>
32
33/* This is a bodge to allow this code to be compiled against older NSS headers
34 * that don't contain the TLS 1.2 changes. */
35#ifndef CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256
36#define CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256 (CKM_NSS + 24)
37#endif
38
39#ifdef NSS_ENABLE_ECC
40
41#ifndef PK11_SETATTRS
42#define PK11_SETATTRS(x,id,v,l) (x)->type = (id); \
43		(x)->pValue=(v); (x)->ulValueLen = (l);
44#endif
45
46#define SSL_GET_SERVER_PUBLIC_KEY(sock, type) \
47    (ss->serverCerts[type].serverKeyPair ? \
48    ss->serverCerts[type].serverKeyPair->pubKey : NULL)
49
50#define SSL_IS_CURVE_NEGOTIATED(curvemsk, curveName) \
51    ((curveName > ec_noName) && \
52     (curveName < ec_pastLastName) && \
53     ((1UL << curveName) & curvemsk) != 0)
54
55
56
57static SECStatus ssl3_CreateECDHEphemeralKeys(sslSocket *ss, ECName ec_curve);
58
59#define supportedCurve(x) (((x) > ec_noName) && ((x) < ec_pastLastName))
60
61/* Table containing OID tags for elliptic curves named in the
62 * ECC-TLS IETF draft.
63 */
64static const SECOidTag ecName2OIDTag[] = {
65	0,
66	SEC_OID_SECG_EC_SECT163K1,  /*  1 */
67	SEC_OID_SECG_EC_SECT163R1,  /*  2 */
68	SEC_OID_SECG_EC_SECT163R2,  /*  3 */
69	SEC_OID_SECG_EC_SECT193R1,  /*  4 */
70	SEC_OID_SECG_EC_SECT193R2,  /*  5 */
71	SEC_OID_SECG_EC_SECT233K1,  /*  6 */
72	SEC_OID_SECG_EC_SECT233R1,  /*  7 */
73	SEC_OID_SECG_EC_SECT239K1,  /*  8 */
74	SEC_OID_SECG_EC_SECT283K1,  /*  9 */
75	SEC_OID_SECG_EC_SECT283R1,  /* 10 */
76	SEC_OID_SECG_EC_SECT409K1,  /* 11 */
77	SEC_OID_SECG_EC_SECT409R1,  /* 12 */
78	SEC_OID_SECG_EC_SECT571K1,  /* 13 */
79	SEC_OID_SECG_EC_SECT571R1,  /* 14 */
80	SEC_OID_SECG_EC_SECP160K1,  /* 15 */
81	SEC_OID_SECG_EC_SECP160R1,  /* 16 */
82	SEC_OID_SECG_EC_SECP160R2,  /* 17 */
83	SEC_OID_SECG_EC_SECP192K1,  /* 18 */
84	SEC_OID_SECG_EC_SECP192R1,  /* 19 */
85	SEC_OID_SECG_EC_SECP224K1,  /* 20 */
86	SEC_OID_SECG_EC_SECP224R1,  /* 21 */
87	SEC_OID_SECG_EC_SECP256K1,  /* 22 */
88	SEC_OID_SECG_EC_SECP256R1,  /* 23 */
89	SEC_OID_SECG_EC_SECP384R1,  /* 24 */
90	SEC_OID_SECG_EC_SECP521R1,  /* 25 */
91};
92
93static const PRUint16 curve2bits[] = {
94	  0, /*  ec_noName     = 0,   */
95	163, /*  ec_sect163k1  = 1,   */
96	163, /*  ec_sect163r1  = 2,   */
97	163, /*  ec_sect163r2  = 3,   */
98	193, /*  ec_sect193r1  = 4,   */
99	193, /*  ec_sect193r2  = 5,   */
100	233, /*  ec_sect233k1  = 6,   */
101	233, /*  ec_sect233r1  = 7,   */
102	239, /*  ec_sect239k1  = 8,   */
103	283, /*  ec_sect283k1  = 9,   */
104	283, /*  ec_sect283r1  = 10,  */
105	409, /*  ec_sect409k1  = 11,  */
106	409, /*  ec_sect409r1  = 12,  */
107	571, /*  ec_sect571k1  = 13,  */
108	571, /*  ec_sect571r1  = 14,  */
109	160, /*  ec_secp160k1  = 15,  */
110	160, /*  ec_secp160r1  = 16,  */
111	160, /*  ec_secp160r2  = 17,  */
112	192, /*  ec_secp192k1  = 18,  */
113	192, /*  ec_secp192r1  = 19,  */
114	224, /*  ec_secp224k1  = 20,  */
115	224, /*  ec_secp224r1  = 21,  */
116	256, /*  ec_secp256k1  = 22,  */
117	256, /*  ec_secp256r1  = 23,  */
118	384, /*  ec_secp384r1  = 24,  */
119	521, /*  ec_secp521r1  = 25,  */
120      65535  /*  ec_pastLastName      */
121};
122
123typedef struct Bits2CurveStr {
124    PRUint16    bits;
125    ECName      curve;
126} Bits2Curve;
127
128static const Bits2Curve bits2curve [] = {
129   {	192,     ec_secp192r1    /*  = 19,  fast */  },
130   {	160,     ec_secp160r2    /*  = 17,  fast */  },
131   {	160,     ec_secp160k1    /*  = 15,  */       },
132   {	160,     ec_secp160r1    /*  = 16,  */       },
133   {	163,     ec_sect163k1    /*  = 1,   */       },
134   {	163,     ec_sect163r1    /*  = 2,   */       },
135   {	163,     ec_sect163r2    /*  = 3,   */       },
136   {	192,     ec_secp192k1    /*  = 18,  */       },
137   {	193,     ec_sect193r1    /*  = 4,   */       },
138   {	193,     ec_sect193r2    /*  = 5,   */       },
139   {	224,     ec_secp224r1    /*  = 21,  fast */  },
140   {	224,     ec_secp224k1    /*  = 20,  */       },
141   {	233,     ec_sect233k1    /*  = 6,   */       },
142   {	233,     ec_sect233r1    /*  = 7,   */       },
143   {	239,     ec_sect239k1    /*  = 8,   */       },
144   {	256,     ec_secp256r1    /*  = 23,  fast */  },
145   {	256,     ec_secp256k1    /*  = 22,  */       },
146   {	283,     ec_sect283k1    /*  = 9,   */       },
147   {	283,     ec_sect283r1    /*  = 10,  */       },
148   {	384,     ec_secp384r1    /*  = 24,  fast */  },
149   {	409,     ec_sect409k1    /*  = 11,  */       },
150   {	409,     ec_sect409r1    /*  = 12,  */       },
151   {	521,     ec_secp521r1    /*  = 25,  fast */  },
152   {	571,     ec_sect571k1    /*  = 13,  */       },
153   {	571,     ec_sect571r1    /*  = 14,  */       },
154   {  65535,     ec_noName    }
155};
156
157typedef struct ECDHEKeyPairStr {
158    ssl3KeyPair *  pair;
159    int            error;  /* error code of the call-once function */
160    PRCallOnceType once;
161} ECDHEKeyPair;
162
163/* arrays of ECDHE KeyPairs */
164static ECDHEKeyPair gECDHEKeyPairs[ec_pastLastName];
165
166SECStatus
167ssl3_ECName2Params(PLArenaPool * arena, ECName curve, SECKEYECParams * params)
168{
169    SECOidData *oidData = NULL;
170
171    if ((curve <= ec_noName) || (curve >= ec_pastLastName) ||
172	((oidData = SECOID_FindOIDByTag(ecName2OIDTag[curve])) == NULL)) {
173        PORT_SetError(SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE);
174	return SECFailure;
175    }
176
177    SECITEM_AllocItem(arena, params, (2 + oidData->oid.len));
178    /*
179     * params->data needs to contain the ASN encoding of an object ID (OID)
180     * representing the named curve. The actual OID is in
181     * oidData->oid.data so we simply prepend 0x06 and OID length
182     */
183    params->data[0] = SEC_ASN1_OBJECT_ID;
184    params->data[1] = oidData->oid.len;
185    memcpy(params->data + 2, oidData->oid.data, oidData->oid.len);
186
187    return SECSuccess;
188}
189
190static ECName
191params2ecName(SECKEYECParams * params)
192{
193    SECItem oid = { siBuffer, NULL, 0};
194    SECOidData *oidData = NULL;
195    ECName i;
196
197    /*
198     * params->data needs to contain the ASN encoding of an object ID (OID)
199     * representing a named curve. Here, we strip away everything
200     * before the actual OID and use the OID to look up a named curve.
201     */
202    if (params->data[0] != SEC_ASN1_OBJECT_ID) return ec_noName;
203    oid.len = params->len - 2;
204    oid.data = params->data + 2;
205    if ((oidData = SECOID_FindOID(&oid)) == NULL) return ec_noName;
206    for (i = ec_noName + 1; i < ec_pastLastName; i++) {
207	if (ecName2OIDTag[i] == oidData->offset)
208	    return i;
209    }
210
211    return ec_noName;
212}
213
214/* Caller must set hiLevel error code. */
215static SECStatus
216ssl3_ComputeECDHKeyHash(SECOidTag hashAlg,
217			SECItem ec_params, SECItem server_ecpoint,
218			SSL3Random *client_rand, SSL3Random *server_rand,
219			SSL3Hashes *hashes, PRBool bypassPKCS11)
220{
221    PRUint8     * hashBuf;
222    PRUint8     * pBuf;
223    SECStatus     rv 		= SECSuccess;
224    unsigned int  bufLen;
225    /*
226     * XXX For now, we only support named curves (the appropriate
227     * checks are made before this method is called) so ec_params
228     * takes up only two bytes. ECPoint needs to fit in 256 bytes
229     * (because the spec says the length must fit in one byte)
230     */
231    PRUint8       buf[2*SSL3_RANDOM_LENGTH + 2 + 1 + 256];
232
233    bufLen = 2*SSL3_RANDOM_LENGTH + ec_params.len + 1 + server_ecpoint.len;
234    if (bufLen <= sizeof buf) {
235    	hashBuf = buf;
236    } else {
237    	hashBuf = PORT_Alloc(bufLen);
238	if (!hashBuf) {
239	    return SECFailure;
240	}
241    }
242
243    memcpy(hashBuf, client_rand, SSL3_RANDOM_LENGTH);
244    	pBuf = hashBuf + SSL3_RANDOM_LENGTH;
245    memcpy(pBuf, server_rand, SSL3_RANDOM_LENGTH);
246    	pBuf += SSL3_RANDOM_LENGTH;
247    memcpy(pBuf, ec_params.data, ec_params.len);
248    	pBuf += ec_params.len;
249    pBuf[0] = (PRUint8)(server_ecpoint.len);
250    pBuf += 1;
251    memcpy(pBuf, server_ecpoint.data, server_ecpoint.len);
252    	pBuf += server_ecpoint.len;
253    PORT_Assert((unsigned int)(pBuf - hashBuf) == bufLen);
254
255    rv = ssl3_ComputeCommonKeyHash(hashAlg, hashBuf, bufLen, hashes,
256				   bypassPKCS11);
257
258    PRINT_BUF(95, (NULL, "ECDHkey hash: ", hashBuf, bufLen));
259    PRINT_BUF(95, (NULL, "ECDHkey hash: MD5 result",
260	      hashes->u.s.md5, MD5_LENGTH));
261    PRINT_BUF(95, (NULL, "ECDHkey hash: SHA1 result",
262	      hashes->u.s.sha, SHA1_LENGTH));
263
264    if (hashBuf != buf)
265    	PORT_Free(hashBuf);
266    return rv;
267}
268
269
270/* Called from ssl3_SendClientKeyExchange(). */
271SECStatus
272ssl3_SendECDHClientKeyExchange(sslSocket * ss, SECKEYPublicKey * svrPubKey)
273{
274    PK11SymKey *	pms 		= NULL;
275    SECStatus           rv    		= SECFailure;
276    PRBool              isTLS, isTLS12;
277    CK_MECHANISM_TYPE	target;
278    SECKEYPublicKey	*pubKey = NULL;		/* Ephemeral ECDH key */
279    SECKEYPrivateKey	*privKey = NULL;	/* Ephemeral ECDH key */
280
281    PORT_Assert( ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss) );
282    PORT_Assert( ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
283
284    isTLS = (PRBool)(ss->ssl3.pwSpec->version > SSL_LIBRARY_VERSION_3_0);
285    isTLS12 = (PRBool)(ss->ssl3.pwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_2);
286
287    /* Generate ephemeral EC keypair */
288    if (svrPubKey->keyType != ecKey) {
289	PORT_SetError(SEC_ERROR_BAD_KEY);
290	goto loser;
291    }
292    /* XXX SHOULD CALL ssl3_CreateECDHEphemeralKeys here, instead! */
293    privKey = SECKEY_CreateECPrivateKey(&svrPubKey->u.ec.DEREncodedParams,
294	                                &pubKey, ss->pkcs11PinArg);
295    if (!privKey || !pubKey) {
296	    ssl_MapLowLevelError(SEC_ERROR_KEYGEN_FAIL);
297	    rv = SECFailure;
298	    goto loser;
299    }
300    PRINT_BUF(50, (ss, "ECDH public value:",
301					pubKey->u.ec.publicValue.data,
302					pubKey->u.ec.publicValue.len));
303
304    if (isTLS12) {
305	target = CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256;
306    } else if (isTLS) {
307	target = CKM_TLS_MASTER_KEY_DERIVE_DH;
308    } else {
309	target = CKM_SSL3_MASTER_KEY_DERIVE_DH;
310    }
311
312    /*  Determine the PMS */
313    pms = PK11_PubDeriveWithKDF(privKey, svrPubKey, PR_FALSE, NULL, NULL,
314			    CKM_ECDH1_DERIVE, target, CKA_DERIVE, 0,
315			    CKD_NULL, NULL, NULL);
316
317    if (pms == NULL) {
318	SSL3AlertDescription desc  = illegal_parameter;
319	(void)SSL3_SendAlert(ss, alert_fatal, desc);
320	ssl_MapLowLevelError(SSL_ERROR_CLIENT_KEY_EXCHANGE_FAILURE);
321	goto loser;
322    }
323
324    SECKEY_DestroyPrivateKey(privKey);
325    privKey = NULL;
326
327    rv = ssl3_InitPendingCipherSpec(ss,  pms);
328    PK11_FreeSymKey(pms); pms = NULL;
329
330    if (rv != SECSuccess) {
331	ssl_MapLowLevelError(SSL_ERROR_CLIENT_KEY_EXCHANGE_FAILURE);
332	goto loser;
333    }
334
335    rv = ssl3_AppendHandshakeHeader(ss, client_key_exchange,
336					pubKey->u.ec.publicValue.len + 1);
337    if (rv != SECSuccess) {
338        goto loser;	/* err set by ssl3_AppendHandshake* */
339    }
340
341    rv = ssl3_AppendHandshakeVariable(ss,
342					pubKey->u.ec.publicValue.data,
343					pubKey->u.ec.publicValue.len, 1);
344    SECKEY_DestroyPublicKey(pubKey);
345    pubKey = NULL;
346
347    if (rv != SECSuccess) {
348        goto loser;	/* err set by ssl3_AppendHandshake* */
349    }
350
351    rv = SECSuccess;
352
353loser:
354    if(pms) PK11_FreeSymKey(pms);
355    if(privKey) SECKEY_DestroyPrivateKey(privKey);
356    if(pubKey) SECKEY_DestroyPublicKey(pubKey);
357    return rv;
358}
359
360
361/*
362** Called from ssl3_HandleClientKeyExchange()
363*/
364SECStatus
365ssl3_HandleECDHClientKeyExchange(sslSocket *ss, SSL3Opaque *b,
366				     PRUint32 length,
367                                     SECKEYPublicKey *srvrPubKey,
368                                     SECKEYPrivateKey *srvrPrivKey)
369{
370    PK11SymKey *      pms;
371    SECStatus         rv;
372    SECKEYPublicKey   clntPubKey;
373    CK_MECHANISM_TYPE	target;
374    PRBool isTLS, isTLS12;
375
376    PORT_Assert( ss->opt.noLocks || ssl_HaveRecvBufLock(ss) );
377    PORT_Assert( ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss) );
378
379    clntPubKey.keyType = ecKey;
380    clntPubKey.u.ec.DEREncodedParams.len =
381	srvrPubKey->u.ec.DEREncodedParams.len;
382    clntPubKey.u.ec.DEREncodedParams.data =
383	srvrPubKey->u.ec.DEREncodedParams.data;
384
385    rv = ssl3_ConsumeHandshakeVariable(ss, &clntPubKey.u.ec.publicValue,
386	                               1, &b, &length);
387    if (rv != SECSuccess) {
388	SEND_ALERT
389	return SECFailure;	/* XXX Who sets the error code?? */
390    }
391
392    isTLS = (PRBool)(ss->ssl3.prSpec->version > SSL_LIBRARY_VERSION_3_0);
393    isTLS12 = (PRBool)(ss->ssl3.prSpec->version >= SSL_LIBRARY_VERSION_TLS_1_2);
394
395    if (isTLS12) {
396	target = CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256;
397    } else if (isTLS) {
398	target = CKM_TLS_MASTER_KEY_DERIVE_DH;
399    } else {
400	target = CKM_SSL3_MASTER_KEY_DERIVE_DH;
401    }
402
403    /*  Determine the PMS */
404    pms = PK11_PubDeriveWithKDF(srvrPrivKey, &clntPubKey, PR_FALSE, NULL, NULL,
405			    CKM_ECDH1_DERIVE, target, CKA_DERIVE, 0,
406			    CKD_NULL, NULL, NULL);
407
408    if (pms == NULL) {
409	/* last gasp.  */
410	ssl_MapLowLevelError(SSL_ERROR_CLIENT_KEY_EXCHANGE_FAILURE);
411	return SECFailure;
412    }
413
414    rv = ssl3_InitPendingCipherSpec(ss,  pms);
415    PK11_FreeSymKey(pms);
416    if (rv != SECSuccess) {
417	SEND_ALERT
418	return SECFailure; /* error code set by ssl3_InitPendingCipherSpec */
419    }
420    return SECSuccess;
421}
422
423ECName
424ssl3_GetCurveWithECKeyStrength(PRUint32 curvemsk, int requiredECCbits)
425{
426    int    i;
427
428    for ( i = 0; bits2curve[i].curve != ec_noName; i++) {
429	if (bits2curve[i].bits < requiredECCbits)
430	    continue;
431    	if (SSL_IS_CURVE_NEGOTIATED(curvemsk, bits2curve[i].curve)) {
432	    return bits2curve[i].curve;
433	}
434    }
435    PORT_SetError(SSL_ERROR_NO_CYPHER_OVERLAP);
436    return ec_noName;
437}
438
439/* find the "weakest link".  Get strength of signature key and of sym key.
440 * choose curve for the weakest of those two.
441 */
442ECName
443ssl3_GetCurveNameForServerSocket(sslSocket *ss)
444{
445    SECKEYPublicKey * svrPublicKey = NULL;
446    ECName ec_curve = ec_noName;
447    int    signatureKeyStrength = 521;
448    int    requiredECCbits = ss->sec.secretKeyBits * 2;
449
450    if (ss->ssl3.hs.kea_def->kea == kea_ecdhe_ecdsa) {
451	svrPublicKey = SSL_GET_SERVER_PUBLIC_KEY(ss, kt_ecdh);
452	if (svrPublicKey)
453	    ec_curve = params2ecName(&svrPublicKey->u.ec.DEREncodedParams);
454	if (!SSL_IS_CURVE_NEGOTIATED(ss->ssl3.hs.negotiatedECCurves, ec_curve)) {
455	    PORT_SetError(SSL_ERROR_NO_CYPHER_OVERLAP);
456	    return ec_noName;
457	}
458	signatureKeyStrength = curve2bits[ ec_curve ];
459    } else {
460        /* RSA is our signing cert */
461        int serverKeyStrengthInBits;
462
463        svrPublicKey = SSL_GET_SERVER_PUBLIC_KEY(ss, kt_rsa);
464        if (!svrPublicKey) {
465            PORT_SetError(SSL_ERROR_NO_CYPHER_OVERLAP);
466            return ec_noName;
467        }
468
469        /* currently strength in bytes */
470        serverKeyStrengthInBits = svrPublicKey->u.rsa.modulus.len;
471        if (svrPublicKey->u.rsa.modulus.data[0] == 0) {
472            serverKeyStrengthInBits--;
473        }
474        /* convert to strength in bits */
475        serverKeyStrengthInBits *= BPB;
476
477        signatureKeyStrength =
478	    SSL_RSASTRENGTH_TO_ECSTRENGTH(serverKeyStrengthInBits);
479    }
480    if ( requiredECCbits > signatureKeyStrength )
481         requiredECCbits = signatureKeyStrength;
482
483    return ssl3_GetCurveWithECKeyStrength(ss->ssl3.hs.negotiatedECCurves,
484					  requiredECCbits);
485}
486
487/* function to clear out the lists */
488static SECStatus
489ssl3_ShutdownECDHECurves(void *appData, void *nssData)
490{
491    int i;
492    ECDHEKeyPair *keyPair = &gECDHEKeyPairs[0];
493
494    for (i=0; i < ec_pastLastName; i++, keyPair++) {
495	if (keyPair->pair) {
496	    ssl3_FreeKeyPair(keyPair->pair);
497	}
498    }
499    memset(gECDHEKeyPairs, 0, sizeof gECDHEKeyPairs);
500    return SECSuccess;
501}
502
503static PRStatus
504ssl3_ECRegister(void)
505{
506    SECStatus rv;
507    rv = NSS_RegisterShutdown(ssl3_ShutdownECDHECurves, gECDHEKeyPairs);
508    if (rv != SECSuccess) {
509	gECDHEKeyPairs[ec_noName].error = PORT_GetError();
510    }
511    return (PRStatus)rv;
512}
513
514/* CallOnce function, called once for each named curve. */
515static PRStatus
516ssl3_CreateECDHEphemeralKeyPair(void * arg)
517{
518    SECKEYPrivateKey *    privKey  = NULL;
519    SECKEYPublicKey *     pubKey   = NULL;
520    ssl3KeyPair *	  keyPair  = NULL;
521    ECName                ec_curve = (ECName)arg;
522    SECKEYECParams        ecParams = { siBuffer, NULL, 0 };
523
524    PORT_Assert(gECDHEKeyPairs[ec_curve].pair == NULL);
525
526    /* ok, no one has generated a global key for this curve yet, do so */
527    if (ssl3_ECName2Params(NULL, ec_curve, &ecParams) != SECSuccess) {
528	gECDHEKeyPairs[ec_curve].error = PORT_GetError();
529	return PR_FAILURE;
530    }
531
532    privKey = SECKEY_CreateECPrivateKey(&ecParams, &pubKey, NULL);
533    SECITEM_FreeItem(&ecParams, PR_FALSE);
534
535    if (!privKey || !pubKey || !(keyPair = ssl3_NewKeyPair(privKey, pubKey))) {
536	if (privKey) {
537	    SECKEY_DestroyPrivateKey(privKey);
538	}
539	if (pubKey) {
540	    SECKEY_DestroyPublicKey(pubKey);
541	}
542	ssl_MapLowLevelError(SEC_ERROR_KEYGEN_FAIL);
543	gECDHEKeyPairs[ec_curve].error = PORT_GetError();
544	return PR_FAILURE;
545    }
546
547    gECDHEKeyPairs[ec_curve].pair = keyPair;
548    return PR_SUCCESS;
549}
550
551/*
552 * Creates the ephemeral public and private ECDH keys used by
553 * server in ECDHE_RSA and ECDHE_ECDSA handshakes.
554 * For now, the elliptic curve is chosen to be the same
555 * strength as the signing certificate (ECC or RSA).
556 * We need an API to specify the curve. This won't be a real
557 * issue until we further develop server-side support for ECC
558 * cipher suites.
559 */
560static SECStatus
561ssl3_CreateECDHEphemeralKeys(sslSocket *ss, ECName ec_curve)
562{
563    ssl3KeyPair *	  keyPair        = NULL;
564
565    /* if there's no global key for this curve, make one. */
566    if (gECDHEKeyPairs[ec_curve].pair == NULL) {
567	PRStatus status;
568
569	status = PR_CallOnce(&gECDHEKeyPairs[ec_noName].once, ssl3_ECRegister);
570        if (status != PR_SUCCESS) {
571	    PORT_SetError(gECDHEKeyPairs[ec_noName].error);
572	    return SECFailure;
573    	}
574	status = PR_CallOnceWithArg(&gECDHEKeyPairs[ec_curve].once,
575	                            ssl3_CreateECDHEphemeralKeyPair,
576				    (void *)ec_curve);
577        if (status != PR_SUCCESS) {
578	    PORT_SetError(gECDHEKeyPairs[ec_curve].error);
579	    return SECFailure;
580    	}
581    }
582
583    keyPair = gECDHEKeyPairs[ec_curve].pair;
584    PORT_Assert(keyPair != NULL);
585    if (!keyPair)
586    	return SECFailure;
587    ss->ephemeralECDHKeyPair = ssl3_GetKeyPairRef(keyPair);
588
589    return SECSuccess;
590}
591
592SECStatus
593ssl3_HandleECDHServerKeyExchange(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
594{
595    PLArenaPool *    arena     = NULL;
596    SECKEYPublicKey *peerKey   = NULL;
597    PRBool           isTLS, isTLS12;
598    SECStatus        rv;
599    int              errCode   = SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH;
600    SSL3AlertDescription desc  = illegal_parameter;
601    SSL3Hashes       hashes;
602    SECItem          signature = {siBuffer, NULL, 0};
603
604    SECItem          ec_params = {siBuffer, NULL, 0};
605    SECItem          ec_point  = {siBuffer, NULL, 0};
606    unsigned char    paramBuf[3]; /* only for curve_type == named_curve */
607    SSL3SignatureAndHashAlgorithm sigAndHash;
608
609    sigAndHash.hashAlg = SEC_OID_UNKNOWN;
610
611    isTLS = (PRBool)(ss->ssl3.prSpec->version > SSL_LIBRARY_VERSION_3_0);
612    isTLS12 = (PRBool)(ss->ssl3.prSpec->version >= SSL_LIBRARY_VERSION_TLS_1_2);
613
614    /* XXX This works only for named curves, revisit this when
615     * we support generic curves.
616     */
617    ec_params.len  = sizeof paramBuf;
618    ec_params.data = paramBuf;
619    rv = ssl3_ConsumeHandshake(ss, ec_params.data, ec_params.len, &b, &length);
620    if (rv != SECSuccess) {
621	goto loser;		/* malformed. */
622    }
623
624    /* Fail if the curve is not a named curve */
625    if ((ec_params.data[0] != ec_type_named) ||
626	(ec_params.data[1] != 0) ||
627	!supportedCurve(ec_params.data[2])) {
628	    errCode = SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE;
629	    desc = handshake_failure;
630	    goto alert_loser;
631    }
632
633    rv = ssl3_ConsumeHandshakeVariable(ss, &ec_point, 1, &b, &length);
634    if (rv != SECSuccess) {
635	goto loser;		/* malformed. */
636    }
637    /* Fail if the ec point uses compressed representation */
638    if (ec_point.data[0] != EC_POINT_FORM_UNCOMPRESSED) {
639	    errCode = SEC_ERROR_UNSUPPORTED_EC_POINT_FORM;
640	    desc = handshake_failure;
641	    goto alert_loser;
642    }
643
644    if (isTLS12) {
645	rv = ssl3_ConsumeSignatureAndHashAlgorithm(ss, &b, &length,
646						   &sigAndHash);
647	if (rv != SECSuccess) {
648	    goto loser;		/* malformed or unsupported. */
649	}
650	rv = ssl3_CheckSignatureAndHashAlgorithmConsistency(
651		&sigAndHash, ss->sec.peerCert);
652	if (rv != SECSuccess) {
653	    goto loser;
654	}
655    }
656
657    rv = ssl3_ConsumeHandshakeVariable(ss, &signature, 2, &b, &length);
658    if (rv != SECSuccess) {
659	goto loser;		/* malformed. */
660    }
661
662    if (length != 0) {
663	if (isTLS)
664	    desc = decode_error;
665	goto alert_loser;		/* malformed. */
666    }
667
668    PRINT_BUF(60, (NULL, "Server EC params", ec_params.data,
669	ec_params.len));
670    PRINT_BUF(60, (NULL, "Server EC point", ec_point.data, ec_point.len));
671
672    /* failures after this point are not malformed handshakes. */
673    /* TLS: send decrypt_error if signature failed. */
674    desc = isTLS ? decrypt_error : handshake_failure;
675
676    /*
677     *  check to make sure the hash is signed by right guy
678     */
679    rv = ssl3_ComputeECDHKeyHash(sigAndHash.hashAlg, ec_params, ec_point,
680				 &ss->ssl3.hs.client_random,
681				 &ss->ssl3.hs.server_random,
682				 &hashes, ss->opt.bypassPKCS11);
683
684    if (rv != SECSuccess) {
685	errCode =
686	    ssl_MapLowLevelError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
687	goto alert_loser;
688    }
689    rv = ssl3_VerifySignedHashes(&hashes, ss->sec.peerCert, &signature,
690				isTLS, ss->pkcs11PinArg);
691    if (rv != SECSuccess)  {
692	errCode =
693	    ssl_MapLowLevelError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
694	goto alert_loser;
695    }
696
697    arena = PORT_NewArena(DER_DEFAULT_CHUNKSIZE);
698    if (arena == NULL) {
699	goto no_memory;
700    }
701
702    ss->sec.peerKey = peerKey = PORT_ArenaZNew(arena, SECKEYPublicKey);
703    if (peerKey == NULL) {
704	goto no_memory;
705    }
706
707    peerKey->arena                 = arena;
708    peerKey->keyType               = ecKey;
709
710    /* set up EC parameters in peerKey */
711    if (ssl3_ECName2Params(arena, ec_params.data[2],
712	    &peerKey->u.ec.DEREncodedParams) != SECSuccess) {
713	/* we should never get here since we already
714	 * checked that we are dealing with a supported curve
715	 */
716	errCode = SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE;
717	goto alert_loser;
718    }
719
720    /* copy publicValue in peerKey */
721    if (SECITEM_CopyItem(arena, &peerKey->u.ec.publicValue,  &ec_point))
722    {
723	PORT_FreeArena(arena, PR_FALSE);
724	goto no_memory;
725    }
726    peerKey->pkcs11Slot         = NULL;
727    peerKey->pkcs11ID           = CK_INVALID_HANDLE;
728
729    ss->sec.peerKey = peerKey;
730    ss->ssl3.hs.ws = wait_cert_request;
731
732    return SECSuccess;
733
734alert_loser:
735    (void)SSL3_SendAlert(ss, alert_fatal, desc);
736loser:
737    PORT_SetError( errCode );
738    return SECFailure;
739
740no_memory:	/* no-memory error has already been set. */
741    ssl_MapLowLevelError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
742    return SECFailure;
743}
744
745SECStatus
746ssl3_SendECDHServerKeyExchange(
747    sslSocket *ss,
748    const SSL3SignatureAndHashAlgorithm *sigAndHash)
749{
750    const ssl3KEADef * kea_def     = ss->ssl3.hs.kea_def;
751    SECStatus          rv          = SECFailure;
752    int                length;
753    PRBool             isTLS, isTLS12;
754    SECItem            signed_hash = {siBuffer, NULL, 0};
755    SSL3Hashes         hashes;
756
757    SECKEYPublicKey *  ecdhePub;
758    SECItem            ec_params = {siBuffer, NULL, 0};
759    unsigned char      paramBuf[3];
760    ECName             curve;
761    SSL3KEAType        certIndex;
762
763    /* Generate ephemeral ECDH key pair and send the public key */
764    curve = ssl3_GetCurveNameForServerSocket(ss);
765    if (curve == ec_noName) {
766    	goto loser;
767    }
768    rv = ssl3_CreateECDHEphemeralKeys(ss, curve);
769    if (rv != SECSuccess) {
770	goto loser; 	/* err set by AppendHandshake. */
771    }
772    ecdhePub = ss->ephemeralECDHKeyPair->pubKey;
773    PORT_Assert(ecdhePub != NULL);
774    if (!ecdhePub) {
775	PORT_SetError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
776	return SECFailure;
777    }
778
779    ec_params.len  = sizeof paramBuf;
780    ec_params.data = paramBuf;
781    curve = params2ecName(&ecdhePub->u.ec.DEREncodedParams);
782    if (curve != ec_noName) {
783	ec_params.data[0] = ec_type_named;
784	ec_params.data[1] = 0x00;
785	ec_params.data[2] = curve;
786    } else {
787	PORT_SetError(SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE);
788	goto loser;
789    }
790
791    rv = ssl3_ComputeECDHKeyHash(sigAndHash->hashAlg,
792				 ec_params,
793				 ecdhePub->u.ec.publicValue,
794				 &ss->ssl3.hs.client_random,
795				 &ss->ssl3.hs.server_random,
796				 &hashes, ss->opt.bypassPKCS11);
797    if (rv != SECSuccess) {
798	ssl_MapLowLevelError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
799	goto loser;
800    }
801
802    isTLS = (PRBool)(ss->ssl3.pwSpec->version > SSL_LIBRARY_VERSION_3_0);
803    isTLS12 = (PRBool)(ss->ssl3.pwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_2);
804
805    /* XXX SSLKEAType isn't really a good choice for
806     * indexing certificates but that's all we have
807     * for now.
808     */
809    if (kea_def->kea == kea_ecdhe_rsa)
810	certIndex = kt_rsa;
811    else /* kea_def->kea == kea_ecdhe_ecdsa */
812	certIndex = kt_ecdh;
813
814    rv = ssl3_SignHashes(&hashes, ss->serverCerts[certIndex].SERVERKEY,
815			 &signed_hash, isTLS);
816    if (rv != SECSuccess) {
817	goto loser;		/* ssl3_SignHashes has set err. */
818    }
819    if (signed_hash.data == NULL) {
820	/* how can this happen and rv == SECSuccess ?? */
821	PORT_SetError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
822	goto loser;
823    }
824
825    length = ec_params.len +
826	     1 + ecdhePub->u.ec.publicValue.len +
827	     (isTLS12 ? 2 : 0) + 2 + signed_hash.len;
828
829    rv = ssl3_AppendHandshakeHeader(ss, server_key_exchange, length);
830    if (rv != SECSuccess) {
831	goto loser; 	/* err set by AppendHandshake. */
832    }
833
834    rv = ssl3_AppendHandshake(ss, ec_params.data, ec_params.len);
835    if (rv != SECSuccess) {
836	goto loser; 	/* err set by AppendHandshake. */
837    }
838
839    rv = ssl3_AppendHandshakeVariable(ss, ecdhePub->u.ec.publicValue.data,
840				      ecdhePub->u.ec.publicValue.len, 1);
841    if (rv != SECSuccess) {
842	goto loser; 	/* err set by AppendHandshake. */
843    }
844
845    if (isTLS12) {
846	rv = ssl3_AppendSignatureAndHashAlgorithm(ss, sigAndHash);
847	if (rv != SECSuccess) {
848	    goto loser; 	/* err set by AppendHandshake. */
849	}
850    }
851
852    rv = ssl3_AppendHandshakeVariable(ss, signed_hash.data,
853				      signed_hash.len, 2);
854    if (rv != SECSuccess) {
855	goto loser; 	/* err set by AppendHandshake. */
856    }
857
858    PORT_Free(signed_hash.data);
859    return SECSuccess;
860
861loser:
862    if (signed_hash.data != NULL)
863    	PORT_Free(signed_hash.data);
864    return SECFailure;
865}
866
867/* Lists of ECC cipher suites for searching and disabling. */
868
869static const ssl3CipherSuite ecdh_suites[] = {
870    TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA,
871    TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
872    TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
873    TLS_ECDH_ECDSA_WITH_NULL_SHA,
874    TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
875    TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA,
876    TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
877    TLS_ECDH_RSA_WITH_AES_256_CBC_SHA,
878    TLS_ECDH_RSA_WITH_NULL_SHA,
879    TLS_ECDH_RSA_WITH_RC4_128_SHA,
880    0 /* end of list marker */
881};
882
883static const ssl3CipherSuite ecdh_ecdsa_suites[] = {
884    TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA,
885    TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
886    TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
887    TLS_ECDH_ECDSA_WITH_NULL_SHA,
888    TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
889    0 /* end of list marker */
890};
891
892static const ssl3CipherSuite ecdh_rsa_suites[] = {
893    TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA,
894    TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
895    TLS_ECDH_RSA_WITH_AES_256_CBC_SHA,
896    TLS_ECDH_RSA_WITH_NULL_SHA,
897    TLS_ECDH_RSA_WITH_RC4_128_SHA,
898    0 /* end of list marker */
899};
900
901static const ssl3CipherSuite ecdhe_ecdsa_suites[] = {
902    TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA,
903    TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
904    TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
905    TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
906    TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
907    TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
908    TLS_ECDHE_ECDSA_WITH_NULL_SHA,
909    TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
910    0 /* end of list marker */
911};
912
913static const ssl3CipherSuite ecdhe_rsa_suites[] = {
914    TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
915    TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
916    TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
917    TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
918    TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
919    TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
920    TLS_ECDHE_RSA_WITH_NULL_SHA,
921    TLS_ECDHE_RSA_WITH_RC4_128_SHA,
922    0 /* end of list marker */
923};
924
925/* List of all ECC cipher suites */
926static const ssl3CipherSuite ecSuites[] = {
927    TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA,
928    TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
929    TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
930    TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
931    TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
932    TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
933    TLS_ECDHE_ECDSA_WITH_NULL_SHA,
934    TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
935    TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
936    TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
937    TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
938    TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
939    TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
940    TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
941    TLS_ECDHE_RSA_WITH_NULL_SHA,
942    TLS_ECDHE_RSA_WITH_RC4_128_SHA,
943    TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA,
944    TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
945    TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
946    TLS_ECDH_ECDSA_WITH_NULL_SHA,
947    TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
948    TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA,
949    TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
950    TLS_ECDH_RSA_WITH_AES_256_CBC_SHA,
951    TLS_ECDH_RSA_WITH_NULL_SHA,
952    TLS_ECDH_RSA_WITH_RC4_128_SHA,
953    0 /* end of list marker */
954};
955
956/* On this socket, Disable the ECC cipher suites in the argument's list */
957SECStatus
958ssl3_DisableECCSuites(sslSocket * ss, const ssl3CipherSuite * suite)
959{
960    if (!suite)
961    	suite = ecSuites;
962    for (; *suite; ++suite) {
963	SECStatus rv      = ssl3_CipherPrefSet(ss, *suite, PR_FALSE);
964
965	PORT_Assert(rv == SECSuccess); /* else is coding error */
966    }
967    return SECSuccess;
968}
969
970/* Look at the server certs configured on this socket, and disable any
971 * ECC cipher suites that are not supported by those certs.
972 */
973void
974ssl3_FilterECCipherSuitesByServerCerts(sslSocket * ss)
975{
976    CERTCertificate * svrCert;
977
978    svrCert = ss->serverCerts[kt_rsa].serverCert;
979    if (!svrCert) {
980	ssl3_DisableECCSuites(ss, ecdhe_rsa_suites);
981    }
982
983    svrCert = ss->serverCerts[kt_ecdh].serverCert;
984    if (!svrCert) {
985	ssl3_DisableECCSuites(ss, ecdh_suites);
986	ssl3_DisableECCSuites(ss, ecdhe_ecdsa_suites);
987    } else {
988	SECOidTag sigTag = SECOID_GetAlgorithmTag(&svrCert->signature);
989
990	switch (sigTag) {
991	case SEC_OID_PKCS1_RSA_ENCRYPTION:
992	case SEC_OID_PKCS1_MD2_WITH_RSA_ENCRYPTION:
993	case SEC_OID_PKCS1_MD4_WITH_RSA_ENCRYPTION:
994	case SEC_OID_PKCS1_MD5_WITH_RSA_ENCRYPTION:
995	case SEC_OID_PKCS1_SHA1_WITH_RSA_ENCRYPTION:
996	case SEC_OID_PKCS1_SHA224_WITH_RSA_ENCRYPTION:
997	case SEC_OID_PKCS1_SHA256_WITH_RSA_ENCRYPTION:
998	case SEC_OID_PKCS1_SHA384_WITH_RSA_ENCRYPTION:
999	case SEC_OID_PKCS1_SHA512_WITH_RSA_ENCRYPTION:
1000	    ssl3_DisableECCSuites(ss, ecdh_ecdsa_suites);
1001	    break;
1002	case SEC_OID_ANSIX962_ECDSA_SHA1_SIGNATURE:
1003	case SEC_OID_ANSIX962_ECDSA_SHA224_SIGNATURE:
1004	case SEC_OID_ANSIX962_ECDSA_SHA256_SIGNATURE:
1005	case SEC_OID_ANSIX962_ECDSA_SHA384_SIGNATURE:
1006	case SEC_OID_ANSIX962_ECDSA_SHA512_SIGNATURE:
1007	case SEC_OID_ANSIX962_ECDSA_SIGNATURE_RECOMMENDED_DIGEST:
1008	case SEC_OID_ANSIX962_ECDSA_SIGNATURE_SPECIFIED_DIGEST:
1009	    ssl3_DisableECCSuites(ss, ecdh_rsa_suites);
1010	    break;
1011	default:
1012	    ssl3_DisableECCSuites(ss, ecdh_suites);
1013	    break;
1014	}
1015    }
1016}
1017
1018/* Ask: is ANY ECC cipher suite enabled on this socket? */
1019/* Order(N^2).  Yuk.  Also, this ignores export policy. */
1020PRBool
1021ssl3_IsECCEnabled(sslSocket * ss)
1022{
1023    const ssl3CipherSuite * suite;
1024    PK11SlotInfo *slot;
1025
1026    /* make sure we can do ECC */
1027    slot = PK11_GetBestSlot(CKM_ECDH1_DERIVE,  ss->pkcs11PinArg);
1028    if (!slot) {
1029	return PR_FALSE;
1030    }
1031    PK11_FreeSlot(slot);
1032
1033    /* make sure an ECC cipher is enabled */
1034    for (suite = ecSuites; *suite; ++suite) {
1035	PRBool    enabled = PR_FALSE;
1036	SECStatus rv      = ssl3_CipherPrefGet(ss, *suite, &enabled);
1037
1038	PORT_Assert(rv == SECSuccess); /* else is coding error */
1039	if (rv == SECSuccess && enabled)
1040	    return PR_TRUE;
1041    }
1042    return PR_FALSE;
1043}
1044
1045#define BE(n) 0, n
1046
1047/* Prefabricated TLS client hello extension, Elliptic Curves List,
1048 * offers only 3 curves, the Suite B curves, 23-25
1049 */
1050static const PRUint8 suiteBECList[12] = {
1051    BE(10),         /* Extension type */
1052    BE( 8),         /* octets that follow ( 3 pairs + 1 length pair) */
1053    BE( 6),         /* octets that follow ( 3 pairs) */
1054    BE(23), BE(24), BE(25)
1055};
1056
1057/* Prefabricated TLS client hello extension, Elliptic Curves List,
1058 * offers curves 1-25.
1059 */
1060static const PRUint8 tlsECList[56] = {
1061    BE(10),         /* Extension type */
1062    BE(52),         /* octets that follow (25 pairs + 1 length pair) */
1063    BE(50),         /* octets that follow (25 pairs) */
1064            BE( 1), BE( 2), BE( 3), BE( 4), BE( 5), BE( 6), BE( 7),
1065    BE( 8), BE( 9), BE(10), BE(11), BE(12), BE(13), BE(14), BE(15),
1066    BE(16), BE(17), BE(18), BE(19), BE(20), BE(21), BE(22), BE(23),
1067    BE(24), BE(25)
1068};
1069
1070static const PRUint8 ecPtFmt[6] = {
1071    BE(11),         /* Extension type */
1072    BE( 2),         /* octets that follow */
1073             1,     /* octets that follow */
1074                 0  /* uncompressed type only */
1075};
1076
1077/* This function already presumes we can do ECC, ssl3_IsECCEnabled must be
1078 * called before this function. It looks to see if we have a token which
1079 * is capable of doing smaller than SuiteB curves. If the token can, we
1080 * presume the token can do the whole SSL suite of curves. If it can't we
1081 * presume the token that allowed ECC to be enabled can only do suite B
1082 * curves. */
1083static PRBool
1084ssl3_SuiteBOnly(sslSocket *ss)
1085{
1086#if 0
1087    /* See if we can support small curves (like 163). If not, assume we can
1088     * only support Suite-B curves (P-256, P-384, P-521). */
1089    PK11SlotInfo *slot =
1090	PK11_GetBestSlotWithAttributes(CKM_ECDH1_DERIVE, 0, 163,
1091				       ss ? ss->pkcs11PinArg : NULL);
1092
1093    if (!slot) {
1094	/* nope, presume we can only do suite B */
1095	return PR_TRUE;
1096    }
1097    /* we can, presume we can do all curves */
1098    PK11_FreeSlot(slot);
1099    return PR_FALSE;
1100#else
1101    return PR_TRUE;
1102#endif
1103}
1104
1105/* Send our "canned" (precompiled) Supported Elliptic Curves extension,
1106 * which says that we support all TLS-defined named curves.
1107 */
1108PRInt32
1109ssl3_SendSupportedCurvesXtn(
1110			sslSocket * ss,
1111			PRBool      append,
1112			PRUint32    maxBytes)
1113{
1114    PRInt32 ecListSize = 0;
1115    const PRUint8 *ecList = NULL;
1116
1117    if (!ss || !ssl3_IsECCEnabled(ss))
1118    	return 0;
1119
1120    if (ssl3_SuiteBOnly(ss)) {
1121	ecListSize = sizeof suiteBECList;
1122	ecList = suiteBECList;
1123    } else {
1124	ecListSize = sizeof tlsECList;
1125	ecList = tlsECList;
1126    }
1127
1128    if (append && maxBytes >= ecListSize) {
1129	SECStatus rv = ssl3_AppendHandshake(ss, ecList, ecListSize);
1130	if (rv != SECSuccess)
1131	    return -1;
1132	if (!ss->sec.isServer) {
1133	    TLSExtensionData *xtnData = &ss->xtnData;
1134	    xtnData->advertised[xtnData->numAdvertised++] =
1135		ssl_elliptic_curves_xtn;
1136	}
1137    }
1138    return ecListSize;
1139}
1140
1141PRUint32
1142ssl3_GetSupportedECCurveMask(sslSocket *ss)
1143{
1144    if (ssl3_SuiteBOnly(ss)) {
1145	return SSL3_SUITE_B_SUPPORTED_CURVES_MASK;
1146    }
1147    return SSL3_ALL_SUPPORTED_CURVES_MASK;
1148}
1149
1150/* Send our "canned" (precompiled) Supported Point Formats extension,
1151 * which says that we only support uncompressed points.
1152 */
1153PRInt32
1154ssl3_SendSupportedPointFormatsXtn(
1155			sslSocket * ss,
1156			PRBool      append,
1157			PRUint32    maxBytes)
1158{
1159    if (!ss || !ssl3_IsECCEnabled(ss))
1160    	return 0;
1161    if (append && maxBytes >= (sizeof ecPtFmt)) {
1162	SECStatus rv = ssl3_AppendHandshake(ss, ecPtFmt, (sizeof ecPtFmt));
1163	if (rv != SECSuccess)
1164	    return -1;
1165	if (!ss->sec.isServer) {
1166	    TLSExtensionData *xtnData = &ss->xtnData;
1167	    xtnData->advertised[xtnData->numAdvertised++] =
1168		ssl_ec_point_formats_xtn;
1169	}
1170    }
1171    return (sizeof ecPtFmt);
1172}
1173
1174/* Just make sure that the remote client supports uncompressed points,
1175 * Since that is all we support.  Disable ECC cipher suites if it doesn't.
1176 */
1177SECStatus
1178ssl3_HandleSupportedPointFormatsXtn(sslSocket *ss, PRUint16 ex_type,
1179                                    SECItem *data)
1180{
1181    int i;
1182
1183    if (data->len < 2 || data->len > 255 || !data->data ||
1184        data->len != (unsigned int)data->data[0] + 1) {
1185    	/* malformed */
1186	goto loser;
1187    }
1188    for (i = data->len; --i > 0; ) {
1189    	if (data->data[i] == 0) {
1190	    /* indicate that we should send a reply */
1191	    SECStatus rv;
1192	    rv = ssl3_RegisterServerHelloExtensionSender(ss, ex_type,
1193			      &ssl3_SendSupportedPointFormatsXtn);
1194	    return rv;
1195	}
1196    }
1197loser:
1198    /* evil client doesn't support uncompressed */
1199    ssl3_DisableECCSuites(ss, ecSuites);
1200    return SECFailure;
1201}
1202
1203
1204#define SSL3_GET_SERVER_PUBLICKEY(sock, type) \
1205    (ss->serverCerts[type].serverKeyPair ? \
1206    ss->serverCerts[type].serverKeyPair->pubKey : NULL)
1207
1208/* Extract the TLS curve name for the public key in our EC server cert. */
1209ECName ssl3_GetSvrCertCurveName(sslSocket *ss)
1210{
1211    SECKEYPublicKey       *srvPublicKey;
1212    ECName		  ec_curve       = ec_noName;
1213
1214    srvPublicKey = SSL3_GET_SERVER_PUBLICKEY(ss, kt_ecdh);
1215    if (srvPublicKey) {
1216	ec_curve = params2ecName(&srvPublicKey->u.ec.DEREncodedParams);
1217    }
1218    return ec_curve;
1219}
1220
1221/* Ensure that the curve in our server cert is one of the ones suppored
1222 * by the remote client, and disable all ECC cipher suites if not.
1223 */
1224SECStatus
1225ssl3_HandleSupportedCurvesXtn(sslSocket *ss, PRUint16 ex_type, SECItem *data)
1226{
1227    PRInt32  list_len;
1228    PRUint32 peerCurves   = 0;
1229    PRUint32 mutualCurves = 0;
1230    PRUint16 svrCertCurveName;
1231
1232    if (!data->data || data->len < 4 || data->len > 65535)
1233    	goto loser;
1234    /* get the length of elliptic_curve_list */
1235    list_len = ssl3_ConsumeHandshakeNumber(ss, 2, &data->data, &data->len);
1236    if (list_len < 0 || data->len != list_len || (data->len % 2) != 0) {
1237    	/* malformed */
1238	goto loser;
1239    }
1240    /* build bit vector of peer's supported curve names */
1241    while (data->len) {
1242	PRInt32  curve_name =
1243		 ssl3_ConsumeHandshakeNumber(ss, 2, &data->data, &data->len);
1244	if (curve_name > ec_noName && curve_name < ec_pastLastName) {
1245	    peerCurves |= (1U << curve_name);
1246	}
1247    }
1248    /* What curves do we support in common? */
1249    mutualCurves = ss->ssl3.hs.negotiatedECCurves &= peerCurves;
1250    if (!mutualCurves) { /* no mutually supported EC Curves */
1251    	goto loser;
1252    }
1253
1254    /* if our ECC cert doesn't use one of these supported curves,
1255     * disable ECC cipher suites that require an ECC cert.
1256     */
1257    svrCertCurveName = ssl3_GetSvrCertCurveName(ss);
1258    if (svrCertCurveName != ec_noName &&
1259        (mutualCurves & (1U << svrCertCurveName)) != 0) {
1260	return SECSuccess;
1261    }
1262    /* Our EC cert doesn't contain a mutually supported curve.
1263     * Disable all ECC cipher suites that require an EC cert
1264     */
1265    ssl3_DisableECCSuites(ss, ecdh_ecdsa_suites);
1266    ssl3_DisableECCSuites(ss, ecdhe_ecdsa_suites);
1267    return SECFailure;
1268
1269loser:
1270    /* no common curve supported */
1271    ssl3_DisableECCSuites(ss, ecSuites);
1272    return SECFailure;
1273}
1274
1275#endif /* NSS_ENABLE_ECC */
1276