Commit 7171f71e authored by Matt Caswell's avatar Matt Caswell
Browse files

Mark DTLS records as read when we have finished with them



The TLS code marks records as read when its finished using a record. The DTLS code did
not do that. However SSL_has_pending() relies on it. So we should make DTLS consistent.

Reviewed-by: default avatarRich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/6160)
parent f1e92861
Loading
Loading
Loading
Loading
+32 −5
Original line number Diff line number Diff line
@@ -473,6 +473,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
            return -1;
        }
        SSL3_RECORD_set_length(rr, 0);
        SSL3_RECORD_set_read(rr);
        goto start;
    }

@@ -482,8 +483,9 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
     */
    if (s->shutdown & SSL_RECEIVED_SHUTDOWN) {
        SSL3_RECORD_set_length(rr, 0);
        SSL3_RECORD_set_read(rr);
        s->rwstate = SSL_NOTHING;
        return (0);
        return 0;
    }

    if (type == SSL3_RECORD_get_type(rr)
@@ -508,8 +510,16 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
        if (recvd_type != NULL)
            *recvd_type = SSL3_RECORD_get_type(rr);

        if (len <= 0)
            return (len);
        if (len <= 0) {
            /*
             * Mark a zero length record as read. This ensures multiple calls to
             * SSL_read() with a zero length buffer will eventually cause
             * SSL_pending() to report data as being available.
             */
            if (SSL3_RECORD_get_length(rr) == 0)
                SSL3_RECORD_set_read(rr);
            return len;
        }

        if ((unsigned int)len > SSL3_RECORD_get_length(rr))
            n = SSL3_RECORD_get_length(rr);
@@ -517,12 +527,16 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
            n = (unsigned int)len;

        memcpy(buf, &(SSL3_RECORD_get_data(rr)[SSL3_RECORD_get_off(rr)]), n);
        if (!peek) {
        if (peek) {
            if (SSL3_RECORD_get_length(rr) == 0)
                SSL3_RECORD_set_read(rr);
        } else {
            SSL3_RECORD_sub_length(rr, n);
            SSL3_RECORD_add_off(rr, n);
            if (SSL3_RECORD_get_length(rr) == 0) {
                s->rlayer.rstate = SSL_ST_READ_HEADER;
                SSL3_RECORD_set_off(rr, 0);
                SSL3_RECORD_set_read(rr);
            }
        }
#ifndef OPENSSL_NO_SCTP
@@ -573,6 +587,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
            }
            /* Exit and notify application to read again */
            SSL3_RECORD_set_length(rr, 0);
            SSL3_RECORD_set_read(rr);
            s->rwstate = SSL_READING;
            BIO_clear_retry_flags(SSL_get_rbio(s));
            BIO_set_retry_read(SSL_get_rbio(s));
@@ -617,6 +632,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
#endif
                s->rlayer.rstate = SSL_ST_READ_HEADER;
                SSL3_RECORD_set_length(rr, 0);
                SSL3_RECORD_set_read(rr);
                goto start;
            }

@@ -626,6 +642,8 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
                SSL3_RECORD_add_off(rr, 1);
                SSL3_RECORD_add_length(rr, -1);
            }
            if (SSL3_RECORD_get_length(rr) == 0)
                SSL3_RECORD_set_read(rr);
            *dest_len = dest_maxlen;
        }
    }
@@ -696,6 +714,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
            }
        } else {
            SSL3_RECORD_set_length(rr, 0);
            SSL3_RECORD_set_read(rr);
            ssl3_send_alert(s, SSL3_AL_WARNING, SSL_AD_NO_RENEGOTIATION);
        }
        /*
@@ -720,6 +739,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
                || (s->options & SSL_OP_NO_RENEGOTIATION) != 0)) {
        s->rlayer.d->handshake_fragment_len = 0;
        SSL3_RECORD_set_length(rr, 0);
        SSL3_RECORD_set_read(rr);
        ssl3_send_alert(s, SSL3_AL_WARNING, SSL_AD_NO_RENEGOTIATION);
        goto start;
    }
@@ -747,6 +767,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,

        if (alert_level == SSL3_AL_WARNING) {
            s->s3->warn_alert = alert_descr;
            SSL3_RECORD_set_read(rr);

            s->rlayer.alert_count++;
            if (s->rlayer.alert_count == MAX_WARN_ALERT_COUNT) {
@@ -811,6 +832,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
            BIO_snprintf(tmp, sizeof(tmp), "%d", alert_descr);
            ERR_add_error_data(2, "SSL alert number ", tmp);
            s->shutdown |= SSL_RECEIVED_SHUTDOWN;
            SSL3_RECORD_set_read(rr);
            SSL_CTX_remove_session(s->session_ctx, s->session);
            return (0);
        } else {
@@ -826,7 +848,8 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
                                            * shutdown */
        s->rwstate = SSL_NOTHING;
        SSL3_RECORD_set_length(rr, 0);
        return (0);
        SSL3_RECORD_set_read(rr);
        return 0;
    }

    if (SSL3_RECORD_get_type(rr) == SSL3_RT_CHANGE_CIPHER_SPEC) {
@@ -835,6 +858,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
         * are still missing, so just drop it.
         */
        SSL3_RECORD_set_length(rr, 0);
        SSL3_RECORD_set_read(rr);
        goto start;
    }

@@ -849,6 +873,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
        dtls1_get_message_header(rr->data, &msg_hdr);
        if (SSL3_RECORD_get_epoch(rr) != s->rlayer.d->r_epoch) {
            SSL3_RECORD_set_length(rr, 0);
            SSL3_RECORD_set_read(rr);
            goto start;
        }

@@ -862,6 +887,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,

            dtls1_retransmit_buffered_messages(s);
            SSL3_RECORD_set_length(rr, 0);
            SSL3_RECORD_set_read(rr);
            if (!(s->mode & SSL_MODE_AUTO_RETRY)) {
                if (SSL3_BUFFER_get_left(&s->rlayer.rbuf) == 0) {
                    /* no read-ahead left? */
@@ -916,6 +942,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
        /* TLS just ignores unknown message types */
        if (s->version == TLS1_VERSION) {
            SSL3_RECORD_set_length(rr, 0);
            SSL3_RECORD_set_read(rr);
            goto start;
        }
        al = SSL_AD_UNEXPECTED_MESSAGE;
+12 −2
Original line number Diff line number Diff line
@@ -1531,6 +1531,7 @@ int dtls1_get_record(SSL *s)
        p += 6;

        n2s(p, rr->length);
        rr->read = 0;

        /*
         * Lets check the version. We tolerate alerts that don't have the exact
@@ -1540,6 +1541,7 @@ int dtls1_get_record(SSL *s)
            if (version != s->version) {
                /* unexpected version, silently discard */
                rr->length = 0;
                rr->read = 1;
                RECORD_LAYER_reset_packet_length(&s->rlayer);
                goto again;
            }
@@ -1548,6 +1550,7 @@ int dtls1_get_record(SSL *s)
        if ((version & 0xff00) != (s->version & 0xff00)) {
            /* wrong version, silently discard record */
            rr->length = 0;
            rr->read = 1;
            RECORD_LAYER_reset_packet_length(&s->rlayer);
            goto again;
        }
@@ -1555,10 +1558,10 @@ int dtls1_get_record(SSL *s)
        if (rr->length > SSL3_RT_MAX_ENCRYPTED_LENGTH) {
            /* record too long, silently discard it */
            rr->length = 0;
            rr->read = 1;
            RECORD_LAYER_reset_packet_length(&s->rlayer);
            goto again;
        }

        /* now s->rlayer.rstate == SSL_ST_READ_BODY */
    }

@@ -1572,6 +1575,7 @@ int dtls1_get_record(SSL *s)
        /* this packet contained a partial record, dump it */
        if (n != i) {
            rr->length = 0;
            rr->read = 1;
            RECORD_LAYER_reset_packet_length(&s->rlayer);
            goto again;
        }
@@ -1588,6 +1592,7 @@ int dtls1_get_record(SSL *s)
    bitmap = dtls1_get_bitmap(s, rr, &is_next_epoch);
    if (bitmap == NULL) {
        rr->length = 0;
        rr->read = 1;
        RECORD_LAYER_reset_packet_length(&s->rlayer); /* dump this record */
        goto again;             /* get another record */
    }
@@ -1602,6 +1607,7 @@ int dtls1_get_record(SSL *s)
         */
        if (!dtls1_record_replay_check(s, bitmap)) {
            rr->length = 0;
            rr->read = 1;
            RECORD_LAYER_reset_packet_length(&s->rlayer); /* dump this record */
            goto again;         /* get another record */
        }
@@ -1610,8 +1616,10 @@ int dtls1_get_record(SSL *s)
#endif

    /* just read a 0 length packet */
    if (rr->length == 0)
    if (rr->length == 0) {
        rr->read = 1;
        goto again;
    }

    /*
     * If this record is from the next epoch (either HM or ALERT), and a
@@ -1626,12 +1634,14 @@ int dtls1_get_record(SSL *s)
                return -1;
        }
        rr->length = 0;
        rr->read = 1;
        RECORD_LAYER_reset_packet_length(&s->rlayer);
        goto again;
    }

    if (!dtls1_process_record(s, bitmap)) {
        rr->length = 0;
        rr->read = 1;
        RECORD_LAYER_reset_packet_length(&s->rlayer); /* dump this record */
        goto again;             /* get another record */
    }