michael@0: /* This Source Code Form is subject to the terms of the Mozilla Public michael@0: * License, v. 2.0. If a copy of the MPL was not distributed with this michael@0: * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ michael@0: michael@0: /* michael@0: * DTLS Protocol michael@0: */ michael@0: michael@0: #include "ssl.h" michael@0: #include "sslimpl.h" michael@0: #include "sslproto.h" michael@0: michael@0: #ifndef PR_ARRAY_SIZE michael@0: #define PR_ARRAY_SIZE(a) (sizeof(a)/sizeof((a)[0])) michael@0: #endif michael@0: michael@0: static SECStatus dtls_TransmitMessageFlight(sslSocket *ss); michael@0: static void dtls_RetransmitTimerExpiredCb(sslSocket *ss); michael@0: static SECStatus dtls_SendSavedWriteData(sslSocket *ss); michael@0: michael@0: /* -28 adjusts for the IP/UDP header */ michael@0: static const PRUint16 COMMON_MTU_VALUES[] = { michael@0: 1500 - 28, /* Ethernet MTU */ michael@0: 1280 - 28, /* IPv6 minimum MTU */ michael@0: 576 - 28, /* Common assumption */ michael@0: 256 - 28 /* We're in serious trouble now */ michael@0: }; michael@0: michael@0: #define DTLS_COOKIE_BYTES 32 michael@0: michael@0: /* List copied from ssl3con.c:cipherSuites */ michael@0: static const ssl3CipherSuite nonDTLSSuites[] = { michael@0: #ifndef NSS_DISABLE_ECC michael@0: TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, michael@0: TLS_ECDHE_RSA_WITH_RC4_128_SHA, michael@0: #endif /* NSS_DISABLE_ECC */ michael@0: TLS_DHE_DSS_WITH_RC4_128_SHA, michael@0: #ifndef NSS_DISABLE_ECC michael@0: TLS_ECDH_RSA_WITH_RC4_128_SHA, michael@0: TLS_ECDH_ECDSA_WITH_RC4_128_SHA, michael@0: #endif /* NSS_DISABLE_ECC */ michael@0: TLS_RSA_WITH_RC4_128_MD5, michael@0: TLS_RSA_WITH_RC4_128_SHA, michael@0: TLS_RSA_EXPORT1024_WITH_RC4_56_SHA, michael@0: TLS_RSA_EXPORT_WITH_RC4_40_MD5, michael@0: 0 /* End of list marker */ michael@0: }; michael@0: michael@0: /* Map back and forth between TLS and DTLS versions in wire format. michael@0: * Mapping table is: michael@0: * michael@0: * TLS DTLS michael@0: * 1.1 (0302) 1.0 (feff) michael@0: * 1.2 (0303) 1.2 (fefd) michael@0: */ michael@0: SSL3ProtocolVersion michael@0: dtls_TLSVersionToDTLSVersion(SSL3ProtocolVersion tlsv) michael@0: { michael@0: if (tlsv == SSL_LIBRARY_VERSION_TLS_1_1) { michael@0: return SSL_LIBRARY_VERSION_DTLS_1_0_WIRE; michael@0: } michael@0: if (tlsv == SSL_LIBRARY_VERSION_TLS_1_2) { michael@0: return SSL_LIBRARY_VERSION_DTLS_1_2_WIRE; michael@0: } michael@0: michael@0: /* Anything other than TLS 1.1 or 1.2 is an error, so return michael@0: * the invalid version 0xffff. */ michael@0: return 0xffff; michael@0: } michael@0: michael@0: /* Map known DTLS versions to known TLS versions. michael@0: * - Invalid versions (< 1.0) return a version of 0 michael@0: * - Versions > known return a version one higher than we know of michael@0: * to accomodate a theoretically newer version */ michael@0: SSL3ProtocolVersion michael@0: dtls_DTLSVersionToTLSVersion(SSL3ProtocolVersion dtlsv) michael@0: { michael@0: if (MSB(dtlsv) == 0xff) { michael@0: return 0; michael@0: } michael@0: michael@0: if (dtlsv == SSL_LIBRARY_VERSION_DTLS_1_0_WIRE) { michael@0: return SSL_LIBRARY_VERSION_TLS_1_1; michael@0: } michael@0: if (dtlsv == SSL_LIBRARY_VERSION_DTLS_1_2_WIRE) { michael@0: return SSL_LIBRARY_VERSION_TLS_1_2; michael@0: } michael@0: michael@0: /* Return a fictional higher version than we know of */ michael@0: return SSL_LIBRARY_VERSION_TLS_1_2 + 1; michael@0: } michael@0: michael@0: /* On this socket, Disable non-DTLS cipher suites in the argument's list */ michael@0: SECStatus michael@0: ssl3_DisableNonDTLSSuites(sslSocket * ss) michael@0: { michael@0: const ssl3CipherSuite * suite; michael@0: michael@0: for (suite = nonDTLSSuites; *suite; ++suite) { michael@0: SECStatus rv = ssl3_CipherPrefSet(ss, *suite, PR_FALSE); michael@0: michael@0: PORT_Assert(rv == SECSuccess); /* else is coding error */ michael@0: } michael@0: return SECSuccess; michael@0: } michael@0: michael@0: /* Allocate a DTLSQueuedMessage. michael@0: * michael@0: * Called from dtls_QueueMessage() michael@0: */ michael@0: static DTLSQueuedMessage * michael@0: dtls_AllocQueuedMessage(PRUint16 epoch, SSL3ContentType type, michael@0: const unsigned char *data, PRUint32 len) michael@0: { michael@0: DTLSQueuedMessage *msg = NULL; michael@0: michael@0: msg = PORT_ZAlloc(sizeof(DTLSQueuedMessage)); michael@0: if (!msg) michael@0: return NULL; michael@0: michael@0: msg->data = PORT_Alloc(len); michael@0: if (!msg->data) { michael@0: PORT_Free(msg); michael@0: return NULL; michael@0: } michael@0: PORT_Memcpy(msg->data, data, len); michael@0: michael@0: msg->len = len; michael@0: msg->epoch = epoch; michael@0: msg->type = type; michael@0: michael@0: return msg; michael@0: } michael@0: michael@0: /* michael@0: * Free a handshake message michael@0: * michael@0: * Called from dtls_FreeHandshakeMessages() michael@0: */ michael@0: static void michael@0: dtls_FreeHandshakeMessage(DTLSQueuedMessage *msg) michael@0: { michael@0: if (!msg) michael@0: return; michael@0: michael@0: PORT_ZFree(msg->data, msg->len); michael@0: PORT_Free(msg); michael@0: } michael@0: michael@0: /* michael@0: * Free a list of handshake messages michael@0: * michael@0: * Called from: michael@0: * dtls_HandleHandshake() michael@0: * ssl3_DestroySSL3Info() michael@0: */ michael@0: void michael@0: dtls_FreeHandshakeMessages(PRCList *list) michael@0: { michael@0: PRCList *cur_p; michael@0: michael@0: while (!PR_CLIST_IS_EMPTY(list)) { michael@0: cur_p = PR_LIST_TAIL(list); michael@0: PR_REMOVE_LINK(cur_p); michael@0: dtls_FreeHandshakeMessage((DTLSQueuedMessage *)cur_p); michael@0: } michael@0: } michael@0: michael@0: /* Called only from ssl3_HandleRecord, for each (deciphered) DTLS record. michael@0: * origBuf is the decrypted ssl record content and is expected to contain michael@0: * complete handshake records michael@0: * Caller must hold the handshake and RecvBuf locks. michael@0: * michael@0: * Note that this code uses msg_len for two purposes: michael@0: * michael@0: * (1) To pass the length to ssl3_HandleHandshakeMessage() michael@0: * (2) To carry the length of a message currently being reassembled michael@0: * michael@0: * However, unlike ssl3_HandleHandshake(), it is not used to carry michael@0: * the state of reassembly (i.e., whether one is in progress). That michael@0: * is carried in recvdHighWater and recvdFragments. michael@0: */ michael@0: #define OFFSET_BYTE(o) (o/8) michael@0: #define OFFSET_MASK(o) (1 << (o%8)) michael@0: michael@0: SECStatus michael@0: dtls_HandleHandshake(sslSocket *ss, sslBuffer *origBuf) michael@0: { michael@0: /* XXX OK for now. michael@0: * This doesn't work properly with asynchronous certificate validation. michael@0: * because that returns a WOULDBLOCK error. The current DTLS michael@0: * applications do not need asynchronous validation, but in the michael@0: * future we will need to add this. michael@0: */ michael@0: sslBuffer buf = *origBuf; michael@0: SECStatus rv = SECSuccess; michael@0: michael@0: PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss)); michael@0: PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); michael@0: michael@0: while (buf.len > 0) { michael@0: PRUint8 type; michael@0: PRUint32 message_length; michael@0: PRUint16 message_seq; michael@0: PRUint32 fragment_offset; michael@0: PRUint32 fragment_length; michael@0: PRUint32 offset; michael@0: michael@0: if (buf.len < 12) { michael@0: PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE); michael@0: rv = SECFailure; michael@0: break; michael@0: } michael@0: michael@0: /* Parse the header */ michael@0: type = buf.buf[0]; michael@0: message_length = (buf.buf[1] << 16) | (buf.buf[2] << 8) | buf.buf[3]; michael@0: message_seq = (buf.buf[4] << 8) | buf.buf[5]; michael@0: fragment_offset = (buf.buf[6] << 16) | (buf.buf[7] << 8) | buf.buf[8]; michael@0: fragment_length = (buf.buf[9] << 16) | (buf.buf[10] << 8) | buf.buf[11]; michael@0: michael@0: #define MAX_HANDSHAKE_MSG_LEN 0x1ffff /* 128k - 1 */ michael@0: if (message_length > MAX_HANDSHAKE_MSG_LEN) { michael@0: (void)ssl3_DecodeError(ss); michael@0: PORT_SetError(SSL_ERROR_RX_RECORD_TOO_LONG); michael@0: return SECFailure; michael@0: } michael@0: #undef MAX_HANDSHAKE_MSG_LEN michael@0: michael@0: buf.buf += 12; michael@0: buf.len -= 12; michael@0: michael@0: /* This fragment must be complete */ michael@0: if (buf.len < fragment_length) { michael@0: PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE); michael@0: rv = SECFailure; michael@0: break; michael@0: } michael@0: michael@0: /* Sanity check the packet contents */ michael@0: if ((fragment_length + fragment_offset) > message_length) { michael@0: PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE); michael@0: rv = SECFailure; michael@0: break; michael@0: } michael@0: michael@0: /* There are three ways we could not be ready for this packet. michael@0: * michael@0: * 1. It's a partial next message. michael@0: * 2. It's a partial or complete message beyond the next michael@0: * 3. It's a message we've already seen michael@0: * michael@0: * If it's the complete next message we accept it right away. michael@0: * This is the common case for short messages michael@0: */ michael@0: if ((message_seq == ss->ssl3.hs.recvMessageSeq) michael@0: && (fragment_offset == 0) michael@0: && (fragment_length == message_length)) { michael@0: /* Complete next message. Process immediately */ michael@0: ss->ssl3.hs.msg_type = (SSL3HandshakeType)type; michael@0: ss->ssl3.hs.msg_len = message_length; michael@0: michael@0: /* At this point we are advancing our state machine, so michael@0: * we can free our last flight of messages */ michael@0: dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight); michael@0: ss->ssl3.hs.recvdHighWater = -1; michael@0: dtls_CancelTimer(ss); michael@0: michael@0: /* Reset the timer to the initial value if the retry counter michael@0: * is 0, per Sec. 4.2.4.1 */ michael@0: if (ss->ssl3.hs.rtRetries == 0) { michael@0: ss->ssl3.hs.rtTimeoutMs = INITIAL_DTLS_TIMEOUT_MS; michael@0: } michael@0: michael@0: rv = ssl3_HandleHandshakeMessage(ss, buf.buf, ss->ssl3.hs.msg_len); michael@0: if (rv == SECFailure) { michael@0: /* Do not attempt to process rest of messages in this record */ michael@0: break; michael@0: } michael@0: } else { michael@0: if (message_seq < ss->ssl3.hs.recvMessageSeq) { michael@0: /* Case 3: we do an immediate retransmit if we're michael@0: * in a waiting state*/ michael@0: if (ss->ssl3.hs.rtTimerCb == NULL) { michael@0: /* Ignore */ michael@0: } else if (ss->ssl3.hs.rtTimerCb == michael@0: dtls_RetransmitTimerExpiredCb) { michael@0: SSL_TRC(30, ("%d: SSL3[%d]: Retransmit detected", michael@0: SSL_GETPID(), ss->fd)); michael@0: /* Check to see if we retransmitted recently. If so, michael@0: * suppress the triggered retransmit. This avoids michael@0: * retransmit wars after packet loss. michael@0: * This is not in RFC 5346 but should be michael@0: */ michael@0: if ((PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted) > michael@0: (ss->ssl3.hs.rtTimeoutMs / 4)) { michael@0: SSL_TRC(30, michael@0: ("%d: SSL3[%d]: Shortcutting retransmit timer", michael@0: SSL_GETPID(), ss->fd)); michael@0: michael@0: /* Cancel the timer and call the CB, michael@0: * which re-arms the timer */ michael@0: dtls_CancelTimer(ss); michael@0: dtls_RetransmitTimerExpiredCb(ss); michael@0: rv = SECSuccess; michael@0: break; michael@0: } else { michael@0: SSL_TRC(30, michael@0: ("%d: SSL3[%d]: We just retransmitted. Ignoring.", michael@0: SSL_GETPID(), ss->fd)); michael@0: rv = SECSuccess; michael@0: break; michael@0: } michael@0: } else if (ss->ssl3.hs.rtTimerCb == dtls_FinishedTimerCb) { michael@0: /* Retransmit the messages and re-arm the timer michael@0: * Note that we are not backing off the timer here. michael@0: * The spec isn't clear and my reasoning is that this michael@0: * may be a re-ordered packet rather than slowness, michael@0: * so let's be aggressive. */ michael@0: dtls_CancelTimer(ss); michael@0: rv = dtls_TransmitMessageFlight(ss); michael@0: if (rv == SECSuccess) { michael@0: rv = dtls_StartTimer(ss, dtls_FinishedTimerCb); michael@0: } michael@0: if (rv != SECSuccess) michael@0: return rv; michael@0: break; michael@0: } michael@0: } else if (message_seq > ss->ssl3.hs.recvMessageSeq) { michael@0: /* Case 2 michael@0: * michael@0: * Ignore this message. This means we don't handle out of michael@0: * order complete messages that well, but we're still michael@0: * compliant and this probably does not happen often michael@0: * michael@0: * XXX OK for now. Maybe do something smarter at some point? michael@0: */ michael@0: } else { michael@0: /* Case 1 michael@0: * michael@0: * Buffer the fragment for reassembly michael@0: */ michael@0: /* Make room for the message */ michael@0: if (ss->ssl3.hs.recvdHighWater == -1) { michael@0: PRUint32 map_length = OFFSET_BYTE(message_length) + 1; michael@0: michael@0: rv = sslBuffer_Grow(&ss->ssl3.hs.msg_body, message_length); michael@0: if (rv != SECSuccess) michael@0: break; michael@0: /* Make room for the fragment map */ michael@0: rv = sslBuffer_Grow(&ss->ssl3.hs.recvdFragments, michael@0: map_length); michael@0: if (rv != SECSuccess) michael@0: break; michael@0: michael@0: /* Reset the reassembly map */ michael@0: ss->ssl3.hs.recvdHighWater = 0; michael@0: PORT_Memset(ss->ssl3.hs.recvdFragments.buf, 0, michael@0: ss->ssl3.hs.recvdFragments.space); michael@0: ss->ssl3.hs.msg_type = (SSL3HandshakeType)type; michael@0: ss->ssl3.hs.msg_len = message_length; michael@0: } michael@0: michael@0: /* If we have a message length mismatch, abandon the reassembly michael@0: * in progress and hope that the next retransmit will give us michael@0: * something sane michael@0: */ michael@0: if (message_length != ss->ssl3.hs.msg_len) { michael@0: ss->ssl3.hs.recvdHighWater = -1; michael@0: PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE); michael@0: rv = SECFailure; michael@0: break; michael@0: } michael@0: michael@0: /* Now copy this fragment into the buffer */ michael@0: PORT_Assert((fragment_offset + fragment_length) <= michael@0: ss->ssl3.hs.msg_body.space); michael@0: PORT_Memcpy(ss->ssl3.hs.msg_body.buf + fragment_offset, michael@0: buf.buf, fragment_length); michael@0: michael@0: /* This logic is a bit tricky. We have two values for michael@0: * reassembly state: michael@0: * michael@0: * - recvdHighWater contains the highest contiguous number of michael@0: * bytes received michael@0: * - recvdFragments contains a bitmask of packets received michael@0: * above recvdHighWater michael@0: * michael@0: * This avoids having to fill in the bitmask in the common michael@0: * case of adjacent fragments received in sequence michael@0: */ michael@0: if (fragment_offset <= ss->ssl3.hs.recvdHighWater) { michael@0: /* Either this is the adjacent fragment or an overlapping michael@0: * fragment */ michael@0: ss->ssl3.hs.recvdHighWater = fragment_offset + michael@0: fragment_length; michael@0: } else { michael@0: for (offset = fragment_offset; michael@0: offset < fragment_offset + fragment_length; michael@0: offset++) { michael@0: ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] |= michael@0: OFFSET_MASK(offset); michael@0: } michael@0: } michael@0: michael@0: /* Now figure out the new high water mark if appropriate */ michael@0: for (offset = ss->ssl3.hs.recvdHighWater; michael@0: offset < ss->ssl3.hs.msg_len; offset++) { michael@0: /* Note that this loop is not efficient, since it counts michael@0: * bit by bit. If we have a lot of out-of-order packets, michael@0: * we should optimize this */ michael@0: if (ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] & michael@0: OFFSET_MASK(offset)) { michael@0: ss->ssl3.hs.recvdHighWater++; michael@0: } else { michael@0: break; michael@0: } michael@0: } michael@0: michael@0: /* If we have all the bytes, then we are good to go */ michael@0: if (ss->ssl3.hs.recvdHighWater == ss->ssl3.hs.msg_len) { michael@0: ss->ssl3.hs.recvdHighWater = -1; michael@0: michael@0: rv = ssl3_HandleHandshakeMessage(ss, michael@0: ss->ssl3.hs.msg_body.buf, michael@0: ss->ssl3.hs.msg_len); michael@0: if (rv == SECFailure) michael@0: break; /* Skip rest of record */ michael@0: michael@0: /* At this point we are advancing our state machine, so michael@0: * we can free our last flight of messages */ michael@0: dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight); michael@0: dtls_CancelTimer(ss); michael@0: michael@0: /* If there have been no retries this time, reset the michael@0: * timer value to the default per Section 4.2.4.1 */ michael@0: if (ss->ssl3.hs.rtRetries == 0) { michael@0: ss->ssl3.hs.rtTimeoutMs = INITIAL_DTLS_TIMEOUT_MS; michael@0: } michael@0: } michael@0: } michael@0: } michael@0: michael@0: buf.buf += fragment_length; michael@0: buf.len -= fragment_length; michael@0: } michael@0: michael@0: origBuf->len = 0; /* So ssl3_GatherAppDataRecord will keep looping. */ michael@0: michael@0: /* XXX OK for now. In future handle rv == SECWouldBlock safely in order michael@0: * to deal with asynchronous certificate verification */ michael@0: return rv; michael@0: } michael@0: michael@0: /* Enqueue a message (either handshake or CCS) michael@0: * michael@0: * Called from: michael@0: * dtls_StageHandshakeMessage() michael@0: * ssl3_SendChangeCipherSpecs() michael@0: */ michael@0: SECStatus dtls_QueueMessage(sslSocket *ss, SSL3ContentType type, michael@0: const SSL3Opaque *pIn, PRInt32 nIn) michael@0: { michael@0: SECStatus rv = SECSuccess; michael@0: DTLSQueuedMessage *msg = NULL; michael@0: michael@0: PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); michael@0: PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss)); michael@0: michael@0: msg = dtls_AllocQueuedMessage(ss->ssl3.cwSpec->epoch, type, pIn, nIn); michael@0: michael@0: if (!msg) { michael@0: PORT_SetError(SEC_ERROR_NO_MEMORY); michael@0: rv = SECFailure; michael@0: } else { michael@0: PR_APPEND_LINK(&msg->link, &ss->ssl3.hs.lastMessageFlight); michael@0: } michael@0: michael@0: return rv; michael@0: } michael@0: michael@0: /* Add DTLS handshake message to the pending queue michael@0: * Empty the sendBuf buffer. michael@0: * This function returns SECSuccess or SECFailure, never SECWouldBlock. michael@0: * Always set sendBuf.len to 0, even when returning SECFailure. michael@0: * michael@0: * Called from: michael@0: * ssl3_AppendHandshakeHeader() michael@0: * dtls_FlushHandshake() michael@0: */ michael@0: SECStatus michael@0: dtls_StageHandshakeMessage(sslSocket *ss) michael@0: { michael@0: SECStatus rv = SECSuccess; michael@0: michael@0: PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); michael@0: PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss)); michael@0: michael@0: /* This function is sometimes called when no data is actually to michael@0: * be staged, so just return SECSuccess. */ michael@0: if (!ss->sec.ci.sendBuf.buf || !ss->sec.ci.sendBuf.len) michael@0: return rv; michael@0: michael@0: rv = dtls_QueueMessage(ss, content_handshake, michael@0: ss->sec.ci.sendBuf.buf, ss->sec.ci.sendBuf.len); michael@0: michael@0: /* Whether we succeeded or failed, toss the old handshake data. */ michael@0: ss->sec.ci.sendBuf.len = 0; michael@0: return rv; michael@0: } michael@0: michael@0: /* Enqueue the handshake message in sendBuf (if any) and then michael@0: * transmit the resulting flight of handshake messages. michael@0: * michael@0: * Called from: michael@0: * ssl3_FlushHandshake() michael@0: */ michael@0: SECStatus michael@0: dtls_FlushHandshakeMessages(sslSocket *ss, PRInt32 flags) michael@0: { michael@0: SECStatus rv = SECSuccess; michael@0: michael@0: PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); michael@0: PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss)); michael@0: michael@0: rv = dtls_StageHandshakeMessage(ss); michael@0: if (rv != SECSuccess) michael@0: return rv; michael@0: michael@0: if (!(flags & ssl_SEND_FLAG_FORCE_INTO_BUFFER)) { michael@0: rv = dtls_TransmitMessageFlight(ss); michael@0: if (rv != SECSuccess) michael@0: return rv; michael@0: michael@0: if (!(flags & ssl_SEND_FLAG_NO_RETRANSMIT)) { michael@0: ss->ssl3.hs.rtRetries = 0; michael@0: rv = dtls_StartTimer(ss, dtls_RetransmitTimerExpiredCb); michael@0: } michael@0: } michael@0: michael@0: return rv; michael@0: } michael@0: michael@0: /* The callback for when the retransmit timer expires michael@0: * michael@0: * Called from: michael@0: * dtls_CheckTimer() michael@0: * dtls_HandleHandshake() michael@0: */ michael@0: static void michael@0: dtls_RetransmitTimerExpiredCb(sslSocket *ss) michael@0: { michael@0: SECStatus rv = SECFailure; michael@0: michael@0: ss->ssl3.hs.rtRetries++; michael@0: michael@0: if (!(ss->ssl3.hs.rtRetries % 3)) { michael@0: /* If one of the messages was potentially greater than > MTU, michael@0: * then downgrade. Do this every time we have retransmitted a michael@0: * message twice, per RFC 6347 Sec. 4.1.1 */ michael@0: dtls_SetMTU(ss, ss->ssl3.hs.maxMessageSent - 1); michael@0: } michael@0: michael@0: rv = dtls_TransmitMessageFlight(ss); michael@0: if (rv == SECSuccess) { michael@0: michael@0: /* Re-arm the timer */ michael@0: rv = dtls_RestartTimer(ss, PR_TRUE, dtls_RetransmitTimerExpiredCb); michael@0: } michael@0: michael@0: if (rv == SECFailure) { michael@0: /* XXX OK for now. In future maybe signal the stack that we couldn't michael@0: * transmit. For now, let the read handle any real network errors */ michael@0: } michael@0: } michael@0: michael@0: /* Transmit a flight of handshake messages, stuffing them michael@0: * into as few records as seems reasonable michael@0: * michael@0: * Called from: michael@0: * dtls_FlushHandshake() michael@0: * dtls_RetransmitTimerExpiredCb() michael@0: */ michael@0: static SECStatus michael@0: dtls_TransmitMessageFlight(sslSocket *ss) michael@0: { michael@0: SECStatus rv = SECSuccess; michael@0: PRCList *msg_p; michael@0: PRUint16 room_left = ss->ssl3.mtu; michael@0: PRInt32 sent; michael@0: michael@0: ssl_GetXmitBufLock(ss); michael@0: ssl_GetSpecReadLock(ss); michael@0: michael@0: /* DTLS does not buffer its handshake messages in michael@0: * ss->pendingBuf, but rather in the lastMessageFlight michael@0: * structure. This is just a sanity check that michael@0: * some programming error hasn't inadvertantly michael@0: * stuffed something in ss->pendingBuf michael@0: */ michael@0: PORT_Assert(!ss->pendingBuf.len); michael@0: for (msg_p = PR_LIST_HEAD(&ss->ssl3.hs.lastMessageFlight); michael@0: msg_p != &ss->ssl3.hs.lastMessageFlight; michael@0: msg_p = PR_NEXT_LINK(msg_p)) { michael@0: DTLSQueuedMessage *msg = (DTLSQueuedMessage *)msg_p; michael@0: michael@0: /* The logic here is: michael@0: * michael@0: * 1. If this is a message that will not fit into the remaining michael@0: * space, then flush. michael@0: * 2. If the message will now fit into the remaining space, michael@0: * encrypt, buffer, and loop. michael@0: * 3. If the message will not fit, then fragment. michael@0: * michael@0: * At the end of the function, flush. michael@0: */ michael@0: if ((msg->len + SSL3_BUFFER_FUDGE) > room_left) { michael@0: /* The message will not fit into the remaining space, so flush */ michael@0: rv = dtls_SendSavedWriteData(ss); michael@0: if (rv != SECSuccess) michael@0: break; michael@0: michael@0: room_left = ss->ssl3.mtu; michael@0: } michael@0: michael@0: if ((msg->len + SSL3_BUFFER_FUDGE) <= room_left) { michael@0: /* The message will fit, so encrypt and then continue with the michael@0: * next packet */ michael@0: sent = ssl3_SendRecord(ss, msg->epoch, msg->type, michael@0: msg->data, msg->len, michael@0: ssl_SEND_FLAG_FORCE_INTO_BUFFER | michael@0: ssl_SEND_FLAG_USE_EPOCH); michael@0: if (sent != msg->len) { michael@0: rv = SECFailure; michael@0: if (sent != -1) { michael@0: PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); michael@0: } michael@0: break; michael@0: } michael@0: michael@0: room_left = ss->ssl3.mtu - ss->pendingBuf.len; michael@0: } else { michael@0: /* The message will not fit, so fragment. michael@0: * michael@0: * XXX OK for now. Arrange to coalesce the last fragment michael@0: * of this message with the next message if possible. michael@0: * That would be more efficient. michael@0: */ michael@0: PRUint32 fragment_offset = 0; michael@0: unsigned char fragment[DTLS_MAX_MTU]; /* >= than largest michael@0: * plausible MTU */ michael@0: michael@0: /* Assert that we have already flushed */ michael@0: PORT_Assert(room_left == ss->ssl3.mtu); michael@0: michael@0: /* Case 3: We now need to fragment this message michael@0: * DTLS only supports fragmenting handshaking messages */ michael@0: PORT_Assert(msg->type == content_handshake); michael@0: michael@0: /* The headers consume 12 bytes so the smalles possible michael@0: * message (i.e., an empty one) is 12 bytes michael@0: */ michael@0: PORT_Assert(msg->len >= 12); michael@0: michael@0: while ((fragment_offset + 12) < msg->len) { michael@0: PRUint32 fragment_len; michael@0: const unsigned char *content = msg->data + 12; michael@0: PRUint32 content_len = msg->len - 12; michael@0: michael@0: /* The reason we use 8 here is that that's the length of michael@0: * the new DTLS data that we add to the header */ michael@0: fragment_len = PR_MIN(room_left - (SSL3_BUFFER_FUDGE + 8), michael@0: content_len - fragment_offset); michael@0: PORT_Assert(fragment_len < DTLS_MAX_MTU - 12); michael@0: /* Make totally sure that we are within the buffer. michael@0: * Note that the only way that fragment len could get michael@0: * adjusted here is if michael@0: * michael@0: * (a) we are in release mode so the PORT_Assert is compiled out michael@0: * (b) either the MTU table is inconsistent with DTLS_MAX_MTU michael@0: * or ss->ssl3.mtu has become corrupt. michael@0: */ michael@0: fragment_len = PR_MIN(fragment_len, DTLS_MAX_MTU - 12); michael@0: michael@0: /* Construct an appropriate-sized fragment */ michael@0: /* Type, length, sequence */ michael@0: PORT_Memcpy(fragment, msg->data, 6); michael@0: michael@0: /* Offset */ michael@0: fragment[6] = (fragment_offset >> 16) & 0xff; michael@0: fragment[7] = (fragment_offset >> 8) & 0xff; michael@0: fragment[8] = (fragment_offset) & 0xff; michael@0: michael@0: /* Fragment length */ michael@0: fragment[9] = (fragment_len >> 16) & 0xff; michael@0: fragment[10] = (fragment_len >> 8) & 0xff; michael@0: fragment[11] = (fragment_len) & 0xff; michael@0: michael@0: PORT_Memcpy(fragment + 12, content + fragment_offset, michael@0: fragment_len); michael@0: michael@0: /* michael@0: * Send the record. We do this in two stages michael@0: * 1. Encrypt michael@0: */ michael@0: sent = ssl3_SendRecord(ss, msg->epoch, msg->type, michael@0: fragment, fragment_len + 12, michael@0: ssl_SEND_FLAG_FORCE_INTO_BUFFER | michael@0: ssl_SEND_FLAG_USE_EPOCH); michael@0: if (sent != (fragment_len + 12)) { michael@0: rv = SECFailure; michael@0: if (sent != -1) { michael@0: PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); michael@0: } michael@0: break; michael@0: } michael@0: michael@0: /* 2. Flush */ michael@0: rv = dtls_SendSavedWriteData(ss); michael@0: if (rv != SECSuccess) michael@0: break; michael@0: michael@0: fragment_offset += fragment_len; michael@0: } michael@0: } michael@0: } michael@0: michael@0: /* Finally, we need to flush */ michael@0: if (rv == SECSuccess) michael@0: rv = dtls_SendSavedWriteData(ss); michael@0: michael@0: /* Give up the locks */ michael@0: ssl_ReleaseSpecReadLock(ss); michael@0: ssl_ReleaseXmitBufLock(ss); michael@0: michael@0: return rv; michael@0: } michael@0: michael@0: /* Flush the data in the pendingBuf and update the max message sent michael@0: * so we can adjust the MTU estimate if we need to. michael@0: * Wrapper for ssl_SendSavedWriteData. michael@0: * michael@0: * Called from dtls_TransmitMessageFlight() michael@0: */ michael@0: static michael@0: SECStatus dtls_SendSavedWriteData(sslSocket *ss) michael@0: { michael@0: PRInt32 sent; michael@0: michael@0: sent = ssl_SendSavedWriteData(ss); michael@0: if (sent < 0) michael@0: return SECFailure; michael@0: michael@0: /* We should always have complete writes b/c datagram sockets michael@0: * don't really block */ michael@0: if (ss->pendingBuf.len > 0) { michael@0: ssl_MapLowLevelError(SSL_ERROR_SOCKET_WRITE_FAILURE); michael@0: return SECFailure; michael@0: } michael@0: michael@0: /* Update the largest message sent so we can adjust the MTU michael@0: * estimate if necessary */ michael@0: if (sent > ss->ssl3.hs.maxMessageSent) michael@0: ss->ssl3.hs.maxMessageSent = sent; michael@0: michael@0: return SECSuccess; michael@0: } michael@0: michael@0: /* Compress, MAC, encrypt a DTLS record. Allows specification of michael@0: * the epoch using epoch value. If use_epoch is PR_TRUE then michael@0: * we use the provided epoch. If use_epoch is PR_FALSE then michael@0: * whatever the current value is in effect is used. michael@0: * michael@0: * Called from ssl3_SendRecord() michael@0: */ michael@0: SECStatus michael@0: dtls_CompressMACEncryptRecord(sslSocket * ss, michael@0: DTLSEpoch epoch, michael@0: PRBool use_epoch, michael@0: SSL3ContentType type, michael@0: const SSL3Opaque * pIn, michael@0: PRUint32 contentLen, michael@0: sslBuffer * wrBuf) michael@0: { michael@0: SECStatus rv = SECFailure; michael@0: ssl3CipherSpec * cwSpec; michael@0: michael@0: ssl_GetSpecReadLock(ss); /********************************/ michael@0: michael@0: /* The reason for this switch-hitting code is that we might have michael@0: * a flight of records spanning an epoch boundary, e.g., michael@0: * michael@0: * ClientKeyExchange (epoch = 0) michael@0: * ChangeCipherSpec (epoch = 0) michael@0: * Finished (epoch = 1) michael@0: * michael@0: * Thus, each record needs a different cipher spec. The information michael@0: * about which epoch to use is carried with the record. michael@0: */ michael@0: if (use_epoch) { michael@0: if (ss->ssl3.cwSpec->epoch == epoch) michael@0: cwSpec = ss->ssl3.cwSpec; michael@0: else if (ss->ssl3.pwSpec->epoch == epoch) michael@0: cwSpec = ss->ssl3.pwSpec; michael@0: else michael@0: cwSpec = NULL; michael@0: } else { michael@0: cwSpec = ss->ssl3.cwSpec; michael@0: } michael@0: michael@0: if (cwSpec) { michael@0: rv = ssl3_CompressMACEncryptRecord(cwSpec, ss->sec.isServer, PR_TRUE, michael@0: PR_FALSE, type, pIn, contentLen, michael@0: wrBuf); michael@0: } else { michael@0: PR_NOT_REACHED("Couldn't find a cipher spec matching epoch"); michael@0: PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); michael@0: } michael@0: ssl_ReleaseSpecReadLock(ss); /************************************/ michael@0: michael@0: return rv; michael@0: } michael@0: michael@0: /* Start a timer michael@0: * michael@0: * Called from: michael@0: * dtls_HandleHandshake() michael@0: * dtls_FlushHAndshake() michael@0: * dtls_RestartTimer() michael@0: */ michael@0: SECStatus michael@0: dtls_StartTimer(sslSocket *ss, DTLSTimerCb cb) michael@0: { michael@0: PORT_Assert(ss->ssl3.hs.rtTimerCb == NULL); michael@0: michael@0: ss->ssl3.hs.rtTimerStarted = PR_IntervalNow(); michael@0: ss->ssl3.hs.rtTimerCb = cb; michael@0: michael@0: return SECSuccess; michael@0: } michael@0: michael@0: /* Restart a timer with optional backoff michael@0: * michael@0: * Called from dtls_RetransmitTimerExpiredCb() michael@0: */ michael@0: SECStatus michael@0: dtls_RestartTimer(sslSocket *ss, PRBool backoff, DTLSTimerCb cb) michael@0: { michael@0: if (backoff) { michael@0: ss->ssl3.hs.rtTimeoutMs *= 2; michael@0: if (ss->ssl3.hs.rtTimeoutMs > MAX_DTLS_TIMEOUT_MS) michael@0: ss->ssl3.hs.rtTimeoutMs = MAX_DTLS_TIMEOUT_MS; michael@0: } michael@0: michael@0: return dtls_StartTimer(ss, cb); michael@0: } michael@0: michael@0: /* Cancel a pending timer michael@0: * michael@0: * Called from: michael@0: * dtls_HandleHandshake() michael@0: * dtls_CheckTimer() michael@0: */ michael@0: void michael@0: dtls_CancelTimer(sslSocket *ss) michael@0: { michael@0: PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss)); michael@0: michael@0: ss->ssl3.hs.rtTimerCb = NULL; michael@0: } michael@0: michael@0: /* Check the pending timer and fire the callback if it expired michael@0: * michael@0: * Called from ssl3_GatherCompleteHandshake() michael@0: */ michael@0: void michael@0: dtls_CheckTimer(sslSocket *ss) michael@0: { michael@0: if (!ss->ssl3.hs.rtTimerCb) michael@0: return; michael@0: michael@0: if ((PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted) > michael@0: PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs)) { michael@0: /* Timer has expired */ michael@0: DTLSTimerCb cb = ss->ssl3.hs.rtTimerCb; michael@0: michael@0: /* Cancel the timer so that we can call the CB safely */ michael@0: dtls_CancelTimer(ss); michael@0: michael@0: /* Now call the CB */ michael@0: cb(ss); michael@0: } michael@0: } michael@0: michael@0: /* The callback to fire when the holddown timer for the Finished michael@0: * message expires and we can delete it michael@0: * michael@0: * Called from dtls_CheckTimer() michael@0: */ michael@0: void michael@0: dtls_FinishedTimerCb(sslSocket *ss) michael@0: { michael@0: ssl3_DestroyCipherSpec(ss->ssl3.pwSpec, PR_FALSE); michael@0: } michael@0: michael@0: /* Cancel the Finished hold-down timer and destroy the michael@0: * pending cipher spec. Note that this means that michael@0: * successive rehandshakes will fail if the Finished is michael@0: * lost. michael@0: * michael@0: * XXX OK for now. Figure out how to handle the combination michael@0: * of Finished lost and rehandshake michael@0: */ michael@0: void michael@0: dtls_RehandshakeCleanup(sslSocket *ss) michael@0: { michael@0: dtls_CancelTimer(ss); michael@0: ssl3_DestroyCipherSpec(ss->ssl3.pwSpec, PR_FALSE); michael@0: ss->ssl3.hs.sendMessageSeq = 0; michael@0: ss->ssl3.hs.recvMessageSeq = 0; michael@0: } michael@0: michael@0: /* Set the MTU to the next step less than or equal to the michael@0: * advertised value. Also used to downgrade the MTU by michael@0: * doing dtls_SetMTU(ss, biggest packet set). michael@0: * michael@0: * Passing 0 means set this to the largest MTU known michael@0: * (effectively resetting the PMTU backoff value). michael@0: * michael@0: * Called by: michael@0: * ssl3_InitState() michael@0: * dtls_RetransmitTimerExpiredCb() michael@0: */ michael@0: void michael@0: dtls_SetMTU(sslSocket *ss, PRUint16 advertised) michael@0: { michael@0: int i; michael@0: michael@0: if (advertised == 0) { michael@0: ss->ssl3.mtu = COMMON_MTU_VALUES[0]; michael@0: SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu)); michael@0: return; michael@0: } michael@0: michael@0: for (i = 0; i < PR_ARRAY_SIZE(COMMON_MTU_VALUES); i++) { michael@0: if (COMMON_MTU_VALUES[i] <= advertised) { michael@0: ss->ssl3.mtu = COMMON_MTU_VALUES[i]; michael@0: SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu)); michael@0: return; michael@0: } michael@0: } michael@0: michael@0: /* Fallback */ michael@0: ss->ssl3.mtu = COMMON_MTU_VALUES[PR_ARRAY_SIZE(COMMON_MTU_VALUES)-1]; michael@0: SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu)); michael@0: } michael@0: michael@0: /* Called from ssl3_HandleHandshakeMessage() when it has deciphered a michael@0: * DTLS hello_verify_request michael@0: * Caller must hold Handshake and RecvBuf locks. michael@0: */ michael@0: SECStatus michael@0: dtls_HandleHelloVerifyRequest(sslSocket *ss, SSL3Opaque *b, PRUint32 length) michael@0: { michael@0: int errCode = SSL_ERROR_RX_MALFORMED_HELLO_VERIFY_REQUEST; michael@0: SECStatus rv; michael@0: PRInt32 temp; michael@0: SECItem cookie = {siBuffer, NULL, 0}; michael@0: SSL3AlertDescription desc = illegal_parameter; michael@0: michael@0: SSL_TRC(3, ("%d: SSL3[%d]: handle hello_verify_request handshake", michael@0: SSL_GETPID(), ss->fd)); michael@0: PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss)); michael@0: PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); michael@0: michael@0: if (ss->ssl3.hs.ws != wait_server_hello) { michael@0: errCode = SSL_ERROR_RX_UNEXPECTED_HELLO_VERIFY_REQUEST; michael@0: desc = unexpected_message; michael@0: goto alert_loser; michael@0: } michael@0: michael@0: /* The version */ michael@0: temp = ssl3_ConsumeHandshakeNumber(ss, 2, &b, &length); michael@0: if (temp < 0) { michael@0: goto loser; /* alert has been sent */ michael@0: } michael@0: michael@0: if (temp != SSL_LIBRARY_VERSION_DTLS_1_0_WIRE && michael@0: temp != SSL_LIBRARY_VERSION_DTLS_1_2_WIRE) { michael@0: goto alert_loser; michael@0: } michael@0: michael@0: /* The cookie */ michael@0: rv = ssl3_ConsumeHandshakeVariable(ss, &cookie, 1, &b, &length); michael@0: if (rv != SECSuccess) { michael@0: goto loser; /* alert has been sent */ michael@0: } michael@0: if (cookie.len > DTLS_COOKIE_BYTES) { michael@0: desc = decode_error; michael@0: goto alert_loser; /* malformed. */ michael@0: } michael@0: michael@0: PORT_Memcpy(ss->ssl3.hs.cookie, cookie.data, cookie.len); michael@0: ss->ssl3.hs.cookieLen = cookie.len; michael@0: michael@0: michael@0: ssl_GetXmitBufLock(ss); /*******************************/ michael@0: michael@0: /* Now re-send the client hello */ michael@0: rv = ssl3_SendClientHello(ss, PR_TRUE); michael@0: michael@0: ssl_ReleaseXmitBufLock(ss); /*******************************/ michael@0: michael@0: if (rv == SECSuccess) michael@0: return rv; michael@0: michael@0: alert_loser: michael@0: (void)SSL3_SendAlert(ss, alert_fatal, desc); michael@0: michael@0: loser: michael@0: errCode = ssl_MapLowLevelError(errCode); michael@0: return SECFailure; michael@0: } michael@0: michael@0: /* Initialize the DTLS anti-replay window michael@0: * michael@0: * Called from: michael@0: * ssl3_SetupPendingCipherSpec() michael@0: * ssl3_InitCipherSpec() michael@0: */ michael@0: void michael@0: dtls_InitRecvdRecords(DTLSRecvdRecords *records) michael@0: { michael@0: PORT_Memset(records->data, 0, sizeof(records->data)); michael@0: records->left = 0; michael@0: records->right = DTLS_RECVD_RECORDS_WINDOW - 1; michael@0: } michael@0: michael@0: /* michael@0: * Has this DTLS record been received? Return values are: michael@0: * -1 -- out of range to the left michael@0: * 0 -- not received yet michael@0: * 1 -- replay michael@0: * michael@0: * Called from: dtls_HandleRecord() michael@0: */ michael@0: int michael@0: dtls_RecordGetRecvd(DTLSRecvdRecords *records, PRUint64 seq) michael@0: { michael@0: PRUint64 offset; michael@0: michael@0: /* Out of range to the left */ michael@0: if (seq < records->left) { michael@0: return -1; michael@0: } michael@0: michael@0: /* Out of range to the right; since we advance the window on michael@0: * receipt, that means that this packet has not been received michael@0: * yet */ michael@0: if (seq > records->right) michael@0: return 0; michael@0: michael@0: offset = seq % DTLS_RECVD_RECORDS_WINDOW; michael@0: michael@0: return !!(records->data[offset / 8] & (1 << (offset % 8))); michael@0: } michael@0: michael@0: /* Update the DTLS anti-replay window michael@0: * michael@0: * Called from ssl3_HandleRecord() michael@0: */ michael@0: void michael@0: dtls_RecordSetRecvd(DTLSRecvdRecords *records, PRUint64 seq) michael@0: { michael@0: PRUint64 offset; michael@0: michael@0: if (seq < records->left) michael@0: return; michael@0: michael@0: if (seq > records->right) { michael@0: PRUint64 new_left; michael@0: PRUint64 new_right; michael@0: PRUint64 right; michael@0: michael@0: /* Slide to the right; this is the tricky part michael@0: * michael@0: * 1. new_top is set to have room for seq, on the michael@0: * next byte boundary by setting the right 8 michael@0: * bits of seq michael@0: * 2. new_left is set to compensate. michael@0: * 3. Zero all bits between top and new_top. Since michael@0: * this is a ring, this zeroes everything as-yet michael@0: * unseen. Because we always operate on byte michael@0: * boundaries, we can zero one byte at a time michael@0: */ michael@0: new_right = seq | 0x07; michael@0: new_left = (new_right - DTLS_RECVD_RECORDS_WINDOW) + 1; michael@0: michael@0: for (right = records->right + 8; right <= new_right; right += 8) { michael@0: offset = right % DTLS_RECVD_RECORDS_WINDOW; michael@0: records->data[offset / 8] = 0; michael@0: } michael@0: michael@0: records->right = new_right; michael@0: records->left = new_left; michael@0: } michael@0: michael@0: offset = seq % DTLS_RECVD_RECORDS_WINDOW; michael@0: michael@0: records->data[offset / 8] |= (1 << (offset % 8)); michael@0: } michael@0: michael@0: SECStatus michael@0: DTLS_GetHandshakeTimeout(PRFileDesc *socket, PRIntervalTime *timeout) michael@0: { michael@0: sslSocket * ss = NULL; michael@0: PRIntervalTime elapsed; michael@0: PRIntervalTime desired; michael@0: michael@0: ss = ssl_FindSocket(socket); michael@0: michael@0: if (!ss) michael@0: return SECFailure; michael@0: michael@0: if (!IS_DTLS(ss)) michael@0: return SECFailure; michael@0: michael@0: if (!ss->ssl3.hs.rtTimerCb) michael@0: return SECFailure; michael@0: michael@0: elapsed = PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted; michael@0: desired = PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs); michael@0: if (elapsed > desired) { michael@0: /* Timer expired */ michael@0: *timeout = PR_INTERVAL_NO_WAIT; michael@0: } else { michael@0: *timeout = desired - elapsed; michael@0: } michael@0: michael@0: return SECSuccess; michael@0: }