Commit 6392fb8e authored by Matt Caswell's avatar Matt Caswell
Browse files

Move setting of the handshake header up one more level



We now set the handshake header, and close the packet directly in the
write_state_machine. This is now possible because it is common for all
messages.

Reviewed-by: default avatarRich Salz <rsalz@openssl.org>
parent 229185e6
Loading
Loading
Loading
Loading
+11 −4
Original line number Diff line number Diff line
@@ -708,8 +708,12 @@ static SUB_STATE_RETURN write_state_machine(SSL *s)
    WRITE_TRAN(*transition) (SSL *s);
    WORK_STATE(*pre_work) (SSL *s, WORK_STATE wst);
    WORK_STATE(*post_work) (SSL *s, WORK_STATE wst);
    int (*construct_message) (SSL *s, WPACKET *pkt);
    int (*get_construct_message_f) (SSL *s, WPACKET *pkt,
                                    int (**confunc) (SSL *s, WPACKET *pkt),
                                    int *mt);
    void (*cb) (const SSL *ssl, int type, int val) = NULL;
    int (*confunc) (SSL *s, WPACKET *pkt);
    int mt;
    WPACKET pkt;

    cb = get_callback(s);
@@ -718,12 +722,12 @@ static SUB_STATE_RETURN write_state_machine(SSL *s)
        transition = ossl_statem_server_write_transition;
        pre_work = ossl_statem_server_pre_work;
        post_work = ossl_statem_server_post_work;
        construct_message = ossl_statem_server_construct_message;
        get_construct_message_f = ossl_statem_server_construct_message;
    } else {
        transition = ossl_statem_client_write_transition;
        pre_work = ossl_statem_client_pre_work;
        post_work = ossl_statem_client_post_work;
        construct_message = ossl_statem_client_construct_message;
        get_construct_message_f = ossl_statem_client_construct_message;
    }

    while (1) {
@@ -766,7 +770,10 @@ static SUB_STATE_RETURN write_state_machine(SSL *s)
                return SUB_STATE_END_HANDSHAKE;
            }
            if (!WPACKET_init(&pkt, s->init_buf)
                    || !construct_message(s, &pkt)
                    || !get_construct_message_f(s, &pkt, &confunc, &mt)
                    || !ssl_set_handshake_header(s, &pkt, mt)
                    || (confunc != NULL && !confunc(s, &pkt))
                    || !ssl_close_construct_packet(s, &pkt, mt)
                    || !WPACKET_finish(&pkt)) {
                WPACKET_cleanup(&pkt);
                ossl_statem_set_error(s);
+20 −26
Original line number Diff line number Diff line
@@ -504,17 +504,18 @@ WORK_STATE ossl_statem_client_post_work(SSL *s, WORK_STATE wst)
}

/*
 * Construct a message to be sent from the client to the server.
 * Get the message construction function and message type for sending from the
 * client
 *
 * Valid return values are:
 *   1: Success
 *   0: Error
 */
int ossl_statem_client_construct_message(SSL *s, WPACKET *pkt)
int ossl_statem_client_construct_message(SSL *s, WPACKET *pkt,
                                         int (**confunc) (SSL *s, WPACKET *pkt),
                                         int *mt)
{
    OSSL_STATEM *st = &s->statem;
    int (*confunc) (SSL *s, WPACKET *pkt) = NULL;
    int mt;

    switch (st->hand_state) {
    default:
@@ -523,51 +524,44 @@ int ossl_statem_client_construct_message(SSL *s, WPACKET *pkt)

    case TLS_ST_CW_CHANGE:
        if (SSL_IS_DTLS(s))
            confunc = dtls_construct_change_cipher_spec;
            *confunc = dtls_construct_change_cipher_spec;
        else
            confunc = tls_construct_change_cipher_spec;
        mt = SSL3_MT_CHANGE_CIPHER_SPEC;
            *confunc = tls_construct_change_cipher_spec;
        *mt = SSL3_MT_CHANGE_CIPHER_SPEC;
        break;

    case TLS_ST_CW_CLNT_HELLO:
        confunc = tls_construct_client_hello;
        mt = SSL3_MT_CLIENT_HELLO;
        *confunc = tls_construct_client_hello;
        *mt = SSL3_MT_CLIENT_HELLO;
        break;

    case TLS_ST_CW_CERT:
        confunc = tls_construct_client_certificate;
        mt = SSL3_MT_CERTIFICATE;
        *confunc = tls_construct_client_certificate;
        *mt = SSL3_MT_CERTIFICATE;
        break;

    case TLS_ST_CW_KEY_EXCH:
        confunc = tls_construct_client_key_exchange;
        mt = SSL3_MT_CLIENT_KEY_EXCHANGE;
        *confunc = tls_construct_client_key_exchange;
        *mt = SSL3_MT_CLIENT_KEY_EXCHANGE;
        break;

    case TLS_ST_CW_CERT_VRFY:
        confunc = tls_construct_client_verify;
        mt = SSL3_MT_CERTIFICATE_VERIFY;
        *confunc = tls_construct_client_verify;
        *mt = SSL3_MT_CERTIFICATE_VERIFY;
        break;

#if !defined(OPENSSL_NO_NEXTPROTONEG)
    case TLS_ST_CW_NEXT_PROTO:
        confunc = tls_construct_next_proto;
        mt = SSL3_MT_NEXT_PROTO;
        *confunc = tls_construct_next_proto;
        *mt = SSL3_MT_NEXT_PROTO;
        break;
#endif
    case TLS_ST_CW_FINISHED:
        confunc = tls_construct_finished;
        mt = SSL3_MT_FINISHED;
        *confunc = tls_construct_finished;
        *mt = SSL3_MT_FINISHED;
        break;
    }

    if (!ssl_set_handshake_header(s, pkt, mt)
            || !confunc(s, pkt)
            || !ssl_close_construct_packet(s, pkt, mt)) {
        SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
               ERR_R_INTERNAL_ERROR);
        return 0;
    }
    return 1;
}

+6 −2
Original line number Diff line number Diff line
@@ -50,7 +50,9 @@ int ossl_statem_client_read_transition(SSL *s, int mt);
WRITE_TRAN ossl_statem_client_write_transition(SSL *s);
WORK_STATE ossl_statem_client_pre_work(SSL *s, WORK_STATE wst);
WORK_STATE ossl_statem_client_post_work(SSL *s, WORK_STATE wst);
int ossl_statem_client_construct_message(SSL *s, WPACKET *pkt);
int ossl_statem_client_construct_message(SSL *s, WPACKET *pkt,
                                         int (**confunc) (SSL *s, WPACKET *pkt),
                                         int *mt);
unsigned long ossl_statem_client_max_message_size(SSL *s);
MSG_PROCESS_RETURN ossl_statem_client_process_message(SSL *s, PACKET *pkt);
WORK_STATE ossl_statem_client_post_process_message(SSL *s, WORK_STATE wst);
@@ -62,7 +64,9 @@ int ossl_statem_server_read_transition(SSL *s, int mt);
WRITE_TRAN ossl_statem_server_write_transition(SSL *s);
WORK_STATE ossl_statem_server_pre_work(SSL *s, WORK_STATE wst);
WORK_STATE ossl_statem_server_post_work(SSL *s, WORK_STATE wst);
int ossl_statem_server_construct_message(SSL *s, WPACKET *pkt);
int ossl_statem_server_construct_message(SSL *s, WPACKET *pkt,
                                         int (**confunc) (SSL *s, WPACKET *pkt),
                                         int *mt);
unsigned long ossl_statem_server_max_message_size(SSL *s);
MSG_PROCESS_RETURN ossl_statem_server_process_message(SSL *s, PACKET *pkt);
WORK_STATE ossl_statem_server_post_process_message(SSL *s, WORK_STATE wst);
+28 −34
Original line number Diff line number Diff line
@@ -613,17 +613,18 @@ WORK_STATE ossl_statem_server_post_work(SSL *s, WORK_STATE wst)
}

/*
 * Construct a message to be sent from the server to the client.
 * Get the message construction function and message type for sending from the
 * server
 *
 * Valid return values are:
 *   1: Success
 *   0: Error
 */
int ossl_statem_server_construct_message(SSL *s, WPACKET *pkt)
int ossl_statem_server_construct_message(SSL *s, WPACKET *pkt,
                                         int (**confunc) (SSL *s, WPACKET *pkt),
                                         int *mt)
{
    OSSL_STATEM *st = &s->statem;
    int (*confunc) (SSL *s, WPACKET *pkt) = NULL;
    int mt;

    switch (st->hand_state) {
    default:
@@ -632,71 +633,64 @@ int ossl_statem_server_construct_message(SSL *s, WPACKET *pkt)

    case TLS_ST_SW_CHANGE:
        if (SSL_IS_DTLS(s))
            confunc = dtls_construct_change_cipher_spec;
            *confunc = dtls_construct_change_cipher_spec;
        else
            confunc = tls_construct_change_cipher_spec;
        mt = SSL3_MT_CHANGE_CIPHER_SPEC;
            *confunc = tls_construct_change_cipher_spec;
        *mt = SSL3_MT_CHANGE_CIPHER_SPEC;
        break;

    case DTLS_ST_SW_HELLO_VERIFY_REQUEST:
        confunc = dtls_construct_hello_verify_request;
        mt = DTLS1_MT_HELLO_VERIFY_REQUEST;
        *confunc = dtls_construct_hello_verify_request;
        *mt = DTLS1_MT_HELLO_VERIFY_REQUEST;
        break;

    case TLS_ST_SW_HELLO_REQ:
        /* No construction function needed */
        mt = SSL3_MT_HELLO_REQUEST;
        *confunc = NULL;
        *mt = SSL3_MT_HELLO_REQUEST;
        break;

    case TLS_ST_SW_SRVR_HELLO:
        confunc = tls_construct_server_hello;
        mt = SSL3_MT_SERVER_HELLO;
        *confunc = tls_construct_server_hello;
        *mt = SSL3_MT_SERVER_HELLO;
        break;

    case TLS_ST_SW_CERT:
        confunc = tls_construct_server_certificate;
        mt = SSL3_MT_CERTIFICATE;
        *confunc = tls_construct_server_certificate;
        *mt = SSL3_MT_CERTIFICATE;
        break;

    case TLS_ST_SW_KEY_EXCH:
        confunc = tls_construct_server_key_exchange;
        mt = SSL3_MT_SERVER_KEY_EXCHANGE;
        *confunc = tls_construct_server_key_exchange;
        *mt = SSL3_MT_SERVER_KEY_EXCHANGE;
        break;

    case TLS_ST_SW_CERT_REQ:
        confunc = tls_construct_certificate_request;
        mt = SSL3_MT_CERTIFICATE_REQUEST;
        *confunc = tls_construct_certificate_request;
        *mt = SSL3_MT_CERTIFICATE_REQUEST;
        break;

    case TLS_ST_SW_SRVR_DONE:
        confunc = tls_construct_server_done;
        mt = SSL3_MT_SERVER_DONE;
        *confunc = tls_construct_server_done;
        *mt = SSL3_MT_SERVER_DONE;
        break;

    case TLS_ST_SW_SESSION_TICKET:
        confunc = tls_construct_new_session_ticket;
        mt = SSL3_MT_NEWSESSION_TICKET;
        *confunc = tls_construct_new_session_ticket;
        *mt = SSL3_MT_NEWSESSION_TICKET;
        break;

    case TLS_ST_SW_CERT_STATUS:
        confunc = tls_construct_cert_status;
        mt = SSL3_MT_CERTIFICATE_STATUS;
        *confunc = tls_construct_cert_status;
        *mt = SSL3_MT_CERTIFICATE_STATUS;
        break;

    case TLS_ST_SW_FINISHED:
        confunc = tls_construct_finished;
        mt = SSL3_MT_FINISHED;
        *confunc = tls_construct_finished;
        *mt = SSL3_MT_FINISHED;
        break;
    }

    if (!ssl_set_handshake_header(s, pkt, mt)
            || (confunc != NULL && !confunc(s, pkt))
            || !ssl_close_construct_packet(s, pkt, mt)) {
        SSLerr(SSL_F_OSSL_STATEM_SERVER_CONSTRUCT_MESSAGE,
               ERR_R_INTERNAL_ERROR);
        return 0;
    }

    return 1;
}