Commit 76af3037 authored by Matt Caswell's avatar Matt Caswell
Browse files

dtls_get_message changes for state machine move



Create a dtls_get_message function similar to the old dtls1_get_message but
in the format required for the new state machine code. The old function will
eventually be deleted in later commits.

Reviewed-by: default avatarTim Hudson <tjh@openssl.org>
Reviewed-by: default avatarRichard Levitte <levitte@openssl.org>
parent f6a2f2da
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -1928,6 +1928,7 @@ void ERR_load_SSL_strings(void);
# define SSL_F_DTLS1_SEND_SERVER_HELLO                    266
# define SSL_F_DTLS1_SEND_SERVER_KEY_EXCHANGE             267
# define SSL_F_DTLS1_WRITE_APP_DATA_BYTES                 268
# define SSL_F_DTLS_GET_REASSEMBLED_MESSAGE               370
# define SSL_F_READ_STATE_MACHINE                         352
# define SSL_F_SSL3_ACCEPT                                128
# define SSL_F_SSL3_ADD_CERT_TO_BUF                       296
+142 −67
Original line number Diff line number Diff line
@@ -161,7 +161,8 @@ static void dtls1_set_message_header_int(SSL *s, unsigned char mt,
                                         unsigned long frag_off,
                                         unsigned long frag_len);
static long dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt,
                                       long max, int *ok);
                                       int *ok);
static int dtls_get_reassembled_message(SSL *s, long *len);

static hm_fragment *dtls1_hm_fragment_new(unsigned long frag_len,
                                          int reassembly)
@@ -481,7 +482,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
    memset(msg_hdr, 0, sizeof(*msg_hdr));

 again:
    i = dtls1_get_message_fragment(s, st1, stn, mt, max, ok);
    i = dtls1_get_message_fragment(s, st1, stn, mt, ok);
    if (i == DTLS1_HM_BAD_FRAGMENT || i == DTLS1_HM_FRAGMENT_RETRY) {
        /* bad fragment received */
        goto again;
@@ -523,6 +524,12 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
        msg_len += DTLS1_HM_HEADER_LENGTH;
    }

    if (msg_len > (unsigned long)max) {
        al = SSL_AD_ILLEGAL_PARAMETER;
        SSLerr(SSL_F_DTLS1_GET_MESSAGE, SSL_R_EXCESSIVE_MESSAGE_SIZE);
        goto f_err;
    }

    ssl3_finish_mac(s, p, msg_len);
    if (s->msg_callback)
        s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE,
@@ -542,8 +549,72 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
    return -1;
}

static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr,
                                     int max)
int dtls_get_message(SSL *s, int *mt, unsigned long *len)
{
    struct hm_header_st *msg_hdr;
    unsigned char *p;
    unsigned long msg_len;
    int ok;
    long tmplen;

    msg_hdr = &s->d1->r_msg_hdr;
    memset(msg_hdr, 0, sizeof(*msg_hdr));

 again:
    ok = dtls_get_reassembled_message(s, &tmplen);
    if (tmplen == DTLS1_HM_BAD_FRAGMENT
        || tmplen == DTLS1_HM_FRAGMENT_RETRY) {
        /* bad fragment received */
        goto again;
    } else if (tmplen <= 0 && !ok) {
        return 0;
    }

    *mt = s->s3->tmp.message_type;

    p = (unsigned char *)s->init_buf->data;

    if (*mt == SSL3_MT_CHANGE_CIPHER_SPEC) {
        if (s->msg_callback) {
            s->msg_callback(0, s->version, SSL3_RT_CHANGE_CIPHER_SPEC,
                            p, 1, s, s->msg_callback_arg);
        }
        /*
         * This isn't a real handshake message so skip the processing below.
         */
        return 1;
    }

    msg_len = msg_hdr->msg_len;

    /* reconstruct message header */
    *(p++) = msg_hdr->type;
    l2n3(msg_len, p);
    s2n(msg_hdr->seq, p);
    l2n3(0, p);
    l2n3(msg_len, p);
    if (s->version != DTLS1_BAD_VER) {
        p -= DTLS1_HM_HEADER_LENGTH;
        msg_len += DTLS1_HM_HEADER_LENGTH;
    }

    ssl3_finish_mac(s, p, msg_len);
    if (s->msg_callback)
        s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE,
                        p, msg_len, s, s->msg_callback_arg);

    memset(msg_hdr, 0, sizeof(*msg_hdr));

    s->d1->handshake_read_seq++;


    s->init_msg = s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
    *len = s->init_num;

    return 1;
}

static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr)
{
    size_t frag_off, frag_len, msg_len;

@@ -557,11 +628,6 @@ static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr,
        return SSL_AD_ILLEGAL_PARAMETER;
    }

    if ((frag_off + frag_len) > (unsigned long)max) {
        SSLerr(SSL_F_DTLS1_PREPROCESS_FRAGMENT, SSL_R_EXCESSIVE_MESSAGE_SIZE);
        return SSL_AD_ILLEGAL_PARAMETER;
    }

    if (s->d1->r_msg_hdr.frag_off == 0) { /* first fragment */
        /*
         * msg_len is limited to 2^24, but is effectively checked against max
@@ -590,7 +656,7 @@ static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr,
    return 0;                   /* no error */
}

static int dtls1_retrieve_buffered_fragment(SSL *s, long max, int *ok)
static int dtls1_retrieve_buffered_fragment(SSL *s, int *ok)
{
    /*-
     * (0) check whether the desired fragment is available
@@ -617,7 +683,7 @@ static int dtls1_retrieve_buffered_fragment(SSL *s, long max, int *ok)
        unsigned long frag_len = frag->msg_header.frag_len;
        pqueue_pop(s->d1->buffered_messages);

        al = dtls1_preprocess_fragment(s, &frag->msg_header, max);
        al = dtls1_preprocess_fragment(s, &frag->msg_header);

        if (al == 0) {          /* no alert */
            unsigned char *p =
@@ -859,19 +925,44 @@ dtls1_process_out_of_seq_message(SSL *s, const struct hm_header_st *msg_hdr,
}

static long
dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, int *ok)
{
    long len;

    do {
        *ok = dtls_get_reassembled_message(s, &len);
        /* A CCS isn't a real handshake message, so if we get one there is no
         * message sequence number to give us confidence that this was really
         * intended to be at this point in the handshake sequence. Therefore we
         * only allow this if we were explicitly looking for it (i.e. if |mt|
         * is -1 we still don't allow it). If we get one when we're not
         * expecting it then probably something got re-ordered or this is a
         * retransmit. We should drop this and try again.
         */
    } while (*ok && mt != SSL3_MT_CHANGE_CIPHER_SPEC
             && s->s3->tmp.message_type == SSL3_MT_CHANGE_CIPHER_SPEC);

    if (*ok)
        s->state = stn;

    return len;
}

static int dtls_get_reassembled_message(SSL *s, long *len)
{
    unsigned char wire[DTLS1_HM_HEADER_LENGTH];
    unsigned long len, frag_off, frag_len;
    unsigned long mlen, frag_off, frag_len;
    int i, al, recvd_type;
    struct hm_header_st msg_hdr;
    int ok;

 redo:
    /* see if we have the required fragment already */
    if ((frag_len = dtls1_retrieve_buffered_fragment(s, max, ok)) || *ok) {
        if (*ok)
    if ((frag_len = dtls1_retrieve_buffered_fragment(s, &ok)) || ok) {
        if (ok)
            s->init_num = frag_len;
        return frag_len;
        *len = frag_len;
        return ok;
    }

    /* read handshake message header */
@@ -879,20 +970,14 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
                                  DTLS1_HM_HEADER_LENGTH, 0);
    if (i <= 0) {               /* nbio, or an error */
        s->rwstate = SSL_READING;
        *ok = 0;
        return i;
        *len = i;
        return 0;
    }
    if(recvd_type == SSL3_RT_CHANGE_CIPHER_SPEC) {
        /* This isn't a real handshake message - its a CCS.
         * There is no message sequence number in a CCS to give us confidence
         * that this was really intended to be at this point in the handshake
         * sequence. Therefore we only allow this if we were explicitly looking
         * for it (i.e. if |mt| is -1 we still don't allow it).
         */
        if(mt == SSL3_MT_CHANGE_CIPHER_SPEC) {
        if (wire[0] != SSL3_MT_CCS) {
            al = SSL_AD_UNEXPECTED_MESSAGE;
                SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT, SSL_R_BAD_CHANGE_CIPHER_SPEC);
            SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
                   SSL_R_BAD_CHANGE_CIPHER_SPEC);
            goto f_err;
        }

@@ -901,31 +986,21 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
        s->init_msg = s->init_buf->data + 1;
        s->s3->tmp.message_type = SSL3_MT_CHANGE_CIPHER_SPEC;
        s->s3->tmp.message_size = i - 1;
            s->state = stn;
            *ok = 1;
            return i-1;
        } else {
            /*
             * We weren't expecting a CCS yet. Probably something got
             * re-ordered or this is a retransmit. We should drop this and try
             * again.
             */
            s->init_num = 0;
            goto redo;
        }
        *len = i - 1;
        return 1;
    }

    /* Handshake fails if message header is incomplete */
    if (i != DTLS1_HM_HEADER_LENGTH) {
        al = SSL_AD_UNEXPECTED_MESSAGE;
        SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT, SSL_R_UNEXPECTED_MESSAGE);
        SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_UNEXPECTED_MESSAGE);
        goto f_err;
    }

    /* parse the message fragment header */
    dtls1_get_message_header(wire, &msg_hdr);

    len = msg_hdr.msg_len;
    mlen = msg_hdr.msg_len;
    frag_off = msg_hdr.frag_off;
    frag_len = msg_hdr.frag_len;

@@ -935,7 +1010,7 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
     */
    if (frag_len > RECORD_LAYER_get_rrec_length(&s->rlayer)) {
        al = SSL3_AD_ILLEGAL_PARAMETER;
        SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT, SSL_R_BAD_LENGTH);
        SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_BAD_LENGTH);
        goto f_err;
    }

@@ -945,11 +1020,15 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
     * While listening, we accept seq 1 (ClientHello with cookie)
     * although we're still expecting seq 0 (ClientHello)
     */
    if (msg_hdr.seq != s->d1->handshake_read_seq)
        return dtls1_process_out_of_seq_message(s, &msg_hdr, ok);
    if (msg_hdr.seq != s->d1->handshake_read_seq) {
        *len = dtls1_process_out_of_seq_message(s, &msg_hdr, &ok);
        return ok;
    }

    if (frag_len && frag_len < len)
        return dtls1_reassemble_fragment(s, &msg_hdr, ok);
    if (frag_len && frag_len < mlen) {
        *len = dtls1_reassemble_fragment(s, &msg_hdr, &ok);
        return ok;
    }

    if (!s->server && s->d1->r_msg_hdr.frag_off == 0 &&
        wire[0] == SSL3_MT_HELLO_REQUEST) {
@@ -969,13 +1048,13 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
        } else {                /* Incorrectly formated Hello request */

            al = SSL_AD_UNEXPECTED_MESSAGE;
            SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT,
            SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
                   SSL_R_UNEXPECTED_MESSAGE);
            goto f_err;
        }
    }

    if ((al = dtls1_preprocess_fragment(s, &msg_hdr, max)))
    if ((al = dtls1_preprocess_fragment(s, &msg_hdr)))
        goto f_err;

    if (frag_len > 0) {
@@ -991,8 +1070,8 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
         */
        if (i <= 0) {
            s->rwstate = SSL_READING;
            *ok = 0;
            return i;
            *len = i;
            return 0;
        }
    } else
        i = 0;
@@ -1003,28 +1082,24 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
     */
    if (i != (int)frag_len) {
        al = SSL3_AD_ILLEGAL_PARAMETER;
        SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT, SSL3_AD_ILLEGAL_PARAMETER);
        SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL3_AD_ILLEGAL_PARAMETER);
        goto f_err;
    }

    *ok = 1;
    s->state = stn;

    /*
     * Note that s->init_num is *not* used as current offset in
     * s->init_buf->data, but as a counter summing up fragments' lengths: as
     * soon as they sum up to handshake packet length, we assume we have got
     * all the fragments.
     */
    s->init_num = frag_len;
    return frag_len;
    *len = s->init_num = frag_len;
    return 1;

 f_err:
    ssl3_send_alert(s, SSL3_AL_FATAL, al);
    s->init_num = 0;

    *ok = 0;
    return (-1);
    *len = -1;
    return 0;
}

/*-
+2 −0
Original line number Diff line number Diff line
@@ -112,6 +112,8 @@ static ERR_STRING_DATA SSL_str_functs[] = {
    {ERR_FUNC(SSL_F_DTLS1_SEND_SERVER_KEY_EXCHANGE),
     "dtls1_send_server_key_exchange"},
    {ERR_FUNC(SSL_F_DTLS1_WRITE_APP_DATA_BYTES), "dtls1_write_app_data_bytes"},
    {ERR_FUNC(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE),
     "DTLS_GET_REASSEMBLED_MESSAGE"},
    {ERR_FUNC(SSL_F_READ_STATE_MACHINE), "READ_STATE_MACHINE"},
    {ERR_FUNC(SSL_F_SSL3_ACCEPT), "ssl3_accept"},
    {ERR_FUNC(SSL_F_SSL3_ADD_CERT_TO_BUF), "SSL3_ADD_CERT_TO_BUF"},
+1 −0
Original line number Diff line number Diff line
@@ -2236,6 +2236,7 @@ long dtls1_ctrl(SSL *s, int cmd, long larg, void *parg);
__owur int dtls1_shutdown(SSL *s);

__owur long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok);
__owur int dtls_get_message(SSL *s, int *mt, unsigned long *len);
__owur int dtls1_dispatch_alert(SSL *s);

__owur int ssl_init_wbio_buffer(SSL *s, int push);
+8 −1
Original line number Diff line number Diff line
@@ -464,7 +464,14 @@ static enum SUB_STATE_RETURN read_state_machine(SSL *s) {
        case READ_STATE_HEADER:
            s->init_num = 0;
            /* Get the state the peer wants to move to */
            if (SSL_IS_DTLS(s)) {
                /*
                 * In DTLS we get the whole message in one go - header and body
                 */
                ret = dtls_get_message(s, &mt, &len);
            } else {
                ret = tls_get_message_header(s, &mt);
            }

            if (ret == 0) {
                /* Could be non-blocking IO */