Commit 43ae5eed authored by Matt Caswell's avatar Matt Caswell
Browse files

Implement a new custom extensions API



The old custom extensions API was not TLSv1.3 aware. Extensions are used
extensively in TLSv1.3 and they can appear in many different types of
messages. Therefore we need a new API to be able to cope with that.

Reviewed-by: default avatarRich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/3139)
parent fe874d27
Loading
Loading
Loading
Loading
+36 −9
Original line number Diff line number Diff line
@@ -259,19 +259,21 @@ typedef int (*tls_session_secret_cb_fn) (SSL *s, void *secret,
#define SSL_EXT_TLS_IMPLEMENTATION_ONLY         0x0004
/* Most extensions are not defined for SSLv3 but EXT_TYPE_renegotiate is */
#define SSL_EXT_SSL3_ALLOWED                    0x0008
/* Extension is only defined for TLS1.2 and above */
/* Extension is only defined for TLS1.2 and below */
#define SSL_EXT_TLS1_2_AND_BELOW_ONLY           0x0010
/* Extension is only defined for TLS1.3 and above */
#define SSL_EXT_TLS1_3_ONLY                     0x0020
#define SSL_EXT_CLIENT_HELLO                    0x0040
/* Ignore this extension during parsing if we are resuming */
#define SSL_EXT_IGNORE_ON_RESUMPTION            0x0040
#define SSL_EXT_CLIENT_HELLO                    0x0080
/* Really means TLS1.2 or below */
#define SSL_EXT_TLS1_2_SERVER_HELLO             0x0080
#define SSL_EXT_TLS1_3_SERVER_HELLO             0x0100
#define SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS     0x0200
#define SSL_EXT_TLS1_3_HELLO_RETRY_REQUEST      0x0400
#define SSL_EXT_TLS1_3_CERTIFICATE              0x0800
#define SSL_EXT_TLS1_3_NEW_SESSION_TICKET       0x1000
#define SSL_EXT_TLS1_3_CERTIFICATE_REQUEST      0x2000
#define SSL_EXT_TLS1_2_SERVER_HELLO             0x0100
#define SSL_EXT_TLS1_3_SERVER_HELLO             0x0200
#define SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS     0x0400
#define SSL_EXT_TLS1_3_HELLO_RETRY_REQUEST      0x0800
#define SSL_EXT_TLS1_3_CERTIFICATE              0x1000
#define SSL_EXT_TLS1_3_NEW_SESSION_TICKET       0x2000
#define SSL_EXT_TLS1_3_CERTIFICATE_REQUEST      0x4000

/* Typedefs for handling custom extensions */

@@ -286,6 +288,23 @@ typedef int (*custom_ext_parse_cb) (SSL *s, unsigned int ext_type,
                                    const unsigned char *in,
                                    size_t inlen, int *al, void *parse_arg);


typedef int (*custom_ext_add_cb_ex) (SSL *s, unsigned int ext_type,
                                     unsigned int context,
                                     const unsigned char **out,
                                     size_t *outlen, X509 *x, size_t chainidx,
                                     int *al, void *add_arg);

typedef void (*custom_ext_free_cb_ex) (SSL *s, unsigned int ext_type,
                                       unsigned int context,
                                       const unsigned char *out, void *add_arg);

typedef int (*custom_ext_parse_cb_ex) (SSL *s, unsigned int ext_type,
                                       unsigned int context,
                                       const unsigned char *in,
                                       size_t inlen, X509 *x, size_t chainidx,
                                       int *al, void *parse_arg);

/* Typedef for verification callback */
typedef int (*SSL_verify_cb)(int preverify_ok, X509_STORE_CTX *x509_ctx);

@@ -779,6 +798,14 @@ __owur int SSL_CTX_add_server_custom_ext(SSL_CTX *ctx, unsigned int ext_type,
                                  custom_ext_parse_cb parse_cb,
                                  void *parse_arg);

__owur int SSL_CTX_add_custom_ext(SSL_CTX *ctx, unsigned int ext_type,
                                  unsigned int context,
                                  custom_ext_add_cb_ex add_cb,
                                  custom_ext_free_cb_ex free_cb,
                                  void *add_arg,
                                  custom_ext_parse_cb_ex parse_cb,
                                  void *parse_arg);

__owur int SSL_extension_supported(unsigned int ext_type);

# define SSL_NOTHING            1
+2 −5
Original line number Diff line number Diff line
@@ -190,9 +190,7 @@ CERT *ssl_cert_dup(CERT *cert)
    ret->sec_level = cert->sec_level;
    ret->sec_ex = cert->sec_ex;

    if (!custom_exts_copy(&ret->cli_ext, &cert->cli_ext))
        goto err;
    if (!custom_exts_copy(&ret->srv_ext, &cert->srv_ext))
    if (!custom_exts_copy(&ret->custext, &cert->custext))
        goto err;
#ifndef OPENSSL_NO_PSK
    if (cert->psk_identity_hint) {
@@ -254,8 +252,7 @@ void ssl_cert_free(CERT *c)
    OPENSSL_free(c->ctype);
    X509_STORE_free(c->verify_store);
    X509_STORE_free(c->chain_store);
    custom_exts_free(&c->cli_ext);
    custom_exts_free(&c->srv_ext);
    custom_exts_free(&c->custext);
#ifndef OPENSSL_NO_PSK
    OPENSSL_free(c->psk_identity_hint);
#endif
+20 −11
Original line number Diff line number Diff line
@@ -1618,15 +1618,22 @@ struct cert_pkey_st {

typedef struct {
    unsigned short ext_type;
    /*
     * Set to 1 if this is only for the server side, 0 if it is only for the
     * client side, or -1 if it is either.
     */
    int server;
    /* The context which this extension applies to */
    unsigned int context;
    /*
     * Per-connection flags relating to this extension type: not used if
     * part of an SSL_CTX structure.
     */
    uint32_t ext_flags;
    custom_ext_add_cb add_cb;
    custom_ext_free_cb free_cb;
    custom_ext_add_cb_ex add_cb;
    custom_ext_free_cb_ex free_cb;
    void *add_arg;
    custom_ext_parse_cb parse_cb;
    custom_ext_parse_cb_ex parse_cb;
    void *parse_arg;
} custom_ext_method;

@@ -1706,9 +1713,8 @@ typedef struct cert_st {
     */
    X509_STORE *chain_store;
    X509_STORE *verify_store;
    /* Custom extension methods for server and client */
    custom_ext_methods cli_ext;
    custom_ext_methods srv_ext;
    /* Custom extensions */
    custom_ext_methods custext;
    /* Security callback */
    int (*sec_cb) (const SSL *s, const SSL_CTX *ctx, int op, int bits, int nid,
                   void *other, void *ex);
@@ -2436,15 +2442,18 @@ __owur int srp_generate_server_master_secret(SSL *s);
__owur int srp_generate_client_master_secret(SSL *s);
__owur int srp_verify_server_param(SSL *s, int *al);

/* t1_ext.c */
/* statem/extensions_cust.c */

custom_ext_method *custom_ext_find(const custom_ext_methods *exts, int server,
                                   unsigned int ext_type, size_t *idx);

void custom_ext_init(custom_ext_methods *meths);

__owur int custom_ext_parse(SSL *s, int server,
                            unsigned int ext_type,
__owur int custom_ext_parse(SSL *s, unsigned int context, unsigned int ext_type,
                            const unsigned char *ext_data, size_t ext_size,
                            int *al);
__owur int custom_ext_add(SSL *s, int server, WPACKET *pkt, int *al);
                            X509 *x, size_t chainidx, int *al);
__owur int custom_ext_add(SSL *s, int context, WPACKET *pkt, X509 *x,
                          size_t chainidx, int maxversion, int *al);

__owur int custom_exts_copy(custom_ext_methods *dst,
                            const custom_ext_methods *src);
+9 −20
Original line number Diff line number Diff line
@@ -797,26 +797,15 @@ static int serverinfo_process_buffer(const unsigned char *serverinfo,

        /* Register callbacks for extensions */
        ext_type = (serverinfo[0] << 8) + serverinfo[1];
        if (ctx) {
            int have_ext_cbs = 0;
            size_t i;
            custom_ext_methods *exts = &ctx->cert->srv_ext;
            custom_ext_method *meth = exts->meths;

            for (i = 0; i < exts->meths_count; i++, meth++) {
                if (ext_type == meth->ext_type) {
                    have_ext_cbs = 1;
                    break;
                }
            }

            if (!have_ext_cbs && !SSL_CTX_add_server_custom_ext(ctx, ext_type,
        if (ctx != NULL
                && custom_ext_find(&ctx->cert->custext, 1, ext_type, NULL)
                   == NULL
                && !SSL_CTX_add_server_custom_ext(ctx, ext_type,
                                                  serverinfo_srv_add_cb,
                                                  NULL, NULL,
                                                  serverinfo_srv_parse_cb,
                                                  NULL))
            return 0;
        }

        serverinfo += 2;
        serverinfo_length -= 2;
+75 −85
Original line number Diff line number Diff line
@@ -329,6 +329,23 @@ static const EXTENSION_DEFINITION ext_defs[] = {
    }
};

/* Check whether an extension's context matches the current context */
static int validate_context(SSL *s, unsigned int extctx, unsigned int thisctx)
{
    /* Check we're allowed to use this extension in this context */
    if ((thisctx & extctx) == 0)
        return 0;

    if (SSL_IS_DTLS(s)) {
        if ((extctx & SSL_EXT_TLS_ONLY) != 0)
            return 0;
    } else if ((extctx & SSL_EXT_DTLS_ONLY) != 0) {
        return 0;
    }

    return 1;
}

/*
 * Verify whether we are allowed to use the extension |type| in the current
 * |context|. Returns 1 to indicate the extension is allowed or unknown or 0 to
@@ -345,40 +362,33 @@ static int verify_extension(SSL *s, unsigned int context, unsigned int type,

    for (i = 0, thisext = ext_defs; i < builtin_num; i++, thisext++) {
        if (type == thisext->type) {
            /* Check we're allowed to use this extension in this context */
            if ((context & thisext->context) == 0)
            if (!validate_context(s, thisext->context, context))
                return 0;

            if (SSL_IS_DTLS(s)) {
                if ((thisext->context & SSL_EXT_TLS_ONLY) != 0)
                    return 0;
            } else if ((thisext->context & SSL_EXT_DTLS_ONLY) != 0) {
                    return 0;
            }

            *found = &rawexlist[i];
            return 1;
        }
    }

    if ((context & (SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_2_SERVER_HELLO)) == 0) {
        /*
         * Custom extensions only apply to <=TLS1.2. This extension is unknown
         * in this context - we allow it
         */
        *found = NULL;
        return 1;
    }

    /* Check the custom extensions */
    if (meths != NULL) {
        for (i = builtin_num; i < builtin_num + meths->meths_count; i++) {
            if (meths->meths[i - builtin_num].ext_type == type) {
                *found = &rawexlist[i];
        size_t offset = 0;
        int server = -1;
        custom_ext_method *meth = NULL;

        if ((context & SSL_EXT_CLIENT_HELLO) != 0)
            server = 1;
        else if ((context & SSL_EXT_TLS1_2_SERVER_HELLO) != 0)
            server = 0;

        meth = custom_ext_find(meths, server, type, &offset);
        if (meth != NULL) {
            if (!validate_context(s, meth->context, context))
                return 0;
            *found = &rawexlist[offset + builtin_num];
            return 1;
        }
    }
    }

    /* Unknown extension. We allow it */
    *found = NULL;
@@ -390,8 +400,7 @@ static int verify_extension(SSL *s, unsigned int context, unsigned int type,
 * the extension is relevant for the current context |thisctx| or not. Returns
 * 1 if the extension is relevant for this context, and 0 otherwise
 */
static int extension_is_relevant(SSL *s, unsigned int extctx,
                                 unsigned int thisctx)
int extension_is_relevant(SSL *s, unsigned int extctx, unsigned int thisctx)
{
    if ((SSL_IS_DTLS(s)
                && (extctx & SSL_EXT_TLS_IMPLEMENTATION_ONLY) != 0)
@@ -399,7 +408,8 @@ static int extension_is_relevant(SSL *s, unsigned int extctx,
                    && (extctx & SSL_EXT_SSL3_ALLOWED) == 0)
            || (SSL_IS_TLS13(s)
                && (extctx & SSL_EXT_TLS1_2_AND_BELOW_ONLY) != 0)
            || (!SSL_IS_TLS13(s) && (extctx & SSL_EXT_TLS1_3_ONLY) != 0))
            || (!SSL_IS_TLS13(s) && (extctx & SSL_EXT_TLS1_3_ONLY) != 0)
            || (s->hit && (extctx & SSL_EXT_IGNORE_ON_RESUMPTION) != 0))
        return 0;

    return 1;
@@ -427,7 +437,7 @@ int tls_collect_extensions(SSL *s, PACKET *packet, unsigned int context,
    PACKET extensions = *packet;
    size_t i = 0;
    size_t num_exts;
    custom_ext_methods *exts = NULL;
    custom_ext_methods *exts = &s->cert->custext;
    RAW_EXTENSION *raw_extensions = NULL;
    const EXTENSION_DEFINITION *thisexd;

@@ -437,12 +447,8 @@ int tls_collect_extensions(SSL *s, PACKET *packet, unsigned int context,
     * Initialise server side custom extensions. Client side is done during
     * construction of extensions for the ClientHello.
     */
    if ((context & SSL_EXT_CLIENT_HELLO) != 0) {
        exts = &s->cert->srv_ext;
        custom_ext_init(&s->cert->srv_ext);
    } else if ((context & SSL_EXT_TLS1_2_SERVER_HELLO) != 0) {
        exts = &s->cert->cli_ext;
    }
    if ((context & SSL_EXT_CLIENT_HELLO) != 0)
        custom_ext_init(&s->cert->custext);

    num_exts = OSSL_NELEM(ext_defs) + (exts != NULL ? exts->meths_count : 0);
    raw_extensions = OPENSSL_zalloc(num_exts * sizeof(*raw_extensions));
@@ -560,21 +566,11 @@ int tls_parse_extension(SSL *s, TLSEXT_INDEX idx, int context,
         */
    }

    /*
     * This is a custom extension. We only allow this if it is a non
     * resumed session on the server side.
     *chain
     * TODO(TLS1.3): We only allow old style <=TLS1.2 custom extensions.
     * We're going to need a new mechanism for TLS1.3 to specify which
     * messages to add the custom extensions to.
     */
    if ((!s->hit || !s->server)
            && (context
                & (SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_2_SERVER_HELLO)) != 0
            && custom_ext_parse(s, s->server, currext->type,
    /* Parse custom extensions */
    if (custom_ext_parse(s, context, currext->type,
                         PACKET_data(&currext->data),
                         PACKET_remaining(&currext->data),
                                al) <= 0)
                         x, chainidx, al) <= 0)
        return 0;

    return 1;
@@ -595,11 +591,7 @@ int tls_parse_all_extensions(SSL *s, int context, RAW_EXTENSION *exts, X509 *x,
    const EXTENSION_DEFINITION *thisexd;

    /* Calculate the number of extensions in the extensions list */
    if ((context & SSL_EXT_CLIENT_HELLO) != 0) {
        numexts += s->cert->srv_ext.meths_count;
    } else if ((context & SSL_EXT_TLS1_2_SERVER_HELLO) != 0) {
        numexts += s->cert->cli_ext.meths_count;
    }
    numexts += s->cert->custext.meths_count;

    /* Parse each extension in turn */
    for (i = 0; i < numexts; i++) {
@@ -621,6 +613,30 @@ int tls_parse_all_extensions(SSL *s, int context, RAW_EXTENSION *exts, X509 *x,
    return 1;
}

int should_add_extension(SSL *s, unsigned int extctx, unsigned int thisctx,
                         int max_version)
{
    /* Skip if not relevant for our context */
    if ((extctx & thisctx) == 0)
        return 0;

    /* Check if this extension is defined for our protocol. If not, skip */
    if ((SSL_IS_DTLS(s) && (extctx & SSL_EXT_TLS_IMPLEMENTATION_ONLY) != 0)
            || (s->version == SSL3_VERSION
                    && (extctx & SSL_EXT_SSL3_ALLOWED) == 0)
            || (SSL_IS_TLS13(s)
                && (extctx & SSL_EXT_TLS1_2_AND_BELOW_ONLY) != 0)
            || (!SSL_IS_TLS13(s)
                && (extctx & SSL_EXT_TLS1_3_ONLY) != 0
                && (thisctx & SSL_EXT_CLIENT_HELLO) == 0)
            || ((extctx & SSL_EXT_TLS1_3_ONLY) != 0
                && (thisctx & SSL_EXT_CLIENT_HELLO) != 0
                && (SSL_IS_DTLS(s) || max_version < TLS1_3_VERSION)))
        return 0;

    return 1;
}

/*
 * Construct all the extensions relevant to the current |context| and write
 * them to |pkt|. If this is an extension for a Certificate in a Certificate
@@ -634,7 +650,7 @@ int tls_construct_extensions(SSL *s, WPACKET *pkt, unsigned int context,
                             X509 *x, size_t chainidx, int *al)
{
    size_t i;
    int addcustom = 0, min_version, max_version = 0, reason, tmpal;
    int min_version, max_version = 0, reason, tmpal;
    const EXTENSION_DEFINITION *thisexd;

    /*
@@ -667,21 +683,10 @@ int tls_construct_extensions(SSL *s, WPACKET *pkt, unsigned int context,

    /* Add custom extensions first */
    if ((context & SSL_EXT_CLIENT_HELLO) != 0) {
        custom_ext_init(&s->cert->cli_ext);
        addcustom = 1;
    } else if ((context & SSL_EXT_TLS1_2_SERVER_HELLO) != 0) {
        /*
         * We already initialised the custom extensions during ClientHello
         * parsing.
         *
         * TODO(TLS1.3): We're going to need a new custom extension mechanism
         * for TLS1.3, so that custom extensions can specify which of the
         * multiple message they wish to add themselves to.
         */
        addcustom = 1;
        /* On the server side with initiase during ClientHello parsing */
        custom_ext_init(&s->cert->custext);
    }

    if (addcustom && !custom_ext_add(s, s->server, pkt, &tmpal)) {
    if (!custom_ext_add(s, context, pkt, x, chainidx, max_version, &tmpal)) {
        SSLerr(SSL_F_TLS_CONSTRUCT_EXTENSIONS, ERR_R_INTERNAL_ERROR);
        goto err;
    }
@@ -691,28 +696,13 @@ int tls_construct_extensions(SSL *s, WPACKET *pkt, unsigned int context,
                         size_t chainidx, int *al);

        /* Skip if not relevant for our context */
        if ((thisexd->context & context) == 0)
        if (!should_add_extension(s, thisexd->context, context, max_version))
            continue;

        construct = s->server ? thisexd->construct_stoc
                              : thisexd->construct_ctos;

        /* Check if this extension is defined for our protocol. If not, skip */
        if ((SSL_IS_DTLS(s)
                    && (thisexd->context & SSL_EXT_TLS_IMPLEMENTATION_ONLY)
                       != 0)
                || (s->version == SSL3_VERSION
                        && (thisexd->context & SSL_EXT_SSL3_ALLOWED) == 0)
                || (SSL_IS_TLS13(s)
                    && (thisexd->context & SSL_EXT_TLS1_2_AND_BELOW_ONLY)
                       != 0)
                || (!SSL_IS_TLS13(s)
                    && (thisexd->context & SSL_EXT_TLS1_3_ONLY) != 0
                    && (context & SSL_EXT_CLIENT_HELLO) == 0)
                || ((thisexd->context & SSL_EXT_TLS1_3_ONLY) != 0
                    && (context & SSL_EXT_CLIENT_HELLO) != 0
                    && (SSL_IS_DTLS(s) || max_version < TLS1_3_VERSION))
                || construct == NULL)
        if (construct == NULL)
            continue;

        if (!construct(s, pkt, context, x, chainidx, &tmpal))
Loading