Commit 3a7c56b2 authored by Matt Caswell's avatar Matt Caswell
Browse files

Add TLSv1.3 server side external PSK support

parent 2556aec5
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -763,6 +763,10 @@ typedef unsigned int (*SSL_psk_server_cb_func)(SSL *ssl,
                                               const char *identity,
                                               unsigned char *psk,
                                               unsigned int max_psk_len);
typedef int (*SSL_psk_find_session_cb_func)(SSL *ssl,
                                            const unsigned char *identity,
                                            size_t identity_len,
                                            SSL_SESSION **sess);
void SSL_CTX_set_psk_server_callback(SSL_CTX *ctx, SSL_psk_server_cb_func cb);
void SSL_set_psk_server_callback(SSL *ssl, SSL_psk_server_cb_func cb);

+2 −0
Original line number Diff line number Diff line
@@ -952,6 +952,7 @@ struct ssl_ctx_st {
    SSL_psk_client_cb_func psk_client_callback;
    SSL_psk_server_cb_func psk_server_callback;
# endif
    SSL_psk_find_session_cb_func psk_find_session_cb;

# ifndef OPENSSL_NO_SRP
    SRP_CTX srp_ctx;            /* ctx for SRP authentication */
@@ -1122,6 +1123,7 @@ struct ssl_st {
    SSL_psk_client_cb_func psk_client_callback;
    SSL_psk_server_cb_func psk_server_callback;
# endif
    SSL_psk_find_session_cb_func psk_find_session_cb;
    SSL_CTX *ctx;
    /* Verified chain of peer */
    STACK_OF(X509) *verified_chain;
+15 −7
Original line number Diff line number Diff line
@@ -1227,17 +1227,27 @@ static int init_psk_kex_modes(SSL *s, unsigned int context)

int tls_psk_do_binder(SSL *s, const EVP_MD *md, const unsigned char *msgstart,
                      size_t binderoffset, const unsigned char *binderin,
                      unsigned char *binderout,
                      SSL_SESSION *sess, int sign)
                      unsigned char *binderout, SSL_SESSION *sess, int sign,
                      int external)
{
    EVP_PKEY *mackey = NULL;
    EVP_MD_CTX *mctx = NULL;
    unsigned char hash[EVP_MAX_MD_SIZE], binderkey[EVP_MAX_MD_SIZE];
    unsigned char finishedkey[EVP_MAX_MD_SIZE], tmpbinder[EVP_MAX_MD_SIZE];
    const char resumption_label[] = "res binder";
    size_t bindersize, hashsize = EVP_MD_size(md);
    const char external_label[] = "ext binder";
    const char *label;
    size_t bindersize, labelsize, hashsize = EVP_MD_size(md);
    int ret = -1;

    if (external) {
        label = external_label;
        labelsize = sizeof(external_label) - 1;
    } else {
        label = resumption_label;
        labelsize = sizeof(resumption_label) - 1;
    }

    /* Generate the early_secret */
    if (!tls13_generate_secret(s, md, NULL, sess->master_key,
                               sess->master_key_length,
@@ -1259,10 +1269,8 @@ int tls_psk_do_binder(SSL *s, const EVP_MD *md, const unsigned char *msgstart,
    }

    /* Generate the binder key */
    if (!tls13_hkdf_expand(s, md, s->early_secret,
                           (unsigned char *)resumption_label,
                           sizeof(resumption_label) - 1, hash, binderkey,
                           hashsize)) {
    if (!tls13_hkdf_expand(s, md, s->early_secret, (unsigned char *)label,
                           labelsize, hash, binderkey, hashsize)) {
        SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
        goto err;
    }
+1 −1
Original line number Diff line number Diff line
@@ -894,7 +894,7 @@ EXT_RETURN tls_construct_ctos_psk(SSL *s, WPACKET *pkt, unsigned int context,
    msgstart = WPACKET_get_curr(pkt) - msglen;

    if (tls_psk_do_binder(s, md, msgstart, binderoffset, NULL, binder,
                          s->session, 1) != 1) {
                          s->session, 1, 0) != 1) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
        goto err;
    }
+58 −43
Original line number Diff line number Diff line
@@ -686,9 +686,8 @@ int tls_parse_ctos_psk(SSL *s, PACKET *pkt, unsigned int context, X509 *x,
    PACKET identities, binders, binder;
    size_t binderoffset, hashsize;
    SSL_SESSION *sess = NULL;
    unsigned int id, i;
    unsigned int id, i, ext = 0;
    const EVP_MD *md = NULL;
    uint32_t ticket_age = 0, now, agesec, agems;

    /*
     * If we have no PSK kex mode that we recognise then we can't resume so
@@ -706,7 +705,6 @@ int tls_parse_ctos_psk(SSL *s, PACKET *pkt, unsigned int context, X509 *x,
    for (id = 0; PACKET_remaining(&identities) != 0; id++) {
        PACKET identity;
        unsigned long ticket_agel;
        int ret;

        if (!PACKET_get_length_prefixed_2(&identities, &identity)
                || !PACKET_get_net_4(&identities, &ticket_agel)) {
@@ -714,17 +712,65 @@ int tls_parse_ctos_psk(SSL *s, PACKET *pkt, unsigned int context, X509 *x,
            return 0;
        }

        ticket_age = (uint32_t)ticket_agel;
        if (s->psk_find_session_cb != NULL
                && s->psk_find_session_cb(s, PACKET_data(&identity),
                                          PACKET_remaining(&identity), &sess)) {
            SSL_SESSION *sesstmp = ssl_session_dup(sess, 0);

            if (sesstmp == NULL) {
                *al = SSL_AD_INTERNAL_ERROR;
                return 0;
            }
            SSL_SESSION_free(sess);
            sess = sesstmp;

        ret = tls_decrypt_ticket(s, PACKET_data(&identity),
                                 PACKET_remaining(&identity), NULL, 0, &sess);
        if (ret == TICKET_FATAL_ERR_MALLOC || ret == TICKET_FATAL_ERR_OTHER) {
            /*
             * We've just been told to use this session for this context so
             * make sure the sid_ctx matches up.
             */
            memcpy(sess->sid_ctx, s->sid_ctx, s->sid_ctx_length);
            sess->sid_ctx_length = s->sid_ctx_length;
            ext = 1;
        } else {
            uint32_t ticket_age = 0, now, agesec, agems;
            int ret = tls_decrypt_ticket(s, PACKET_data(&identity),
                                         PACKET_remaining(&identity), NULL, 0,
                                         &sess);

            if (ret == TICKET_FATAL_ERR_MALLOC
                    || ret == TICKET_FATAL_ERR_OTHER) {
                *al = SSL_AD_INTERNAL_ERROR;
                return 0;
            }
            if (ret == TICKET_NO_DECRYPT)
                continue;

            ticket_age = (uint32_t)ticket_agel;
            now = (uint32_t)time(NULL);
            agesec = now - (uint32_t)sess->time;
            agems = agesec * (uint32_t)1000;
            ticket_age -= sess->ext.tick_age_add;

            /*
             * For simplicity we do our age calculations in seconds. If the
             * client does it in ms then it could appear that their ticket age
             * is longer than ours (our ticket age calculation should always be
             * slightly longer than the client's due to the network latency).
             * Therefore we add 1000ms to our age calculation to adjust for
             * rounding errors.
             */
            if (sess->timeout >= (long)agesec
                    && agems / (uint32_t)1000 == agesec
                    && ticket_age <= agems + 1000
                    && ticket_age + TICKET_AGE_ALLOWANCE >= agems + 1000) {
                /*
                 * Ticket age is within tolerance and not expired. We allow it
                 * for early data
                 */
                s->ext.early_data_ok = 1;
            }
        }

        md = ssl_md(sess->cipher->algorithm2);
        if (md != ssl_md(s->s3->tmp.new_cipher->algorithm2)) {
            /* The ciphersuite is not compatible with this session. */
@@ -732,12 +778,6 @@ int tls_parse_ctos_psk(SSL *s, PACKET *pkt, unsigned int context, X509 *x,
            sess = NULL;
            continue;
        }

        /*
         * TODO(TLS1.3): Somehow we need to handle the case of a ticket renewal.
         * Ignored for now
         */

        break;
    }

@@ -763,7 +803,7 @@ int tls_parse_ctos_psk(SSL *s, PACKET *pkt, unsigned int context, X509 *x,
            || tls_psk_do_binder(s, md,
                                 (const unsigned char *)s->init_buf->data,
                                 binderoffset, PACKET_data(&binder), NULL,
                                 sess, 0) != 1) {
                                 sess, 0, ext) != 1) {
        *al = SSL_AD_DECODE_ERROR;
        SSLerr(SSL_F_TLS_PARSE_CTOS_PSK, ERR_R_INTERNAL_ERROR);
        goto err;
@@ -771,31 +811,6 @@ int tls_parse_ctos_psk(SSL *s, PACKET *pkt, unsigned int context, X509 *x,

    sess->ext.tick_identity = id;

    now = (uint32_t)time(NULL);
    agesec = now - (uint32_t)sess->time;
    agems = agesec * (uint32_t)1000;
    ticket_age -= sess->ext.tick_age_add;


    /*
     * For simplicity we do our age calculations in seconds. If the client does
     * it in ms then it could appear that their ticket age is longer than ours
     * (our ticket age calculation should always be slightly longer than the
     * client's due to the network latency). Therefore we add 1000ms to our age
     * calculation to adjust for rounding errors.
     */
    if (sess->timeout >= (long)agesec
            && agems / (uint32_t)1000 == agesec
            && ticket_age <= agems + 1000
            && ticket_age + TICKET_AGE_ALLOWANCE >= agems + 1000) {
        /*
         * Ticket age is within tolerance and not expired. We allow it for early
         * data
         */
        s->ext.early_data_ok = 1;
    }


    SSL_SESSION_free(s->session);
    s->session = sess;
    return 1;
Loading