Commit 66fab923 authored by Matt Caswell's avatar Matt Caswell
Browse files

Mark DTLS records as read when we have finished with them



The TLS code marks records as read when its finished using a record. The DTLS code did
not do that. However SSL_has_pending() relies on it. So we should make DTLS consistent.

Reviewed-by: default avatarRich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/6159)
parent 0d8da779
Loading
Loading
Loading
Loading
+23 −4
Original line number Diff line number Diff line
@@ -363,8 +363,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
        return -1;
    }

    if (!ossl_statem_get_in_handshake(s) && SSL_in_init(s))
    {
    if (!ossl_statem_get_in_handshake(s) && SSL_in_init(s)) {
        /* type == SSL3_RT_APPLICATION_DATA */
        i = s->handshake_func(s);
        /* SSLfatal() already called if appropriate */
@@ -473,6 +472,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
            return -1;
        }
        SSL3_RECORD_set_length(rr, 0);
        SSL3_RECORD_set_read(rr);
        goto start;
    }

@@ -482,6 +482,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
     */
    if (s->shutdown & SSL_RECEIVED_SHUTDOWN) {
        SSL3_RECORD_set_length(rr, 0);
        SSL3_RECORD_set_read(rr);
        s->rwstate = SSL_NOTHING;
        return 0;
    }
@@ -508,8 +509,16 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
        if (recvd_type != NULL)
            *recvd_type = SSL3_RECORD_get_type(rr);

        if (len == 0)
        if (len == 0) {
            /*
             * Mark a zero length record as read. This ensures multiple calls to
             * SSL_read() with a zero length buffer will eventually cause
             * SSL_pending() to report data as being available.
             */
            if (SSL3_RECORD_get_length(rr) == 0)
                SSL3_RECORD_set_read(rr);
            return 0;
        }

        if (len > SSL3_RECORD_get_length(rr))
            n = SSL3_RECORD_get_length(rr);
@@ -517,12 +526,16 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
            n = len;

        memcpy(buf, &(SSL3_RECORD_get_data(rr)[SSL3_RECORD_get_off(rr)]), n);
        if (!peek) {
        if (peek) {
            if (SSL3_RECORD_get_length(rr) == 0)
                SSL3_RECORD_set_read(rr);
        } else {
            SSL3_RECORD_sub_length(rr, n);
            SSL3_RECORD_add_off(rr, n);
            if (SSL3_RECORD_get_length(rr) == 0) {
                s->rlayer.rstate = SSL_ST_READ_HEADER;
                SSL3_RECORD_set_off(rr, 0);
                SSL3_RECORD_set_read(rr);
            }
        }
#ifndef OPENSSL_NO_SCTP
@@ -578,6 +591,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,

        if (alert_level == SSL3_AL_WARNING) {
            s->s3->warn_alert = alert_descr;
            SSL3_RECORD_set_read(rr);

            s->rlayer.alert_count++;
            if (s->rlayer.alert_count == MAX_WARN_ALERT_COUNT) {
@@ -615,6 +629,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
            BIO_snprintf(tmp, sizeof tmp, "%d", alert_descr);
            ERR_add_error_data(2, "SSL alert number ", tmp);
            s->shutdown |= SSL_RECEIVED_SHUTDOWN;
            SSL3_RECORD_set_read(rr);
            SSL_CTX_remove_session(s->session_ctx, s->session);
            return 0;
        } else {
@@ -630,6 +645,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
                                            * shutdown */
        s->rwstate = SSL_NOTHING;
        SSL3_RECORD_set_length(rr, 0);
        SSL3_RECORD_set_read(rr);
        return 0;
    }

@@ -639,6 +655,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
         * are still missing, so just drop it.
         */
        SSL3_RECORD_set_length(rr, 0);
        SSL3_RECORD_set_read(rr);
        goto start;
    }

@@ -656,6 +673,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
        if (SSL3_RECORD_get_epoch(rr) != s->rlayer.d->r_epoch
                || SSL3_RECORD_get_length(rr) < DTLS1_HM_HEADER_LENGTH) {
            SSL3_RECORD_set_length(rr, 0);
            SSL3_RECORD_set_read(rr);
            goto start;
        }

@@ -677,6 +695,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
                    return -1;
            }
            SSL3_RECORD_set_length(rr, 0);
            SSL3_RECORD_set_read(rr);
            if (!(s->mode & SSL_MODE_AUTO_RETRY)) {
                if (SSL3_BUFFER_get_left(&s->rlayer.rbuf) == 0) {
                    /* no read-ahead left? */
+12 −1
Original line number Diff line number Diff line
@@ -1882,6 +1882,7 @@ int dtls1_get_record(SSL *s)
        p += 6;

        n2s(p, rr->length);
        rr->read = 0;

        /*
         * Lets check the version. We tolerate alerts that don't have the exact
@@ -1891,6 +1892,7 @@ int dtls1_get_record(SSL *s)
            if (version != s->version) {
                /* unexpected version, silently discard */
                rr->length = 0;
                rr->read = 1;
                RECORD_LAYER_reset_packet_length(&s->rlayer);
                goto again;
            }
@@ -1899,6 +1901,7 @@ int dtls1_get_record(SSL *s)
        if ((version & 0xff00) != (s->version & 0xff00)) {
            /* wrong version, silently discard record */
            rr->length = 0;
            rr->read = 1;
            RECORD_LAYER_reset_packet_length(&s->rlayer);
            goto again;
        }
@@ -1906,6 +1909,7 @@ int dtls1_get_record(SSL *s)
        if (rr->length > SSL3_RT_MAX_ENCRYPTED_LENGTH) {
            /* record too long, silently discard it */
            rr->length = 0;
            rr->read = 1;
            RECORD_LAYER_reset_packet_length(&s->rlayer);
            goto again;
        }
@@ -1915,6 +1919,7 @@ int dtls1_get_record(SSL *s)
                && rr->length > GET_MAX_FRAGMENT_LENGTH(s->session)) {
            /* record too long, silently discard it */
            rr->length = 0;
            rr->read = 1;
            RECORD_LAYER_reset_packet_length(&s->rlayer);
            goto again;
        }
@@ -1936,6 +1941,7 @@ int dtls1_get_record(SSL *s)
                return -1;
            }
            rr->length = 0;
            rr->read = 1;
            RECORD_LAYER_reset_packet_length(&s->rlayer);
            goto again;
        }
@@ -1966,6 +1972,7 @@ int dtls1_get_record(SSL *s)
         */
        if (!dtls1_record_replay_check(s, bitmap)) {
            rr->length = 0;
            rr->read = 1;
            RECORD_LAYER_reset_packet_length(&s->rlayer); /* dump this record */
            goto again;         /* get another record */
        }
@@ -1974,8 +1981,10 @@ int dtls1_get_record(SSL *s)
#endif

    /* just read a 0 length packet */
    if (rr->length == 0)
    if (rr->length == 0) {
        rr->read = 1;
        goto again;
    }

    /*
     * If this record is from the next epoch (either HM or ALERT), and a
@@ -1992,6 +2001,7 @@ int dtls1_get_record(SSL *s)
            }
        }
        rr->length = 0;
        rr->read = 1;
        RECORD_LAYER_reset_packet_length(&s->rlayer);
        goto again;
    }
@@ -2002,6 +2012,7 @@ int dtls1_get_record(SSL *s)
            return -1;
        }
        rr->length = 0;
        rr->read = 1;
        RECORD_LAYER_reset_packet_length(&s->rlayer); /* dump this record */
        goto again;             /* get another record */
    }