1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5/*
6 * DTLS Protocol
7 */
8
9#include "ssl.h"
10#include "sslimpl.h"
11#include "sslproto.h"
12
13#ifndef PR_ARRAY_SIZE
14#define PR_ARRAY_SIZE(a) (sizeof(a)/sizeof((a)[0]))
15#endif
16
17static SECStatus dtls_TransmitMessageFlight(sslSocket *ss);
18static void dtls_RetransmitTimerExpiredCb(sslSocket *ss);
19static SECStatus dtls_SendSavedWriteData(sslSocket *ss);
20
21/* -28 adjusts for the IP/UDP header */
22static const PRUint16 COMMON_MTU_VALUES[] = {
23    1500 - 28,  /* Ethernet MTU */
24    1280 - 28,  /* IPv6 minimum MTU */
25    576 - 28,   /* Common assumption */
26    256 - 28    /* We're in serious trouble now */
27};
28
29#define DTLS_COOKIE_BYTES 32
30
31/* List copied from ssl3con.c:cipherSuites */
32static const ssl3CipherSuite nonDTLSSuites[] = {
33#ifdef NSS_ENABLE_ECC
34    TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
35    TLS_ECDHE_RSA_WITH_RC4_128_SHA,
36#endif  /* NSS_ENABLE_ECC */
37    TLS_DHE_DSS_WITH_RC4_128_SHA,
38#ifdef NSS_ENABLE_ECC
39    TLS_ECDH_RSA_WITH_RC4_128_SHA,
40    TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
41#endif  /* NSS_ENABLE_ECC */
42    SSL_RSA_WITH_RC4_128_MD5,
43    SSL_RSA_WITH_RC4_128_SHA,
44    TLS_RSA_EXPORT1024_WITH_RC4_56_SHA,
45    SSL_RSA_EXPORT_WITH_RC4_40_MD5,
46    0 /* End of list marker */
47};
48
49/* Map back and forth between TLS and DTLS versions in wire format.
50 * Mapping table is:
51 *
52 * TLS             DTLS
53 * 1.1 (0302)      1.0 (feff)
54 */
55SSL3ProtocolVersion
56dtls_TLSVersionToDTLSVersion(SSL3ProtocolVersion tlsv)
57{
58    /* Anything other than TLS 1.1 is an error, so return
59     * the invalid version ffff. */
60    if (tlsv != SSL_LIBRARY_VERSION_TLS_1_1)
61	return 0xffff;
62
63    return SSL_LIBRARY_VERSION_DTLS_1_0_WIRE;
64}
65
66/* Map known DTLS versions to known TLS versions.
67 * - Invalid versions (< 1.0) return a version of 0
68 * - Versions > known return a version one higher than we know of
69 * to accomodate a theoretically newer version */
70SSL3ProtocolVersion
71dtls_DTLSVersionToTLSVersion(SSL3ProtocolVersion dtlsv)
72{
73    if (MSB(dtlsv) == 0xff) {
74	return 0;
75    }
76
77    if (dtlsv == SSL_LIBRARY_VERSION_DTLS_1_0_WIRE)
78	return SSL_LIBRARY_VERSION_TLS_1_1;
79
80    /* Return a fictional higher version than we know of */
81    return SSL_LIBRARY_VERSION_TLS_1_1 + 1;
82}
83
84/* On this socket, Disable non-DTLS cipher suites in the argument's list */
85SECStatus
86ssl3_DisableNonDTLSSuites(sslSocket * ss)
87{
88    const ssl3CipherSuite * suite;
89
90    for (suite = nonDTLSSuites; *suite; ++suite) {
91	SECStatus rv = ssl3_CipherPrefSet(ss, *suite, PR_FALSE);
92
93	PORT_Assert(rv == SECSuccess); /* else is coding error */
94    }
95    return SECSuccess;
96}
97
98/* Allocate a DTLSQueuedMessage.
99 *
100 * Called from dtls_QueueMessage()
101 */
102static DTLSQueuedMessage *
103dtls_AllocQueuedMessage(PRUint16 epoch, SSL3ContentType type,
104			const unsigned char *data, PRUint32 len)
105{
106    DTLSQueuedMessage *msg = NULL;
107
108    msg = PORT_ZAlloc(sizeof(DTLSQueuedMessage));
109    if (!msg)
110	return NULL;
111
112    msg->data = PORT_Alloc(len);
113    if (!msg->data) {
114	PORT_Free(msg);
115        return NULL;
116    }
117    PORT_Memcpy(msg->data, data, len);
118
119    msg->len = len;
120    msg->epoch = epoch;
121    msg->type = type;
122
123    return msg;
124}
125
126/*
127 * Free a handshake message
128 *
129 * Called from dtls_FreeHandshakeMessages()
130 */
131static void
132dtls_FreeHandshakeMessage(DTLSQueuedMessage *msg)
133{
134    if (!msg)
135	return;
136
137    PORT_ZFree(msg->data, msg->len);
138    PORT_Free(msg);
139}
140
141/*
142 * Free a list of handshake messages
143 *
144 * Called from:
145 *              dtls_HandleHandshake()
146 *              ssl3_DestroySSL3Info()
147 */
148void
149dtls_FreeHandshakeMessages(PRCList *list)
150{
151    PRCList *cur_p;
152
153    while (!PR_CLIST_IS_EMPTY(list)) {
154	cur_p = PR_LIST_TAIL(list);
155	PR_REMOVE_LINK(cur_p);
156	dtls_FreeHandshakeMessage((DTLSQueuedMessage *)cur_p);
157    }
158}
159
160/* Called only from ssl3_HandleRecord, for each (deciphered) DTLS record.
161 * origBuf is the decrypted ssl record content and is expected to contain
162 * complete handshake records
163 * Caller must hold the handshake and RecvBuf locks.
164 *
165 * Note that this code uses msg_len for two purposes:
166 *
167 * (1) To pass the length to ssl3_HandleHandshakeMessage()
168 * (2) To carry the length of a message currently being reassembled
169 *
170 * However, unlike ssl3_HandleHandshake(), it is not used to carry
171 * the state of reassembly (i.e., whether one is in progress). That
172 * is carried in recvdHighWater and recvdFragments.
173 */
174#define OFFSET_BYTE(o) (o/8)
175#define OFFSET_MASK(o) (1 << (o%8))
176
177SECStatus
178dtls_HandleHandshake(sslSocket *ss, sslBuffer *origBuf)
179{
180    /* XXX OK for now.
181     * This doesn't work properly with asynchronous certificate validation.
182     * because that returns a WOULDBLOCK error. The current DTLS
183     * applications do not need asynchronous validation, but in the
184     * future we will need to add this.
185     */
186    sslBuffer buf = *origBuf;
187    SECStatus rv = SECSuccess;
188
189    PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
190    PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
191
192    while (buf.len > 0) {
193        PRUint8 type;
194        PRUint32 message_length;
195        PRUint16 message_seq;
196        PRUint32 fragment_offset;
197        PRUint32 fragment_length;
198        PRUint32 offset;
199
200        if (buf.len < 12) {
201            PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
202            rv = SECFailure;
203            break;
204        }
205
206        /* Parse the header */
207	type = buf.buf[0];
208        message_length = (buf.buf[1] << 16) | (buf.buf[2] << 8) | buf.buf[3];
209        message_seq = (buf.buf[4] << 8) | buf.buf[5];
210        fragment_offset = (buf.buf[6] << 16) | (buf.buf[7] << 8) | buf.buf[8];
211        fragment_length = (buf.buf[9] << 16) | (buf.buf[10] << 8) | buf.buf[11];
212
213#define MAX_HANDSHAKE_MSG_LEN 0x1ffff	/* 128k - 1 */
214	if (message_length > MAX_HANDSHAKE_MSG_LEN) {
215	    (void)ssl3_DecodeError(ss);
216	    PORT_SetError(SSL_ERROR_RX_RECORD_TOO_LONG);
217	    return SECFailure;
218	}
219#undef MAX_HANDSHAKE_MSG_LEN
220
221        buf.buf += 12;
222        buf.len -= 12;
223
224        /* This fragment must be complete */
225        if (buf.len < fragment_length) {
226            PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
227            rv = SECFailure;
228            break;
229        }
230
231        /* Sanity check the packet contents */
232	if ((fragment_length + fragment_offset) > message_length) {
233            PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
234            rv = SECFailure;
235            break;
236        }
237
238        /* There are three ways we could not be ready for this packet.
239         *
240         * 1. It's a partial next message.
241         * 2. It's a partial or complete message beyond the next
242         * 3. It's a message we've already seen
243         *
244         * If it's the complete next message we accept it right away.
245         * This is the common case for short messages
246         */
247        if ((message_seq == ss->ssl3.hs.recvMessageSeq)
248	    && (fragment_offset == 0)
249	    && (fragment_length == message_length)) {
250            /* Complete next message. Process immediately */
251            ss->ssl3.hs.msg_type = (SSL3HandshakeType)type;
252            ss->ssl3.hs.msg_len = message_length;
253
254            /* At this point we are advancing our state machine, so
255             * we can free our last flight of messages */
256            dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
257	    ss->ssl3.hs.recvdHighWater = -1;
258	    dtls_CancelTimer(ss);
259
260	    /* Reset the timer to the initial value if the retry counter
261	     * is 0, per Sec. 4.2.4.1 */
262	    if (ss->ssl3.hs.rtRetries == 0) {
263		ss->ssl3.hs.rtTimeoutMs = INITIAL_DTLS_TIMEOUT_MS;
264	    }
265
266            rv = ssl3_HandleHandshakeMessage(ss, buf.buf, ss->ssl3.hs.msg_len);
267            if (rv == SECFailure) {
268                /* Do not attempt to process rest of messages in this record */
269                break;
270            }
271        } else {
272	    if (message_seq < ss->ssl3.hs.recvMessageSeq) {
273		/* Case 3: we do an immediate retransmit if we're
274		 * in a waiting state*/
275		if (ss->ssl3.hs.rtTimerCb == NULL) {
276		    /* Ignore */
277		} else if (ss->ssl3.hs.rtTimerCb ==
278			 dtls_RetransmitTimerExpiredCb) {
279		    SSL_TRC(30, ("%d: SSL3[%d]: Retransmit detected",
280				 SSL_GETPID(), ss->fd));
281		    /* Check to see if we retransmitted recently. If so,
282		     * suppress the triggered retransmit. This avoids
283		     * retransmit wars after packet loss.
284		     * This is not in RFC 5346 but should be
285		     */
286		    if ((PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted) >
287			(ss->ssl3.hs.rtTimeoutMs / 4)) {
288			    SSL_TRC(30,
289			    ("%d: SSL3[%d]: Shortcutting retransmit timer",
290                            SSL_GETPID(), ss->fd));
291
292			    /* Cancel the timer and call the CB,
293			     * which re-arms the timer */
294			    dtls_CancelTimer(ss);
295			    dtls_RetransmitTimerExpiredCb(ss);
296			    rv = SECSuccess;
297			    break;
298			} else {
299			    SSL_TRC(30,
300			    ("%d: SSL3[%d]: We just retransmitted. Ignoring.",
301                            SSL_GETPID(), ss->fd));
302			    rv = SECSuccess;
303			    break;
304			}
305		} else if (ss->ssl3.hs.rtTimerCb == dtls_FinishedTimerCb) {
306		    /* Retransmit the messages and re-arm the timer
307		     * Note that we are not backing off the timer here.
308		     * The spec isn't clear and my reasoning is that this
309		     * may be a re-ordered packet rather than slowness,
310		     * so let's be aggressive. */
311		    dtls_CancelTimer(ss);
312		    rv = dtls_TransmitMessageFlight(ss);
313		    if (rv == SECSuccess) {
314			rv = dtls_StartTimer(ss, dtls_FinishedTimerCb);
315		    }
316		    if (rv != SECSuccess)
317			return rv;
318		    break;
319		}
320	    } else if (message_seq > ss->ssl3.hs.recvMessageSeq) {
321		/* Case 2
322                 *
323		 * Ignore this message. This means we don't handle out of
324		 * order complete messages that well, but we're still
325		 * compliant and this probably does not happen often
326                 *
327		 * XXX OK for now. Maybe do something smarter at some point?
328		 */
329	    } else {
330		/* Case 1
331                 *
332		 * Buffer the fragment for reassembly
333		 */
334                /* Make room for the message */
335                if (ss->ssl3.hs.recvdHighWater == -1) {
336                    PRUint32 map_length = OFFSET_BYTE(message_length) + 1;
337
338                    rv = sslBuffer_Grow(&ss->ssl3.hs.msg_body, message_length);
339                    if (rv != SECSuccess)
340                        break;
341                    /* Make room for the fragment map */
342                    rv = sslBuffer_Grow(&ss->ssl3.hs.recvdFragments,
343                                        map_length);
344                    if (rv != SECSuccess)
345                        break;
346
347                    /* Reset the reassembly map */
348                    ss->ssl3.hs.recvdHighWater = 0;
349                    PORT_Memset(ss->ssl3.hs.recvdFragments.buf, 0,
350				ss->ssl3.hs.recvdFragments.space);
351		    ss->ssl3.hs.msg_type = (SSL3HandshakeType)type;
352                    ss->ssl3.hs.msg_len = message_length;
353                }
354
355                /* If we have a message length mismatch, abandon the reassembly
356                 * in progress and hope that the next retransmit will give us
357                 * something sane
358                 */
359                if (message_length != ss->ssl3.hs.msg_len) {
360                    ss->ssl3.hs.recvdHighWater = -1;
361                    PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
362                    rv = SECFailure;
363                    break;
364                }
365
366                /* Now copy this fragment into the buffer */
367                PORT_Assert((fragment_offset + fragment_length) <=
368                            ss->ssl3.hs.msg_body.space);
369                PORT_Memcpy(ss->ssl3.hs.msg_body.buf + fragment_offset,
370                            buf.buf, fragment_length);
371
372                /* This logic is a bit tricky. We have two values for
373                 * reassembly state:
374                 *
375                 * - recvdHighWater contains the highest contiguous number of
376                 *   bytes received
377                 * - recvdFragments contains a bitmask of packets received
378                 *   above recvdHighWater
379                 *
380                 * This avoids having to fill in the bitmask in the common
381                 * case of adjacent fragments received in sequence
382                 */
383                if (fragment_offset <= ss->ssl3.hs.recvdHighWater) {
384		    /* Either this is the adjacent fragment or an overlapping
385                     * fragment */
386                    ss->ssl3.hs.recvdHighWater = fragment_offset +
387                                                 fragment_length;
388                } else {
389                    for (offset = fragment_offset;
390                         offset < fragment_offset + fragment_length;
391                         offset++) {
392                        ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] |=
393                            OFFSET_MASK(offset);
394                    }
395                }
396
397                /* Now figure out the new high water mark if appropriate */
398                for (offset = ss->ssl3.hs.recvdHighWater;
399                     offset < ss->ssl3.hs.msg_len; offset++) {
400		    /* Note that this loop is not efficient, since it counts
401		     * bit by bit. If we have a lot of out-of-order packets,
402		     * we should optimize this */
403                    if (ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] &
404                        OFFSET_MASK(offset)) {
405                        ss->ssl3.hs.recvdHighWater++;
406                    } else {
407                        break;
408                    }
409                }
410
411                /* If we have all the bytes, then we are good to go */
412                if (ss->ssl3.hs.recvdHighWater == ss->ssl3.hs.msg_len) {
413                    ss->ssl3.hs.recvdHighWater = -1;
414
415                    rv = ssl3_HandleHandshakeMessage(ss,
416                                                     ss->ssl3.hs.msg_body.buf,
417                                                     ss->ssl3.hs.msg_len);
418                    if (rv == SECFailure)
419                        break; /* Skip rest of record */
420
421		    /* At this point we are advancing our state machine, so
422		     * we can free our last flight of messages */
423		    dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
424		    dtls_CancelTimer(ss);
425
426		    /* If there have been no retries this time, reset the
427		     * timer value to the default per Section 4.2.4.1 */
428		    if (ss->ssl3.hs.rtRetries == 0) {
429			ss->ssl3.hs.rtTimeoutMs = INITIAL_DTLS_TIMEOUT_MS;
430		    }
431                }
432            }
433        }
434
435	buf.buf += fragment_length;
436        buf.len -= fragment_length;
437    }
438
439    origBuf->len = 0;	/* So ssl3_GatherAppDataRecord will keep looping. */
440
441    /* XXX OK for now. In future handle rv == SECWouldBlock safely in order
442     * to deal with asynchronous certificate verification */
443    return rv;
444}
445
446/* Enqueue a message (either handshake or CCS)
447 *
448 * Called from:
449 *              dtls_StageHandshakeMessage()
450 *              ssl3_SendChangeCipherSpecs()
451 */
452SECStatus dtls_QueueMessage(sslSocket *ss, SSL3ContentType type,
453    const SSL3Opaque *pIn, PRInt32 nIn)
454{
455    SECStatus rv = SECSuccess;
456    DTLSQueuedMessage *msg = NULL;
457
458    PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
459    PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
460
461    msg = dtls_AllocQueuedMessage(ss->ssl3.cwSpec->epoch, type, pIn, nIn);
462
463    if (!msg) {
464	PORT_SetError(SEC_ERROR_NO_MEMORY);
465	rv = SECFailure;
466    } else {
467	PR_APPEND_LINK(&msg->link, &ss->ssl3.hs.lastMessageFlight);
468    }
469
470    return rv;
471}
472
473/* Add DTLS handshake message to the pending queue
474 * Empty the sendBuf buffer.
475 * This function returns SECSuccess or SECFailure, never SECWouldBlock.
476 * Always set sendBuf.len to 0, even when returning SECFailure.
477 *
478 * Called from:
479 *              ssl3_AppendHandshakeHeader()
480 *              dtls_FlushHandshake()
481 */
482SECStatus
483dtls_StageHandshakeMessage(sslSocket *ss)
484{
485    SECStatus rv = SECSuccess;
486
487    PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
488    PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
489
490    /* This function is sometimes called when no data is actually to
491     * be staged, so just return SECSuccess. */
492    if (!ss->sec.ci.sendBuf.buf || !ss->sec.ci.sendBuf.len)
493	return rv;
494
495    rv = dtls_QueueMessage(ss, content_handshake,
496                           ss->sec.ci.sendBuf.buf, ss->sec.ci.sendBuf.len);
497
498    /* Whether we succeeded or failed, toss the old handshake data. */
499    ss->sec.ci.sendBuf.len = 0;
500    return rv;
501}
502
503/* Enqueue the handshake message in sendBuf (if any) and then
504 * transmit the resulting flight of handshake messages.
505 *
506 * Called from:
507 *              ssl3_FlushHandshake()
508 */
509SECStatus
510dtls_FlushHandshakeMessages(sslSocket *ss, PRInt32 flags)
511{
512    SECStatus rv = SECSuccess;
513
514    PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
515    PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
516
517    rv = dtls_StageHandshakeMessage(ss);
518    if (rv != SECSuccess)
519        return rv;
520
521    if (!(flags & ssl_SEND_FLAG_FORCE_INTO_BUFFER)) {
522        rv = dtls_TransmitMessageFlight(ss);
523        if (rv != SECSuccess)
524            return rv;
525
526	if (!(flags & ssl_SEND_FLAG_NO_RETRANSMIT)) {
527	    ss->ssl3.hs.rtRetries = 0;
528	    rv = dtls_StartTimer(ss, dtls_RetransmitTimerExpiredCb);
529	}
530    }
531
532    return rv;
533}
534
535/* The callback for when the retransmit timer expires
536 *
537 * Called from:
538 *              dtls_CheckTimer()
539 *              dtls_HandleHandshake()
540 */
541static void
542dtls_RetransmitTimerExpiredCb(sslSocket *ss)
543{
544    SECStatus rv = SECFailure;
545
546    ss->ssl3.hs.rtRetries++;
547
548    if (!(ss->ssl3.hs.rtRetries % 3)) {
549	/* If one of the messages was potentially greater than > MTU,
550	 * then downgrade. Do this every time we have retransmitted a
551	 * message twice, per RFC 6347 Sec. 4.1.1 */
552	dtls_SetMTU(ss, ss->ssl3.hs.maxMessageSent - 1);
553    }
554
555    rv = dtls_TransmitMessageFlight(ss);
556    if (rv == SECSuccess) {
557
558	/* Re-arm the timer */
559	rv = dtls_RestartTimer(ss, PR_TRUE, dtls_RetransmitTimerExpiredCb);
560    }
561
562    if (rv == SECFailure) {
563	/* XXX OK for now. In future maybe signal the stack that we couldn't
564	 * transmit. For now, let the read handle any real network errors */
565    }
566}
567
568/* Transmit a flight of handshake messages, stuffing them
569 * into as few records as seems reasonable
570 *
571 * Called from:
572 *             dtls_FlushHandshake()
573 *             dtls_RetransmitTimerExpiredCb()
574 */
575static SECStatus
576dtls_TransmitMessageFlight(sslSocket *ss)
577{
578    SECStatus rv = SECSuccess;
579    PRCList *msg_p;
580    PRUint16 room_left = ss->ssl3.mtu;
581    PRInt32 sent;
582
583    ssl_GetXmitBufLock(ss);
584    ssl_GetSpecReadLock(ss);
585
586    /* DTLS does not buffer its handshake messages in
587     * ss->pendingBuf, but rather in the lastMessageFlight
588     * structure. This is just a sanity check that
589     * some programming error hasn't inadvertantly
590     * stuffed something in ss->pendingBuf
591     */
592    PORT_Assert(!ss->pendingBuf.len);
593    for (msg_p = PR_LIST_HEAD(&ss->ssl3.hs.lastMessageFlight);
594	 msg_p != &ss->ssl3.hs.lastMessageFlight;
595	 msg_p = PR_NEXT_LINK(msg_p)) {
596        DTLSQueuedMessage *msg = (DTLSQueuedMessage *)msg_p;
597
598        /* The logic here is:
599         *
600	 * 1. If this is a message that will not fit into the remaining
601	 *    space, then flush.
602	 * 2. If the message will now fit into the remaining space,
603         *    encrypt, buffer, and loop.
604         * 3. If the message will not fit, then fragment.
605         *
606	 * At the end of the function, flush.
607         */
608        if ((msg->len + SSL3_BUFFER_FUDGE) > room_left) {
609	    /* The message will not fit into the remaining space, so flush */
610	    rv = dtls_SendSavedWriteData(ss);
611	    if (rv != SECSuccess)
612		break;
613
614            room_left = ss->ssl3.mtu;
615	}
616
617        if ((msg->len + SSL3_BUFFER_FUDGE) <= room_left) {
618            /* The message will fit, so encrypt and then continue with the
619	     * next packet */
620            sent = ssl3_SendRecord(ss, msg->epoch, msg->type,
621				   msg->data, msg->len,
622				   ssl_SEND_FLAG_FORCE_INTO_BUFFER |
623				   ssl_SEND_FLAG_USE_EPOCH);
624            if (sent != msg->len) {
625		rv = SECFailure;
626		if (sent != -1) {
627		    PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
628		}
629                break;
630	    }
631
632            room_left = ss->ssl3.mtu - ss->pendingBuf.len;
633        } else {
634            /* The message will not fit, so fragment.
635             *
636	     * XXX OK for now. Arrange to coalesce the last fragment
637	     * of this message with the next message if possible.
638	     * That would be more efficient.
639	     */
640            PRUint32 fragment_offset = 0;
641            unsigned char fragment[DTLS_MAX_MTU]; /* >= than largest
642                                                   * plausible MTU */
643
644	    /* Assert that we have already flushed */
645	    PORT_Assert(room_left == ss->ssl3.mtu);
646
647            /* Case 3: We now need to fragment this message
648             * DTLS only supports fragmenting handshaking messages */
649            PORT_Assert(msg->type == content_handshake);
650
651	    /* The headers consume 12 bytes so the smalles possible
652	     *  message (i.e., an empty one) is 12 bytes
653	     */
654	    PORT_Assert(msg->len >= 12);
655
656            while ((fragment_offset + 12) < msg->len) {
657                PRUint32 fragment_len;
658                const unsigned char *content = msg->data + 12;
659                PRUint32 content_len = msg->len - 12;
660
661		/* The reason we use 8 here is that that's the length of
662		 * the new DTLS data that we add to the header */
663                fragment_len = PR_MIN(room_left - (SSL3_BUFFER_FUDGE + 8),
664                                      content_len - fragment_offset);
665		PORT_Assert(fragment_len < DTLS_MAX_MTU - 12);
666		/* Make totally sure that we are within the buffer.
667		 * Note that the only way that fragment len could get
668		 * adjusted here is if
669                 *
670		 * (a) we are in release mode so the PORT_Assert is compiled out
671		 * (b) either the MTU table is inconsistent with DTLS_MAX_MTU
672		 * or ss->ssl3.mtu has become corrupt.
673		 */
674		fragment_len = PR_MIN(fragment_len, DTLS_MAX_MTU - 12);
675
676                /* Construct an appropriate-sized fragment */
677                /* Type, length, sequence */
678                PORT_Memcpy(fragment, msg->data, 6);
679
680                /* Offset */
681                fragment[6] = (fragment_offset >> 16) & 0xff;
682                fragment[7] = (fragment_offset >> 8) & 0xff;
683                fragment[8] = (fragment_offset) & 0xff;
684
685                /* Fragment length */
686                fragment[9] = (fragment_len >> 16) & 0xff;
687                fragment[10] = (fragment_len >> 8) & 0xff;
688                fragment[11] = (fragment_len) & 0xff;
689
690                PORT_Memcpy(fragment + 12, content + fragment_offset,
691                            fragment_len);
692
693                /*
694		 *  Send the record. We do this in two stages
695		 * 1. Encrypt
696		 */
697                sent = ssl3_SendRecord(ss, msg->epoch, msg->type,
698                                       fragment, fragment_len + 12,
699                                       ssl_SEND_FLAG_FORCE_INTO_BUFFER |
700				       ssl_SEND_FLAG_USE_EPOCH);
701                if (sent != (fragment_len + 12)) {
702		    rv = SECFailure;
703		    if (sent != -1) {
704			PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
705		    }
706		    break;
707		}
708
709		/* 2. Flush */
710		rv = dtls_SendSavedWriteData(ss);
711		if (rv != SECSuccess)
712		    break;
713
714                fragment_offset += fragment_len;
715            }
716        }
717    }
718
719    /* Finally, we need to flush */
720    if (rv == SECSuccess)
721	rv = dtls_SendSavedWriteData(ss);
722
723    /* Give up the locks */
724    ssl_ReleaseSpecReadLock(ss);
725    ssl_ReleaseXmitBufLock(ss);
726
727    return rv;
728}
729
730/* Flush the data in the pendingBuf and update the max message sent
731 * so we can adjust the MTU estimate if we need to.
732 * Wrapper for ssl_SendSavedWriteData.
733 *
734 * Called from dtls_TransmitMessageFlight()
735 */
736static
737SECStatus dtls_SendSavedWriteData(sslSocket *ss)
738{
739    PRInt32 sent;
740
741    sent = ssl_SendSavedWriteData(ss);
742    if (sent < 0)
743	return SECFailure;
744
745    /* We should always have complete writes b/c datagram sockets
746     * don't really block */
747    if (ss->pendingBuf.len > 0) {
748	ssl_MapLowLevelError(SSL_ERROR_SOCKET_WRITE_FAILURE);
749    	return SECFailure;
750    }
751
752    /* Update the largest message sent so we can adjust the MTU
753     * estimate if necessary */
754    if (sent > ss->ssl3.hs.maxMessageSent)
755	ss->ssl3.hs.maxMessageSent = sent;
756
757    return SECSuccess;
758}
759
760/* Compress, MAC, encrypt a DTLS record. Allows specification of
761 * the epoch using epoch value. If use_epoch is PR_TRUE then
762 * we use the provided epoch. If use_epoch is PR_FALSE then
763 * whatever the current value is in effect is used.
764 *
765 * Called from ssl3_SendRecord()
766 */
767SECStatus
768dtls_CompressMACEncryptRecord(sslSocket *        ss,
769                              DTLSEpoch          epoch,
770			      PRBool             use_epoch,
771                              SSL3ContentType    type,
772		              const SSL3Opaque * pIn,
773		              PRUint32           contentLen,
774			      sslBuffer        * wrBuf)
775{
776    SECStatus rv = SECFailure;
777    ssl3CipherSpec *          cwSpec;
778
779    ssl_GetSpecReadLock(ss);	/********************************/
780
781    /* The reason for this switch-hitting code is that we might have
782     * a flight of records spanning an epoch boundary, e.g.,
783     *
784     * ClientKeyExchange (epoch = 0)
785     * ChangeCipherSpec (epoch = 0)
786     * Finished (epoch = 1)
787     *
788     * Thus, each record needs a different cipher spec. The information
789     * about which epoch to use is carried with the record.
790     */
791    if (use_epoch) {
792	if (ss->ssl3.cwSpec->epoch == epoch)
793	    cwSpec = ss->ssl3.cwSpec;
794	else if (ss->ssl3.pwSpec->epoch == epoch)
795	    cwSpec = ss->ssl3.pwSpec;
796	else
797	    cwSpec = NULL;
798    } else {
799	cwSpec = ss->ssl3.cwSpec;
800    }
801
802    if (cwSpec) {
803        rv = ssl3_CompressMACEncryptRecord(cwSpec, ss->sec.isServer, PR_TRUE,
804					   PR_FALSE, type, pIn, contentLen,
805					   wrBuf);
806    } else {
807        PR_NOT_REACHED("Couldn't find a cipher spec matching epoch");
808	PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
809    }
810    ssl_ReleaseSpecReadLock(ss); /************************************/
811
812    return rv;
813}
814
815/* Start a timer
816 *
817 * Called from:
818 *             dtls_HandleHandshake()
819 *             dtls_FlushHAndshake()
820 *             dtls_RestartTimer()
821 */
822SECStatus
823dtls_StartTimer(sslSocket *ss, DTLSTimerCb cb)
824{
825    PORT_Assert(ss->ssl3.hs.rtTimerCb == NULL);
826
827    ss->ssl3.hs.rtTimerStarted = PR_IntervalNow();
828    ss->ssl3.hs.rtTimerCb = cb;
829
830    return SECSuccess;
831}
832
833/* Restart a timer with optional backoff
834 *
835 * Called from dtls_RetransmitTimerExpiredCb()
836 */
837SECStatus
838dtls_RestartTimer(sslSocket *ss, PRBool backoff, DTLSTimerCb cb)
839{
840    if (backoff) {
841	ss->ssl3.hs.rtTimeoutMs *= 2;
842	if (ss->ssl3.hs.rtTimeoutMs > MAX_DTLS_TIMEOUT_MS)
843	    ss->ssl3.hs.rtTimeoutMs = MAX_DTLS_TIMEOUT_MS;
844    }
845
846    return dtls_StartTimer(ss, cb);
847}
848
849/* Cancel a pending timer
850 *
851 * Called from:
852 *              dtls_HandleHandshake()
853 *              dtls_CheckTimer()
854 */
855void
856dtls_CancelTimer(sslSocket *ss)
857{
858    PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
859
860    ss->ssl3.hs.rtTimerCb = NULL;
861}
862
863/* Check the pending timer and fire the callback if it expired
864 *
865 * Called from ssl3_GatherCompleteHandshake()
866 */
867void
868dtls_CheckTimer(sslSocket *ss)
869{
870    if (!ss->ssl3.hs.rtTimerCb)
871	return;
872
873    if ((PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted) >
874	PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs)) {
875	/* Timer has expired */
876	DTLSTimerCb cb = ss->ssl3.hs.rtTimerCb;
877
878	/* Cancel the timer so that we can call the CB safely */
879	dtls_CancelTimer(ss);
880
881	/* Now call the CB */
882	cb(ss);
883    }
884}
885
886/* The callback to fire when the holddown timer for the Finished
887 * message expires and we can delete it
888 *
889 * Called from dtls_CheckTimer()
890 */
891void
892dtls_FinishedTimerCb(sslSocket *ss)
893{
894    ssl3_DestroyCipherSpec(ss->ssl3.pwSpec, PR_FALSE);
895}
896
897/* Cancel the Finished hold-down timer and destroy the
898 * pending cipher spec. Note that this means that
899 * successive rehandshakes will fail if the Finished is
900 * lost.
901 *
902 * XXX OK for now. Figure out how to handle the combination
903 * of Finished lost and rehandshake
904 */
905void
906dtls_RehandshakeCleanup(sslSocket *ss)
907{
908    dtls_CancelTimer(ss);
909    ssl3_DestroyCipherSpec(ss->ssl3.pwSpec, PR_FALSE);
910    ss->ssl3.hs.sendMessageSeq = 0;
911    ss->ssl3.hs.recvMessageSeq = 0;
912}
913
914/* Set the MTU to the next step less than or equal to the
915 * advertised value. Also used to downgrade the MTU by
916 * doing dtls_SetMTU(ss, biggest packet set).
917 *
918 * Passing 0 means set this to the largest MTU known
919 * (effectively resetting the PMTU backoff value).
920 *
921 * Called by:
922 *            ssl3_InitState()
923 *            dtls_RetransmitTimerExpiredCb()
924 */
925void
926dtls_SetMTU(sslSocket *ss, PRUint16 advertised)
927{
928    int i;
929
930    if (advertised == 0) {
931	ss->ssl3.mtu = COMMON_MTU_VALUES[0];
932	SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
933	return;
934    }
935
936    for (i = 0; i < PR_ARRAY_SIZE(COMMON_MTU_VALUES); i++) {
937	if (COMMON_MTU_VALUES[i] <= advertised) {
938	    ss->ssl3.mtu = COMMON_MTU_VALUES[i];
939	    SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
940	    return;
941	}
942    }
943
944    /* Fallback */
945    ss->ssl3.mtu = COMMON_MTU_VALUES[PR_ARRAY_SIZE(COMMON_MTU_VALUES)-1];
946    SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
947}
948
949/* Called from ssl3_HandleHandshakeMessage() when it has deciphered a
950 * DTLS hello_verify_request
951 * Caller must hold Handshake and RecvBuf locks.
952 */
953SECStatus
954dtls_HandleHelloVerifyRequest(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
955{
956    int                 errCode	= SSL_ERROR_RX_MALFORMED_HELLO_VERIFY_REQUEST;
957    SECStatus           rv;
958    PRInt32             temp;
959    SECItem             cookie = {siBuffer, NULL, 0};
960    SSL3AlertDescription desc   = illegal_parameter;
961
962    SSL_TRC(3, ("%d: SSL3[%d]: handle hello_verify_request handshake",
963    	SSL_GETPID(), ss->fd));
964    PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
965    PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
966
967    if (ss->ssl3.hs.ws != wait_server_hello) {
968        errCode = SSL_ERROR_RX_UNEXPECTED_HELLO_VERIFY_REQUEST;
969	desc    = unexpected_message;
970	goto alert_loser;
971    }
972
973    /* The version */
974    temp = ssl3_ConsumeHandshakeNumber(ss, 2, &b, &length);
975    if (temp < 0) {
976    	goto loser; 	/* alert has been sent */
977    }
978
979    if (temp != SSL_LIBRARY_VERSION_DTLS_1_0_WIRE) {
980	/* Note: this will need adjustment for DTLS 1.2 per Section 4.2.1 */
981	goto alert_loser;
982    }
983
984    /* The cookie */
985    rv = ssl3_ConsumeHandshakeVariable(ss, &cookie, 1, &b, &length);
986    if (rv != SECSuccess) {
987    	goto loser; 	/* alert has been sent */
988    }
989    if (cookie.len > DTLS_COOKIE_BYTES) {
990	desc = decode_error;
991	goto alert_loser;	/* malformed. */
992    }
993
994    PORT_Memcpy(ss->ssl3.hs.cookie, cookie.data, cookie.len);
995    ss->ssl3.hs.cookieLen = cookie.len;
996
997
998    ssl_GetXmitBufLock(ss);		/*******************************/
999
1000    /* Now re-send the client hello */
1001    rv = ssl3_SendClientHello(ss, PR_TRUE);
1002
1003    ssl_ReleaseXmitBufLock(ss);		/*******************************/
1004
1005    if (rv == SECSuccess)
1006	return rv;
1007
1008alert_loser:
1009    (void)SSL3_SendAlert(ss, alert_fatal, desc);
1010
1011loser:
1012    errCode = ssl_MapLowLevelError(errCode);
1013    return SECFailure;
1014}
1015
1016/* Initialize the DTLS anti-replay window
1017 *
1018 * Called from:
1019 *              ssl3_SetupPendingCipherSpec()
1020 *              ssl3_InitCipherSpec()
1021 */
1022void
1023dtls_InitRecvdRecords(DTLSRecvdRecords *records)
1024{
1025    PORT_Memset(records->data, 0, sizeof(records->data));
1026    records->left = 0;
1027    records->right = DTLS_RECVD_RECORDS_WINDOW - 1;
1028}
1029
1030/*
1031 * Has this DTLS record been received? Return values are:
1032 * -1 -- out of range to the left
1033 *  0 -- not received yet
1034 *  1 -- replay
1035 *
1036 *  Called from: dtls_HandleRecord()
1037 */
1038int
1039dtls_RecordGetRecvd(DTLSRecvdRecords *records, PRUint64 seq)
1040{
1041    PRUint64 offset;
1042
1043    /* Out of range to the left */
1044    if (seq < records->left) {
1045	return -1;
1046    }
1047
1048    /* Out of range to the right; since we advance the window on
1049     * receipt, that means that this packet has not been received
1050     * yet */
1051    if (seq > records->right)
1052	return 0;
1053
1054    offset = seq % DTLS_RECVD_RECORDS_WINDOW;
1055
1056    return !!(records->data[offset / 8] & (1 << (offset % 8)));
1057}
1058
1059/* Update the DTLS anti-replay window
1060 *
1061 * Called from ssl3_HandleRecord()
1062 */
1063void
1064dtls_RecordSetRecvd(DTLSRecvdRecords *records, PRUint64 seq)
1065{
1066    PRUint64 offset;
1067
1068    if (seq < records->left)
1069	return;
1070
1071    if (seq > records->right) {
1072	PRUint64 new_left;
1073	PRUint64 new_right;
1074	PRUint64 right;
1075
1076	/* Slide to the right; this is the tricky part
1077         *
1078	 * 1. new_top is set to have room for seq, on the
1079	 *    next byte boundary by setting the right 8
1080	 *    bits of seq
1081         * 2. new_left is set to compensate.
1082         * 3. Zero all bits between top and new_top. Since
1083         *    this is a ring, this zeroes everything as-yet
1084	 *    unseen. Because we always operate on byte
1085	 *    boundaries, we can zero one byte at a time
1086	 */
1087	new_right = seq | 0x07;
1088	new_left = (new_right - DTLS_RECVD_RECORDS_WINDOW) + 1;
1089
1090	for (right = records->right + 8; right <= new_right; right += 8) {
1091	    offset = right % DTLS_RECVD_RECORDS_WINDOW;
1092	    records->data[offset / 8] = 0;
1093	}
1094
1095	records->right = new_right;
1096	records->left = new_left;
1097    }
1098
1099    offset = seq % DTLS_RECVD_RECORDS_WINDOW;
1100
1101    records->data[offset / 8] |= (1 << (offset % 8));
1102}
1103
1104SECStatus
1105DTLS_GetHandshakeTimeout(PRFileDesc *socket, PRIntervalTime *timeout)
1106{
1107    sslSocket * ss = NULL;
1108    PRIntervalTime elapsed;
1109    PRIntervalTime desired;
1110
1111    ss = ssl_FindSocket(socket);
1112
1113    if (!ss)
1114        return SECFailure;
1115
1116    if (!IS_DTLS(ss))
1117        return SECFailure;
1118
1119    if (!ss->ssl3.hs.rtTimerCb)
1120        return SECFailure;
1121
1122    elapsed = PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted;
1123    desired = PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs);
1124    if (elapsed > desired) {
1125        /* Timer expired */
1126        *timeout = PR_INTERVAL_NO_WAIT;
1127    } else {
1128        *timeout = desired - elapsed;
1129    }
1130
1131    return SECSuccess;
1132}
1133