Commit 9ab930b2 authored by Matt Caswell's avatar Matt Caswell
Browse files

Split ssl3_get_message



The function ssl3_get_message gets a whole message from the underlying bio
and returns it to the state machine code. The new state machine code will
split this into two discrete steps: get the message header and get the
message body. This commit splits the existing function into these two
sub steps to facilitate the state machine implementation.

Reviewed-by: default avatarTim Hudson <tjh@openssl.org>
Reviewed-by: default avatarRichard Levitte <levitte@openssl.org>
parent 94b3664a
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -2099,6 +2099,8 @@ void ERR_load_SSL_strings(void);
# define SSL_F_TLS1_PROCESS_HEARTBEAT                     341
# define SSL_F_TLS1_SETUP_KEY_BLOCK                       211
# define SSL_F_TLS1_SET_SERVER_SIGALGS                    335
# define SSL_F_TLS_GET_MESSAGE_BODY                       351
# define SSL_F_TLS_GET_MESSAGE_HEADER                     350
# define SSL_F_USE_CERTIFICATE_CHAIN_FILE                 220

/* Reason codes. */
+163 −113
Original line number Diff line number Diff line
@@ -411,9 +411,8 @@ unsigned long ssl3_output_cert_chain(SSL *s, CERT_PKEY *cpk)
long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
{
    unsigned char *p;
    unsigned long l;
    long n;
    int i, al, recvd_type;
    int al, mtin;

    if (s->s3->tmp.reuse_message) {
        s->s3->tmp.reuse_message = 0;
@@ -432,20 +431,14 @@ long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
    p = (unsigned char *)s->init_buf->data;

    if (s->state == st1) {
        /* s->init_num < SSL3_HM_HEADER_LENGTH */
        int skip_message;

        do {
            while (s->init_num < SSL3_HM_HEADER_LENGTH) {
                i = s->method->ssl_read_bytes(s, SSL3_RT_HANDSHAKE, &recvd_type,
                    &p[s->init_num], SSL3_HM_HEADER_LENGTH - s->init_num, 0);
                if (i <= 0) {
                    s->rwstate = SSL_READING;
        if (tls_get_message_header(s, &mtin) == 0) {
            /* Could be NBIO */
            *ok = 0;
                    return i;
            return -1;
        }
        s->state = stn;
        if (s->init_num == 0
                        && recvd_type == SSL3_RT_CHANGE_CIPHER_SPEC
                && mtin == SSL3_MT_CHANGE_CIPHER_SPEC
                && (mt < 0 || mt == SSL3_MT_CHANGE_CIPHER_SPEC)) {
            if (*p != SSL3_MT_CCS) {
                al = SSL_AD_UNEXPECTED_MESSAGE;
@@ -453,20 +446,67 @@ long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
                       SSL_R_UNEXPECTED_MESSAGE);
                goto f_err;
            }
                    s->init_num = i - 1;
            s->init_msg = p + 1;
            s->s3->tmp.message_type = SSL3_MT_CHANGE_CIPHER_SPEC;
                    s->s3->tmp.message_size = i - 1;
                    s->state = stn;
            s->s3->tmp.message_size = s->init_num;
            *ok = 1;
            if (s->msg_callback)
                s->msg_callback(0, s->version,
                                SSL3_RT_CHANGE_CIPHER_SPEC, p, 1, s,
                                s->msg_callback_arg);
                    return i - 1;
            return s->init_num;
        }
        if (s->s3->tmp.message_size > (unsigned long)max) {
            al = SSL_AD_ILLEGAL_PARAMETER;
            SSLerr(SSL_F_SSL3_GET_MESSAGE, SSL_R_EXCESSIVE_MESSAGE_SIZE);
            goto f_err;
        }
        if ((mt >= 0) && (mtin != mt)) {
            al = SSL_AD_UNEXPECTED_MESSAGE;
            SSLerr(SSL_F_SSL3_GET_MESSAGE, SSL_R_UNEXPECTED_MESSAGE);
            goto f_err;
        }
    }

    /* next state (stn) */
    if (tls_get_message_body(s, (unsigned long *)&n) == 0) {
        *ok = 0;
        return n;
    }

    *ok = 1;
    return n;
 f_err:
    ssl3_send_alert(s, SSL3_AL_FATAL, al);
    *ok = 0;
    return 0;
}

int tls_get_message_header(SSL *s, int *mt)
{
    /* s->init_num < SSL3_HM_HEADER_LENGTH */
    int skip_message, i, recvd_type, al;
    unsigned char *p;
    unsigned long l;

    p = (unsigned char *)s->init_buf->data;

    do {
        while (s->init_num < SSL3_HM_HEADER_LENGTH) {
            i = s->method->ssl_read_bytes(s, SSL3_RT_HANDSHAKE, &recvd_type,
                &p[s->init_num], SSL3_HM_HEADER_LENGTH - s->init_num, 0);
            if (i <= 0) {
                s->rwstate = SSL_READING;
                return 0;
            }
            if (recvd_type == SSL3_RT_CHANGE_CIPHER_SPEC) {
                s->s3->tmp.message_type = *mt = SSL3_MT_CHANGE_CIPHER_SPEC;
                s->init_num = i - 1;
                s->s3->tmp.message_size = i;
                return 1;
            } else if (recvd_type != SSL3_RT_HANDSHAKE) {
                al = SSL_AD_UNEXPECTED_MESSAGE;
                    SSLerr(SSL_F_SSL3_GET_MESSAGE, SSL_R_CCS_RECEIVED_EARLY);
                SSLerr(SSL_F_TLS_GET_MESSAGE_HEADER, SSL_R_CCS_RECEIVED_EARLY);
                goto f_err;
            }
            s->init_num += i;
@@ -493,12 +533,7 @@ long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
    } while (skip_message);
    /* s->init_num == SSL3_HM_HEADER_LENGTH */

        if ((mt >= 0) && (*p != mt)) {
            al = SSL_AD_UNEXPECTED_MESSAGE;
            SSLerr(SSL_F_SSL3_GET_MESSAGE, SSL_R_UNEXPECTED_MESSAGE);
            goto f_err;
        }

    *mt = *p;
    s->s3->tmp.message_type = *(p++);

    if(RECORD_LAYER_is_sslv2_record(&s->rlayer)) {
@@ -513,41 +548,51 @@ long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
        l = RECORD_LAYER_get_rrec_length(&s->rlayer)
            + SSL3_HM_HEADER_LENGTH;
        if (l && !BUF_MEM_grow_clean(s->init_buf, (int)l)) {
                SSLerr(SSL_F_SSL3_GET_MESSAGE, ERR_R_BUF_LIB);
            SSLerr(SSL_F_TLS_GET_MESSAGE_HEADER, ERR_R_BUF_LIB);
            goto err;
        }
        s->s3->tmp.message_size = l;
            s->state = stn;

        s->init_msg = s->init_buf->data;
        s->init_num = SSL3_HM_HEADER_LENGTH;
    } else {
        n2l3(p, l);
            if (l > (unsigned long)max) {
                al = SSL_AD_ILLEGAL_PARAMETER;
                SSLerr(SSL_F_SSL3_GET_MESSAGE, SSL_R_EXCESSIVE_MESSAGE_SIZE);
                goto f_err;
            }
        /* BUF_MEM_grow takes an 'int' parameter */
        if (l > (INT_MAX - SSL3_HM_HEADER_LENGTH)) {
            al = SSL_AD_ILLEGAL_PARAMETER;
                SSLerr(SSL_F_SSL3_GET_MESSAGE, SSL_R_EXCESSIVE_MESSAGE_SIZE);
            SSLerr(SSL_F_TLS_GET_MESSAGE_HEADER, SSL_R_EXCESSIVE_MESSAGE_SIZE);
            goto f_err;
        }
        if (l && !BUF_MEM_grow_clean(s->init_buf,
                                    (int)l + SSL3_HM_HEADER_LENGTH)) {
                SSLerr(SSL_F_SSL3_GET_MESSAGE, ERR_R_BUF_LIB);
            SSLerr(SSL_F_TLS_GET_MESSAGE_HEADER, ERR_R_BUF_LIB);
            goto err;
        }
        s->s3->tmp.message_size = l;
            s->state = stn;

        s->init_msg = s->init_buf->data + SSL3_HM_HEADER_LENGTH;
        s->init_num = 0;
    }

    return 1;
 f_err:
    ssl3_send_alert(s, SSL3_AL_FATAL, al);
 err:
    return 0;
}

int tls_get_message_body(SSL *s, unsigned long *len)
{
    long n;
    unsigned char *p;
    int i;

    if (s->s3->tmp.message_type == SSL3_MT_CHANGE_CIPHER_SPEC) {
        /* We've already read everything in */
        *len = (unsigned long)s->init_num;
        return 1;
    }

    /* next state (stn) */
    p = s->init_msg;
    n = s->s3->tmp.message_size - s->init_num;
    while (n > 0) {
@@ -555,8 +600,8 @@ long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
                                      &p[s->init_num], n, 0);
        if (i <= 0) {
            s->rwstate = SSL_READING;
            *ok = 0;
            return i;
            *len = 0;
            return 0;
        }
        s->init_num += i;
        n -= i;
@@ -586,13 +631,18 @@ long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
                            s->msg_callback_arg);
    }

    *ok = 1;
    return s->init_num;
 f_err:
    ssl3_send_alert(s, SSL3_AL_FATAL, al);
 err:
    *ok = 0;
    return (-1);
    /*
     * init_num should never be negative...should probably be declared
     * unsigned
     */
    if (s->init_num < 0) {
        SSLerr(SSL_F_TLS_GET_MESSAGE_BODY, ERR_R_INTERNAL_ERROR);
        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
        *len = 0;
        return 0;
    }
    *len = (unsigned long)s->init_num;
    return 1;
}

int ssl_cert_type(X509 *x, EVP_PKEY *pkey)
+2 −0
Original line number Diff line number Diff line
@@ -331,6 +331,8 @@ static ERR_STRING_DATA SSL_str_functs[] = {
    {ERR_FUNC(SSL_F_TLS1_PROCESS_HEARTBEAT), "tls1_process_heartbeat"},
    {ERR_FUNC(SSL_F_TLS1_SETUP_KEY_BLOCK), "tls1_setup_key_block"},
    {ERR_FUNC(SSL_F_TLS1_SET_SERVER_SIGALGS), "tls1_set_server_sigalgs"},
    {ERR_FUNC(SSL_F_TLS_GET_MESSAGE_BODY), "tls_get_message_body"},
    {ERR_FUNC(SSL_F_TLS_GET_MESSAGE_HEADER), "tls_get_message_header"},
    {ERR_FUNC(SSL_F_USE_CERTIFICATE_CHAIN_FILE), "use_certificate_chain_file"},
    {0, NULL}
};
+2 −0
Original line number Diff line number Diff line
@@ -1930,6 +1930,8 @@ __owur int ssl3_generate_master_secret(SSL *s, unsigned char *out,
                                unsigned char *p, int len);
__owur int ssl3_get_req_cert_type(SSL *s, unsigned char *p);
__owur long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok);
__owur int tls_get_message_header(SSL *s, int *mt);
__owur int tls_get_message_body(SSL *s, unsigned long *len);
__owur int ssl3_send_finished(SSL *s, int a, int b, const char *sender, int slen);
__owur int ssl3_num_ciphers(void);
__owur const SSL_CIPHER *ssl3_get_cipher(unsigned int u);