Commit 5ffff599 authored by Matt Caswell's avatar Matt Caswell
Browse files

Add the ability to set a TLSv1.3 PSK via just the key bytes

parent 5a43d511
Loading
Loading
Loading
Loading
+51 −6
Original line number Diff line number Diff line
@@ -172,25 +172,68 @@ static unsigned int psk_client_cb(SSL *ssl, const char *hint, char *identity,
}
#endif

#define TLS13_AES_128_GCM_SHA256_BYTES  ((const unsigned char *)"\x13\x01")
#define TLS13_AES_256_GCM_SHA384_BYTES  ((const unsigned char *)"\x13\x02")

static int psk_use_session_cb(SSL *s, const EVP_MD *md,
                              const unsigned char **id, size_t *idlen,
                              SSL_SESSION **sess)
{
    const SSL_CIPHER *cipher = SSL_SESSION_get0_cipher(psksess);
    SSL_SESSION *usesess = NULL;
    const SSL_CIPHER *cipher = NULL;

    if (cipher == NULL)
    if (psksess != NULL) {
        SSL_SESSION_up_ref(psksess);
        usesess = psksess;
    } else {
        long key_len;
        unsigned char *key = OPENSSL_hexstr2buf(psk_key, &key_len);

        if (key == NULL) {
            BIO_printf(bio_err, "Could not convert PSK key '%s' to buffer\n",
                       psk_key);
            return 0;
        }

    if (md != NULL && SSL_CIPHER_get_handshake_digest(cipher) != md)
        if (key_len == EVP_MD_size(EVP_sha256()))
            cipher = SSL_CIPHER_find(s, TLS13_AES_128_GCM_SHA256_BYTES);
        else if(key_len == EVP_MD_size(EVP_sha384()))
            cipher = SSL_CIPHER_find(s, TLS13_AES_256_GCM_SHA384_BYTES);

        if (cipher == NULL) {
            /* Doesn't look like a suitable TLSv1.3 key. Ignore it */
            OPENSSL_free(key);
            return 0;
        }
        usesess = SSL_SESSION_new();
        if (usesess == NULL
                || !SSL_SESSION_set1_master_key(usesess, key, key_len)
                || !SSL_SESSION_set_cipher(usesess, cipher)
                || !SSL_SESSION_set_protocol_version(usesess, TLS1_3_VERSION)) {
            OPENSSL_free(key);
            goto err;
        }
        OPENSSL_free(key);
    }

    SSL_SESSION_up_ref(psksess);
    *sess = psksess;
    cipher = SSL_SESSION_get0_cipher(usesess);

    if (cipher == NULL)
        goto err;

    if (md != NULL && SSL_CIPHER_get_handshake_digest(cipher) != md)
        goto err;

    *sess = usesess;

    *id = (unsigned char *)psk_identity;
    *idlen = strlen(psk_identity);

    return 1;

 err:
    SSL_SESSION_free(usesess);
    return 0;
}

/* This is a context that we pass to callbacks */
@@ -1699,8 +1742,10 @@ int s_client_main(int argc, char **argv)
            ERR_print_errors(bio_err);
            goto end;
        }
        SSL_CTX_set_psk_use_session_callback(ctx, psk_use_session_cb);
    }
    if (psk_key != NULL || psksess != NULL)
        SSL_CTX_set_psk_use_session_callback(ctx, psk_use_session_cb);

#ifndef OPENSSL_NO_SRTP
    if (srtp_profiles != NULL) {
        /* Returns 0 on success! */
+46 −3
Original line number Diff line number Diff line
@@ -179,15 +179,55 @@ static unsigned int psk_server_cb(SSL *ssl, const char *identity,
}
#endif

#define TLS13_AES_128_GCM_SHA256_BYTES  ((const unsigned char *)"\x13\x01")
#define TLS13_AES_256_GCM_SHA384_BYTES  ((const unsigned char *)"\x13\x02")

static int psk_find_session_cb(SSL *ssl, const unsigned char *identity,
                               size_t identity_len, SSL_SESSION **sess)
{
    SSL_SESSION *tmpsess = NULL;
    unsigned char *key;
    long key_len;
    const SSL_CIPHER *cipher = NULL;

    if (strlen(psk_identity) != identity_len
            || memcmp(psk_identity, identity, identity_len) != 0)
        return 0;

    if (psksess != NULL) {
        SSL_SESSION_up_ref(psksess);
        *sess = psksess;
        return 1;
    }

    key = OPENSSL_hexstr2buf(psk_key, &key_len);
    if (key == NULL) {
        BIO_printf(bio_err, "Could not convert PSK key '%s' to buffer\n",
                   psk_key);
        return 0;
    }

    if (key_len == EVP_MD_size(EVP_sha256()))
        cipher = SSL_CIPHER_find(ssl, TLS13_AES_128_GCM_SHA256_BYTES);
    else if(key_len == EVP_MD_size(EVP_sha384()))
        cipher = SSL_CIPHER_find(ssl, TLS13_AES_256_GCM_SHA384_BYTES);

    if (cipher == NULL) {
        /* Doesn't look like a suitable TLSv1.3 key. Ignore it */
        OPENSSL_free(key);
        return 0;
    }

    tmpsess = SSL_SESSION_new();
    if (tmpsess == NULL
            || !SSL_SESSION_set1_master_key(tmpsess, key, key_len)
            || !SSL_SESSION_set_cipher(tmpsess, cipher)
            || !SSL_SESSION_set_protocol_version(tmpsess, SSL_version(ssl))) {
        OPENSSL_free(key);
        return 0;
    }
    OPENSSL_free(key);
    *sess = tmpsess;

    return 1;
}
@@ -1974,9 +2014,12 @@ int s_server_main(int argc, char *argv[])
            ERR_print_errors(bio_err);
            goto end;
        }
        SSL_CTX_set_psk_find_session_callback(ctx, psk_find_session_cb);

    }

    if (psk_key != NULL || psksess != NULL)
        SSL_CTX_set_psk_find_session_callback(ctx, psk_find_session_cb);

    SSL_CTX_set_verify(ctx, s_server_verify, verify_callback);
    if (!SSL_CTX_set_session_id_context(ctx,
                                        (void *)&s_server_session_id_context,