Commit 5923ad4b authored by Matt Caswell's avatar Matt Caswell
Browse files

Don't set the handshake header in every message



Move setting the handshake header up a level into the state machine code
in order to reduce boilerplate.

Reviewed-by: default avatarRich Salz <rsalz@openssl.org>
parent 7cea05dc
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -2078,7 +2078,9 @@ int ERR_load_SSL_strings(void);
# define SSL_F_DTLS_GET_REASSEMBLED_MESSAGE               370
# define SSL_F_DTLS_PROCESS_HELLO_VERIFY                  386
# define SSL_F_OPENSSL_INIT_SSL                           342
# define SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE       430
# define SSL_F_OSSL_STATEM_CLIENT_READ_TRANSITION         417
# define SSL_F_OSSL_STATEM_SERVER_CONSTRUCT_MESSAGE       431
# define SSL_F_OSSL_STATEM_SERVER_READ_TRANSITION         418
# define SSL_F_READ_STATE_MACHINE                         352
# define SSL_F_SSL3_CHANGE_CIPHER_STATE                   129
+4 −0
Original line number Diff line number Diff line
@@ -49,8 +49,12 @@ static ERR_STRING_DATA SSL_str_functs[] = {
     "dtls_get_reassembled_message"},
    {ERR_FUNC(SSL_F_DTLS_PROCESS_HELLO_VERIFY), "dtls_process_hello_verify"},
    {ERR_FUNC(SSL_F_OPENSSL_INIT_SSL), "OPENSSL_init_ssl"},
    {ERR_FUNC(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE),
     "ossl_statem_client_construct_message"},
    {ERR_FUNC(SSL_F_OSSL_STATEM_CLIENT_READ_TRANSITION),
     "ossl_statem_client_read_transition"},
    {ERR_FUNC(SSL_F_OSSL_STATEM_SERVER_CONSTRUCT_MESSAGE),
     "ossl_statem_server_construct_message"},
    {ERR_FUNC(SSL_F_OSSL_STATEM_SERVER_READ_TRANSITION),
     "ossl_statem_server_read_transition"},
    {ERR_FUNC(SSL_F_READ_STATE_MACHINE), "read_state_machine"},
+58 −68
Original line number Diff line number Diff line
@@ -513,41 +513,74 @@ WORK_STATE ossl_statem_client_post_work(SSL *s, WORK_STATE wst)
int ossl_statem_client_construct_message(SSL *s, WPACKET *pkt)
{
    OSSL_STATEM *st = &s->statem;
    int (*confunc) (SSL *s, WPACKET *pkt) = NULL;
    int ret = 1, mt;

    if (st->hand_state == TLS_ST_CW_CHANGE) {
        /* Special case becase it is a different content type */
        if (SSL_IS_DTLS(s))
            return dtls_construct_change_cipher_spec(s, pkt);

        return tls_construct_change_cipher_spec(s, pkt);
    } else {
        switch (st->hand_state) {
        default:
            /* Shouldn't happen */
            return 0;

        case TLS_ST_CW_CLNT_HELLO:
        return tls_construct_client_hello(s, pkt);
            confunc = tls_construct_client_hello;
            mt = SSL3_MT_CLIENT_HELLO;
            break;

        case TLS_ST_CW_CERT:
        return tls_construct_client_certificate(s, pkt);
            confunc = tls_construct_client_certificate;
            mt = SSL3_MT_CERTIFICATE;
            break;

        case TLS_ST_CW_KEY_EXCH:
        return tls_construct_client_key_exchange(s, pkt);
            confunc = tls_construct_client_key_exchange;
            mt = SSL3_MT_CLIENT_KEY_EXCHANGE;
            break;

        case TLS_ST_CW_CERT_VRFY:
        return tls_construct_client_verify(s, pkt);

    case TLS_ST_CW_CHANGE:
        if (SSL_IS_DTLS(s))
            return dtls_construct_change_cipher_spec(s, pkt);
        else
            return tls_construct_change_cipher_spec(s, pkt);
            confunc = tls_construct_client_verify;
            mt = SSL3_MT_CERTIFICATE_VERIFY;
            break;

#if !defined(OPENSSL_NO_NEXTPROTONEG)
        case TLS_ST_CW_NEXT_PROTO:
        return tls_construct_next_proto(s, pkt);
            confunc = tls_construct_next_proto;
            mt = SSL3_MT_NEXT_PROTO;
            break;
#endif
        case TLS_ST_CW_FINISHED:
        return tls_construct_finished(s, pkt,
            mt = SSL3_MT_FINISHED;
            break;
        }

        if (!ssl_set_handshake_header(s, pkt, mt)) {
            SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
                   ERR_R_INTERNAL_ERROR);
            return 0;
        }

        if (st->hand_state == TLS_ST_CW_FINISHED)
            ret = tls_construct_finished(s, pkt,
                                         s->method->
                                         ssl3_enc->client_finished_label,
                                         s->method->
                                         ssl3_enc->client_finished_label_len);
        else
            ret = confunc(s, pkt);

        if (!ret || !ssl_close_construct_packet(s, pkt)) {
            SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
                   ERR_R_INTERNAL_ERROR);
            return 0;
        }
    }
    return 1;
}

/*
@@ -736,12 +769,6 @@ int tls_construct_client_hello(SSL *s, WPACKET *pkt)
    if (i && ssl_fill_hello_random(s, 0, p, sizeof(s->s3->client_random)) <= 0)
        return 0;

    if (!ssl_set_handshake_header(s, pkt, SSL3_MT_CLIENT_HELLO)) {
        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
        return 0;
    }

    /*-
     * version indicates the negotiated version: for example from
     * an SSLv2/v3 compatible client hello). The client_version
@@ -855,11 +882,6 @@ int tls_construct_client_hello(SSL *s, WPACKET *pkt)
        return 0;
    }

    if (!ssl_close_construct_packet(s, pkt)) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_HELLO, ERR_R_INTERNAL_ERROR);
        return 0;
    }

    return 1;
}

@@ -2455,12 +2477,6 @@ int tls_construct_client_key_exchange(SSL *s, WPACKET *pkt)
    unsigned long alg_k;
    int al = -1;

    if (!ssl_set_handshake_header(s, pkt, SSL3_MT_CLIENT_KEY_EXCHANGE)) {
        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
        goto err;
    }

    alg_k = s->s3->tmp.new_cipher->algorithm_mkey;

    if ((alg_k & SSL_PSK)
@@ -2488,12 +2504,6 @@ int tls_construct_client_key_exchange(SSL *s, WPACKET *pkt)
        goto err;
    }

    if (!ssl_close_construct_packet(s, pkt)) {
        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
        goto err;
    }

    return 1;
 err:
    if (al != -1)
@@ -2582,11 +2592,6 @@ int tls_construct_client_verify(SSL *s, WPACKET *pkt)
    void *hdata;
    unsigned char *sig = NULL;

    if (!ssl_set_handshake_header(s, pkt, SSL3_MT_CERTIFICATE_VERIFY)) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_VERIFY, ERR_R_INTERNAL_ERROR);
        goto err;
    }

    mctx = EVP_MD_CTX_new();
    if (mctx == NULL) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_VERIFY, ERR_R_MALLOC_FAILURE);
@@ -2640,11 +2645,6 @@ int tls_construct_client_verify(SSL *s, WPACKET *pkt)
    if (!ssl3_digest_cached_records(s, 0))
        goto err;

    if (!ssl_close_construct_packet(s, pkt)) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_VERIFY, ERR_R_INTERNAL_ERROR);
        goto err;
    }

    OPENSSL_free(sig);
    EVP_MD_CTX_free(mctx);
    return 1;
@@ -2846,11 +2846,6 @@ int tls_construct_next_proto(SSL *s, WPACKET *pkt)
    size_t len, padding_len;
    unsigned char *padding = NULL;

    if (!ssl_set_handshake_header(s, pkt, SSL3_MT_NEXT_PROTO)) {
        SSLerr(SSL_F_TLS_CONSTRUCT_NEXT_PROTO, ERR_R_INTERNAL_ERROR);
        goto err;
    }

    len = s->next_proto_negotiated_len;
    padding_len = 32 - ((len + 2) % 32);

@@ -2862,11 +2857,6 @@ int tls_construct_next_proto(SSL *s, WPACKET *pkt)

    memset(padding, 0, padding_len);

    if (!ssl_close_construct_packet(s, pkt)) {
        SSLerr(SSL_F_TLS_CONSTRUCT_NEXT_PROTO, ERR_R_INTERNAL_ERROR);
        goto err;
    }

    return 1;
 err:
    ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
+2 −14
Original line number Diff line number Diff line
@@ -75,11 +75,6 @@ int tls_construct_finished(SSL *s, WPACKET *pkt, const char *sender, int slen)
{
    int i;

    if (!ssl_set_handshake_header(s, pkt, SSL3_MT_FINISHED)) {
        SSLerr(SSL_F_TLS_CONSTRUCT_FINISHED, ERR_R_INTERNAL_ERROR);
        goto err;
    }

    i = s->method->ssl3_enc->final_finish_mac(s,
                                              sender, slen,
                                              s->s3->tmp.finish_md);
@@ -108,11 +103,6 @@ int tls_construct_finished(SSL *s, WPACKET *pkt, const char *sender, int slen)
        s->s3->previous_server_finished_len = i;
    }

    if (!ssl_close_construct_packet(s, pkt)) {
        SSLerr(SSL_F_TLS_CONSTRUCT_FINISHED, ERR_R_INTERNAL_ERROR);
        goto err;
    }

    return 1;
 err:
    ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
@@ -278,11 +268,9 @@ int tls_construct_change_cipher_spec(SSL *s, WPACKET *pkt)

unsigned long ssl3_output_cert_chain(SSL *s, WPACKET *pkt, CERT_PKEY *cpk)
{
    if (!ssl_set_handshake_header(s, pkt, SSL3_MT_CERTIFICATE)
            || !WPACKET_start_sub_packet_u24(pkt)
    if (!WPACKET_start_sub_packet_u24(pkt)
            || !ssl_add_cert_chain(s, pkt, cpk)
            || !WPACKET_close(pkt)
            || !ssl_close_construct_packet(s, pkt)) {
            || !WPACKET_close(pkt)) {
        SSLerr(SSL_F_SSL3_OUTPUT_CERT_CHAIN, ERR_R_INTERNAL_ERROR);
        return 0;
    }
+0 −1
Original line number Diff line number Diff line
@@ -109,7 +109,6 @@ __owur MSG_PROCESS_RETURN dtls_process_hello_verify(SSL *s, PACKET *pkt);
__owur MSG_PROCESS_RETURN tls_process_client_hello(SSL *s, PACKET *pkt);
__owur WORK_STATE tls_post_process_client_hello(SSL *s, WORK_STATE wst);
__owur int tls_construct_server_hello(SSL *s, WPACKET *pkt);
__owur int tls_construct_hello_request(SSL *s, WPACKET *pkt);
__owur int dtls_construct_hello_verify_request(SSL *s, WPACKET *pkt);
__owur int tls_construct_server_certificate(SSL *s, WPACKET *pkt);
__owur int tls_construct_server_key_exchange(SSL *s, WPACKET *pkt);
Loading