Commit 13c0ec4a authored by Matt Caswell's avatar Matt Caswell
Browse files

Split out CKE construction PSK pre-amble and RSA into a separate function



The tls_construct_client_key_exchange() function is too long. This splits
out the construction of the PSK pre-amble into a separate function as well
as the RSA construction.

Reviewed-by: default avatarRichard Levitte <levitte@openssl.org>
parent 0bce0b02
Loading
Loading
Loading
Loading
+178 −150
Original line number Diff line number Diff line
@@ -2012,38 +2012,27 @@ MSG_PROCESS_RETURN tls_process_server_done(SSL *s, PACKET *pkt)
        return MSG_PROCESS_FINISHED_READING;
}

int tls_construct_client_key_exchange(SSL *s)
static int tls_construct_cke_psk_preamble(SSL *s, unsigned char **p,
                                          size_t *pskhdrlen, int *al)
{
    unsigned char *p;
    int n;
#ifndef OPENSSL_NO_PSK
    size_t pskhdrlen = 0;
#endif
    unsigned long alg_k;

    alg_k = s->s3->tmp.new_cipher->algorithm_mkey;

    p = ssl_handshake_start(s);


#ifndef OPENSSL_NO_PSK
    if (alg_k & SSL_PSK) {
        int psk_err = 1;
    int ret = 0;
    /*
     * The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes to return a
     * \0-terminated identity. The last byte is for us for simulating
     * strnlen.
     */
    char identity[PSK_MAX_IDENTITY_LEN + 1];
        size_t identitylen;
    size_t identitylen = 0;
    unsigned char psk[PSK_MAX_PSK_LEN];
        unsigned char *tmppsk;
        char *tmpidentity;
        size_t psklen;
    unsigned char *tmppsk = NULL;
    char *tmpidentity = NULL;
    size_t psklen = 0;

    if (s->psk_client_callback == NULL) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
               SSL_R_PSK_NO_CLIENT_CB);
        *al = SSL_AD_INTERNAL_ERROR;
        goto err;
    }

@@ -2056,58 +2045,62 @@ int tls_construct_client_key_exchange(SSL *s)
    if (psklen > PSK_MAX_PSK_LEN) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
               ERR_R_INTERNAL_ERROR);
            goto psk_err;
        *al = SSL_AD_HANDSHAKE_FAILURE;
        goto err;
    } else if (psklen == 0) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
               SSL_R_PSK_IDENTITY_NOT_FOUND);
            goto psk_err;
        *al = SSL_AD_HANDSHAKE_FAILURE;
        goto err;
    }

    identitylen = strlen(identity);
    if (identitylen > PSK_MAX_IDENTITY_LEN) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
               ERR_R_INTERNAL_ERROR);
            goto psk_err;
        *al = SSL_AD_HANDSHAKE_FAILURE;
        goto err;
    }

    tmppsk = OPENSSL_memdup(psk, psklen);
    tmpidentity = OPENSSL_strdup(identity);
    if (tmppsk == NULL || tmpidentity == NULL) {
            OPENSSL_cleanse(identity, sizeof(identity));
            OPENSSL_cleanse(psk, psklen);
            OPENSSL_clear_free(tmppsk, psklen);
            OPENSSL_clear_free(tmpidentity, identitylen);
            goto memerr;
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
        *al = SSL_AD_INTERNAL_ERROR;
        goto err;
    }

    OPENSSL_free(s->s3->tmp.psk);
    s->s3->tmp.psk = tmppsk;
    s->s3->tmp.psklen = psklen;
    tmppsk = NULL;
    OPENSSL_free(s->session->psk_identity);
    s->session->psk_identity = tmpidentity;
        s2n(identitylen, p);
        memcpy(p, identity, identitylen);
        pskhdrlen = 2 + identitylen;
        p += identitylen;
        psk_err = 0;
psk_err:
    tmpidentity = NULL;
    s2n(identitylen, *p);
    memcpy(*p, identity, identitylen);
    *pskhdrlen = 2 + identitylen;
    *p += identitylen;

    ret = 1;

 err:
    OPENSSL_cleanse(psk, psklen);
    OPENSSL_cleanse(identity, sizeof(identity));
        if (psk_err != 0) {
            ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
            goto err;
        }
    }
    if (alg_k & SSL_kPSK) {
        n = 0;
    } else
#endif
    OPENSSL_clear_free(tmppsk, psklen);
    OPENSSL_clear_free(tmpidentity, identitylen);

    /* Fool emacs indentation */
    if (0) {
    return ret;
#else
    SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
    *al = SSL_AD_INTERNAL_ERROR;
    return 0;
#endif
}

static int tls_construct_cke_rsa(SSL *s, unsigned char **p, int *len, int *al)
{
#ifndef OPENSSL_NO_RSA
    else if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
    unsigned char *q;
    EVP_PKEY *pkey = NULL;
    EVP_PKEY_CTX *pctx = NULL;
@@ -2121,68 +2114,103 @@ psk_err:
         */
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
               ERR_R_INTERNAL_ERROR);
            goto err;
        return 0;
    }

    pkey = X509_get0_pubkey(s->session->peer);
    if (EVP_PKEY_get0_RSA(pkey) == NULL) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
               ERR_R_INTERNAL_ERROR);
            goto err;
        return 0;
    }

    pmslen = SSL_MAX_MASTER_KEY_LENGTH;
    pms = OPENSSL_malloc(pmslen);
        if (pms == NULL)
            goto memerr;
    if (pms == NULL) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
               ERR_R_MALLOC_FAILURE);
        *al = SSL_AD_INTERNAL_ERROR;
        return 0;
    }

    pms[0] = s->client_version >> 8;
    pms[1] = s->client_version & 0xff;
    if (RAND_bytes(pms + 2, pmslen - 2) <= 0) {
            OPENSSL_clear_free(pms, pmslen);
        goto err;
    }

        q = p;
    q = *p;
    /* Fix buf for TLS and beyond */
    if (s->version > SSL3_VERSION)
            p += 2;
        *p += 2;
    pctx = EVP_PKEY_CTX_new(pkey, NULL);
    if (pctx == NULL || EVP_PKEY_encrypt_init(pctx) <= 0
        || EVP_PKEY_encrypt(pctx, NULL, &enclen, pms, pmslen) <= 0) {
            OPENSSL_clear_free(pms, pmslen);
            EVP_PKEY_CTX_free(pctx);
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
               ERR_R_EVP_LIB);
        goto err;
    }
        if (EVP_PKEY_encrypt(pctx, p, &enclen, pms, pmslen) <= 0) {
            OPENSSL_clear_free(pms, pmslen);
            EVP_PKEY_CTX_free(pctx);
    if (EVP_PKEY_encrypt(pctx, *p, &enclen, pms, pmslen) <= 0) {
        SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
               SSL_R_BAD_RSA_ENCRYPT);
        goto err;
    }
        n = enclen;
    *len = enclen;
    EVP_PKEY_CTX_free(pctx);
    pctx = NULL;
# ifdef PKCS1_CHECK
    if (s->options & SSL_OP_PKCS1_CHECK_1)
            p[1]++;
        (*p)[1]++;
    if (s->options & SSL_OP_PKCS1_CHECK_2)
        tmp_buf[0] = 0x70;
# endif

    /* Fix buf for TLS and beyond */
    if (s->version > SSL3_VERSION) {
            s2n(n, q);
            n += 2;
        s2n(*len, q);
        *len += 2;
    }

    s->s3->tmp.pms = pms;
    s->s3->tmp.pmslen = pmslen;
    }

    return 1;
 err:
    OPENSSL_clear_free(pms, pmslen);
    EVP_PKEY_CTX_free(pctx);

    return 0;
#else
    SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
    *al = SSL_AD_INTERNAL_ERROR;
    return 0;
#endif
}

int tls_construct_client_key_exchange(SSL *s)
{
    unsigned char *p;
    int n;
    size_t pskhdrlen = 0;
    unsigned long alg_k;
    int al = -1;

    alg_k = s->s3->tmp.new_cipher->algorithm_mkey;

    p = ssl_handshake_start(s);



    if ((alg_k & SSL_PSK)
            && !tls_construct_cke_psk_preamble(s, &p, &pskhdrlen, &al))
        goto err;

    if (alg_k & SSL_kPSK) {
        n = 0;
    } else if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
        if (!tls_construct_cke_rsa(s, &p, &n, &al))
            goto err;
    }
#ifndef OPENSSL_NO_DH
    else if (alg_k & (SSL_kDHE | SSL_kDHEPSK)) {
        DH *dh_clnt = NULL;
@@ -2421,9 +2449,7 @@ psk_err:
        goto err;
    }

#ifndef OPENSSL_NO_PSK
    n += pskhdrlen;
#endif

    if (!ssl_set_handshake_header(s, SSL3_MT_CLIENT_KEY_EXCHANGE, n)) {
        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
@@ -2433,9 +2459,11 @@ psk_err:

    return 1;
 memerr:
    ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
    SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
    al = SSL_AD_INTERNAL_ERROR;
 err:
    if (al != -1)
        ssl3_send_alert(s, SSL3_AL_FATAL, al);
    OPENSSL_clear_free(s->s3->tmp.pms, s->s3->tmp.pmslen);
    s->s3->tmp.pms = NULL;
#ifndef OPENSSL_NO_PSK