Commit 1ab3836b authored by Matt Caswell's avatar Matt Caswell
Browse files

Refactor ClientHello processing so that extensions get parsed earlier

parent e3fb4d3d
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -2256,6 +2256,7 @@ int ERR_load_SSL_strings(void);
# define SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE          377
# define SSL_F_TLS_GET_MESSAGE_BODY                       351
# define SSL_F_TLS_GET_MESSAGE_HEADER                     387
# define SSL_F_TLS_PARSE_RAW_EXTENSIONS                   432
# define SSL_F_TLS_POST_PROCESS_CLIENT_HELLO              378
# define SSL_F_TLS_POST_PROCESS_CLIENT_KEY_EXCHANGE       384
# define SSL_F_TLS_PREPARE_CLIENT_CERTIFICATE             360
+1 −0
Original line number Diff line number Diff line
@@ -279,6 +279,7 @@ static ERR_STRING_DATA SSL_str_functs[] = {
     "tls_construct_server_key_exchange"},
    {ERR_FUNC(SSL_F_TLS_GET_MESSAGE_BODY), "tls_get_message_body"},
    {ERR_FUNC(SSL_F_TLS_GET_MESSAGE_HEADER), "tls_get_message_header"},
    {ERR_FUNC(SSL_F_TLS_PARSE_RAW_EXTENSIONS), "tls_parse_raw_extensions"},
    {ERR_FUNC(SSL_F_TLS_POST_PROCESS_CLIENT_HELLO),
     "tls_post_process_client_hello"},
    {ERR_FUNC(SSL_F_TLS_POST_PROCESS_CLIENT_KEY_EXCHANGE),
+29 −7
Original line number Diff line number Diff line
@@ -1624,6 +1624,29 @@ typedef struct ssl3_comp_st {
} SSL3_COMP;
# endif

typedef struct {
    unsigned int type;
    PACKET data;
} RAW_EXTENSION;

#define MAX_COMPRESSIONS_SIZE   255

typedef struct {
    unsigned int isv2;
    unsigned int version;
    unsigned char random[SSL3_RANDOM_SIZE];
    size_t session_id_len;
    unsigned char session_id[SSL_MAX_SSL_SESSION_ID_LENGTH];
    size_t dtls_cookie_len;
    unsigned char dtls_cookie[DTLS1_COOKIE_LENGTH];
    PACKET ciphersuites;
    size_t compressions_len;
    unsigned char compressions[MAX_COMPRESSIONS_SIZE];
    PACKET extensions;
    size_t num_extensions;
    RAW_EXTENSION *pre_proc_exts;
} CLIENTHELLO_MSG;

extern SSL3_ENC_METHOD ssl3_undef_enc_method;

__owur const SSL_METHOD *ssl_bad_method(int ver);
@@ -1797,8 +1820,7 @@ __owur CERT *ssl_cert_dup(CERT *cert);
void ssl_cert_clear_certs(CERT *c);
void ssl_cert_free(CERT *c);
__owur int ssl_get_new_session(SSL *s, int session);
__owur int ssl_get_prev_session(SSL *s, const PACKET *ext,
                                const PACKET *session_id);
__owur int ssl_get_prev_session(SSL *s, CLIENTHELLO_MSG *hello);
__owur SSL_SESSION *ssl_session_dup(SSL_SESSION *src, int ticket);
__owur int ssl_cipher_id_cmp(const SSL_CIPHER *a, const SSL_CIPHER *b);
DECLARE_OBJ_BSEARCH_GLOBAL_CMP_FN(SSL_CIPHER, SSL_CIPHER, ssl_cipher_id);
@@ -1919,7 +1941,7 @@ __owur int ssl_version_supported(const SSL *s, int version);
__owur int ssl_set_client_hello_version(SSL *s);
__owur int ssl_check_version_downgrade(SSL *s);
__owur int ssl_set_version_bound(int method_version, int version, int *bound);
__owur int ssl_choose_server_version(SSL *s);
__owur int ssl_choose_server_version(SSL *s, CLIENTHELLO_MSG *hello);
__owur int ssl_choose_client_version(SSL *s, int version);
int ssl_get_client_min_max_version(const SSL *s, int *min_version,
                                   int *max_version);
@@ -2020,7 +2042,7 @@ __owur int tls1_shared_list(SSL *s,
                            const unsigned char *l2, size_t l2len, int nmatch);
__owur int ssl_add_clienthello_tlsext(SSL *s, WPACKET *pkt, int *al);
__owur int ssl_add_serverhello_tlsext(SSL *s, WPACKET *pkt, int *al);
__owur int ssl_parse_clienthello_tlsext(SSL *s, PACKET *pkt);
__owur int ssl_parse_clienthello_tlsext(SSL *s, CLIENTHELLO_MSG *hello);
void ssl_set_default_md(SSL *s);
__owur int tls1_set_server_sigalgs(SSL *s);
__owur int ssl_check_clienthello_tlsext_late(SSL *s, int *al);
@@ -2034,9 +2056,9 @@ __owur int dtls1_process_heartbeat(SSL *s, unsigned char *p,
                                   size_t length);
#  endif

__owur int tls_check_serverhello_tlsext_early(SSL *s, const PACKET *ext,
                                              const PACKET *session_id,
__owur int tls_get_ticket_from_client(SSL *s, CLIENTHELLO_MSG *hello,
                                      SSL_SESSION **ret);
__owur int tls_check_client_ems_support(SSL *s, CLIENTHELLO_MSG *hello);

__owur int tls12_get_sigandhash(WPACKET *pkt, const EVP_PKEY *pk,
                                const EVP_MD *md);
+11 −12
Original line number Diff line number Diff line
@@ -445,7 +445,7 @@ int ssl_get_new_session(SSL *s, int session)
 *   - Both for new and resumed sessions, s->tlsext_ticket_expected is set to 1
 *     if the server should issue a new session ticket (to 0 otherwise).
 */
int ssl_get_prev_session(SSL *s, const PACKET *ext, const PACKET *session_id)
int ssl_get_prev_session(SSL *s, CLIENTHELLO_MSG *hello)
{
    /* This is used only by servers. */

@@ -454,11 +454,11 @@ int ssl_get_prev_session(SSL *s, const PACKET *ext, const PACKET *session_id)
    int try_session_cache = 1;
    int r;

    if (PACKET_remaining(session_id) == 0)
    if (hello->session_id_len == 0)
        try_session_cache = 0;

    /* sets s->tlsext_ticket_expected and extended master secret flag */
    r = tls_check_serverhello_tlsext_early(s, ext, session_id, &ret);
    /* sets s->tlsext_ticket_expected */
    r = tls_get_ticket_from_client(s, hello, &ret);
    switch (r) {
    case -1:                   /* Error during processing */
        fatal = 1;
@@ -479,14 +479,12 @@ int ssl_get_prev_session(SSL *s, const PACKET *ext, const PACKET *session_id)
        !(s->session_ctx->session_cache_mode &
          SSL_SESS_CACHE_NO_INTERNAL_LOOKUP)) {
        SSL_SESSION data;
        size_t local_len;

        data.ssl_version = s->version;
        memset(data.session_id, 0, sizeof(data.session_id));
        if (!PACKET_copy_all(session_id, data.session_id,
                             sizeof(data.session_id), &local_len)) {
            goto err;
        }
        data.session_id_length = local_len;
        memcpy(data.session_id, hello->session_id, hello->session_id_len);
        data.session_id_length = hello->session_id_len;

        CRYPTO_THREAD_read_lock(s->session_ctx->lock);
        ret = lh_SSL_SESSION_retrieve(s->session_ctx->sessions, &data);
        if (ret != NULL) {
@@ -501,8 +499,9 @@ int ssl_get_prev_session(SSL *s, const PACKET *ext, const PACKET *session_id)
    if (try_session_cache &&
        ret == NULL && s->session_ctx->get_session_cb != NULL) {
        int copy = 1;
        ret = s->session_ctx->get_session_cb(s, PACKET_data(session_id),
                                             (int)PACKET_remaining(session_id),

        ret = s->session_ctx->get_session_cb(s, hello->session_id,
                                             hello->session_id_len,
                                             &copy);

        if (ret != NULL) {
+92 −2
Original line number Diff line number Diff line
@@ -152,6 +152,94 @@ static void ssl3_take_mac(SSL *s)
}
#endif

static int compare_extensions(const void *p1, const void *p2)
{
    const RAW_EXTENSION *e1 = (const RAW_EXTENSION *)p1;
    const RAW_EXTENSION *e2 = (const RAW_EXTENSION *)p2;
    if (e1->type < e2->type)
        return -1;
    else if (e1->type > e2->type)
        return 1;
    else
        return 0;
}

/*
 * Gather a list of all the extensions. We don't actually process the content
 * of the extensions yet, except to check their types.
 *
 * Per http://tools.ietf.org/html/rfc5246#section-7.4.1.4, there may not be
 * more than one extension of the same type in a ClientHello or ServerHello.
 * This function returns 1 if all extensions are unique and we have parsed their
 * types, and 0 if the extensions contain duplicates, could not be successfully
 * parsed, or an internal error occurred.
 */
int tls_parse_raw_extensions(PACKET *packet, RAW_EXTENSION **res,
                             size_t *numfound, int *ad)
{
    PACKET extensions = *packet;
    size_t num_extensions = 0, i = 0;
    RAW_EXTENSION *raw_extensions = NULL;

    /* First pass: count the extensions. */
    while (PACKET_remaining(&extensions) > 0) {
        unsigned int type;
        PACKET extension;
        if (!PACKET_get_net_2(&extensions, &type) ||
            !PACKET_get_length_prefixed_2(&extensions, &extension)) {
            *ad = SSL_AD_DECODE_ERROR;
            goto done;
        }
        num_extensions++;
    }

    if (num_extensions > 0) {
        raw_extensions = OPENSSL_malloc(sizeof(RAW_EXTENSION) * num_extensions);
        if (raw_extensions == NULL) {
            *ad = SSL_AD_INTERNAL_ERROR;
            SSLerr(SSL_F_TLS_PARSE_RAW_EXTENSIONS, ERR_R_MALLOC_FAILURE);
            goto done;
        }

        /* Second pass: gather the extension types. */
        for (i = 0; i < num_extensions; i++) {
            if (!PACKET_get_net_2(packet, &raw_extensions[i].type) ||
                !PACKET_get_length_prefixed_2(packet,
                                              &raw_extensions[i].data)) {
                /* This should not happen. */
                *ad = SSL_AD_INTERNAL_ERROR;
                SSLerr(SSL_F_TLS_PARSE_RAW_EXTENSIONS, ERR_R_INTERNAL_ERROR);
                goto done;
            }
        }

        if (PACKET_remaining(packet) != 0) {
            *ad = SSL_AD_DECODE_ERROR;
            SSLerr(SSL_F_TLS_PARSE_RAW_EXTENSIONS, SSL_R_LENGTH_MISMATCH);
            goto done;
        }
        /* Sort the extensions and make sure there are no duplicates. */
        qsort(raw_extensions, num_extensions, sizeof(RAW_EXTENSION),
              compare_extensions);
        for (i = 1; i < num_extensions; i++) {
            if (raw_extensions[i - 1].type == raw_extensions[i].type) {
                *ad = SSL_AD_DECODE_ERROR;
                goto done;
            }
        }
    }

    *res = raw_extensions;
    *numfound = num_extensions;
    return 1;

 done:
    OPENSSL_free(raw_extensions);
    return 0;
}



MSG_PROCESS_RETURN tls_process_change_cipher_spec(SSL *s, PACKET *pkt)
{
    int al;
@@ -875,7 +963,7 @@ int ssl_set_version_bound(int method_version, int version, int *bound)
 *
 * Returns 0 on success or an SSL error reason number on failure.
 */
int ssl_choose_server_version(SSL *s)
int ssl_choose_server_version(SSL *s, CLIENTHELLO_MSG *hello)
{
    /*-
     * With version-flexible methods we have an initial state with:
@@ -887,11 +975,13 @@ int ssl_choose_server_version(SSL *s)
     * handle version.
     */
    int server_version = s->method->version;
    int client_version = s->client_version;
    int client_version = hello->version;
    const version_info *vent;
    const version_info *table;
    int disabled = 0;

    s->client_version = client_version;

    switch (server_version) {
    default:
        if (version_cmp(s, client_version, s->version) < 0)
Loading