Commit 0bce0b02 authored by Matt Caswell's avatar Matt Caswell
Browse files

Narrow the scope of local variables in tls_construct_client_key_exchange()



This is in preparation for splitting up this over long function.

Reviewed-by: default avatarRichard Levitte <levitte@openssl.org>
parent d4450e4b
Loading
Loading
Loading
Loading
+58 −52
Original line number Diff line number Diff line
@@ -2020,20 +2020,7 @@ int tls_construct_client_key_exchange(SSL *s)
    size_t pskhdrlen = 0;
#endif
    unsigned long alg_k;
#ifndef OPENSSL_NO_RSA
    unsigned char *q;
    EVP_PKEY *pkey = NULL;
    EVP_PKEY_CTX *pctx = NULL;
#endif
#if !defined(OPENSSL_NO_EC) || !defined(OPENSSL_NO_DH)
    EVP_PKEY *ckey = NULL, *skey = NULL;
#endif
#ifndef OPENSSL_NO_EC
    unsigned char *encodedPoint = NULL;
    int encoded_pt_len = 0;
#endif
    unsigned char *pms = NULL;
    size_t pmslen = 0;

    alg_k = s->s3->tmp.new_cipher->algorithm_mkey;

    p = ssl_handshake_start(s);
@@ -2050,6 +2037,8 @@ int tls_construct_client_key_exchange(SSL *s)
        char identity[PSK_MAX_IDENTITY_LEN + 1];
        size_t identitylen;
        unsigned char psk[PSK_MAX_PSK_LEN];
        unsigned char *tmppsk;
        char *tmpidentity;
        size_t psklen;

        if (s->psk_client_callback == NULL) {
@@ -2073,35 +2062,36 @@ int tls_construct_client_key_exchange(SSL *s)
                   SSL_R_PSK_IDENTITY_NOT_FOUND);
            goto psk_err;
        }
        OPENSSL_free(s->s3->tmp.psk);
        s->s3->tmp.psk = OPENSSL_memdup(psk, psklen);
        OPENSSL_cleanse(psk, psklen);

        if (s->s3->tmp.psk == NULL) {
            OPENSSL_cleanse(identity, sizeof(identity));
            goto memerr;
        }

        s->s3->tmp.psklen = psklen;
        identitylen = strlen(identity);
        if (identitylen > PSK_MAX_IDENTITY_LEN) {
            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
                   ERR_R_INTERNAL_ERROR);
            goto psk_err;
        }
        OPENSSL_free(s->session->psk_identity);
        s->session->psk_identity = OPENSSL_strdup(identity);
        if (s->session->psk_identity == NULL) {

        tmppsk = OPENSSL_memdup(psk, psklen);
        tmpidentity = OPENSSL_strdup(identity);
        if (tmppsk == NULL || tmpidentity == NULL) {
            OPENSSL_cleanse(identity, sizeof(identity));
            OPENSSL_cleanse(psk, psklen);
            OPENSSL_clear_free(tmppsk, psklen);
            OPENSSL_clear_free(tmpidentity, identitylen);
            goto memerr;
        }

        OPENSSL_free(s->s3->tmp.psk);
        s->s3->tmp.psk = tmppsk;
        s->s3->tmp.psklen = psklen;
        OPENSSL_free(s->session->psk_identity);
        s->session->psk_identity = tmpidentity;
        s2n(identitylen, p);
        memcpy(p, identity, identitylen);
        pskhdrlen = 2 + identitylen;
        p += identitylen;
        psk_err = 0;
psk_err:
        OPENSSL_cleanse(psk, psklen);
        OPENSSL_cleanse(identity, sizeof(identity));
        if (psk_err != 0) {
            ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
@@ -2118,11 +2108,12 @@ psk_err:
    }
#ifndef OPENSSL_NO_RSA
    else if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
        unsigned char *q;
        EVP_PKEY *pkey = NULL;
        EVP_PKEY_CTX *pctx = NULL;
        size_t enclen;
        pmslen = SSL_MAX_MASTER_KEY_LENGTH;
        pms = OPENSSL_malloc(pmslen);
        if (pms == NULL)
            goto memerr;
        unsigned char *pms = NULL;
        size_t pmslen = 0;

        if (s->session->peer == NULL) {
            /*
@@ -2140,10 +2131,17 @@ psk_err:
            goto err;
        }

        pmslen = SSL_MAX_MASTER_KEY_LENGTH;
        pms = OPENSSL_malloc(pmslen);
        if (pms == NULL)
            goto memerr;

        pms[0] = s->client_version >> 8;
        pms[1] = s->client_version & 0xff;
        if (RAND_bytes(pms + 2, pmslen - 2) <= 0)
        if (RAND_bytes(pms + 2, pmslen - 2) <= 0) {
            OPENSSL_clear_free(pms, pmslen);
            goto err;
        }

        q = p;
        /* Fix buf for TLS and beyond */
@@ -2152,11 +2150,15 @@ psk_err:
        pctx = EVP_PKEY_CTX_new(pkey, NULL);
        if (pctx == NULL || EVP_PKEY_encrypt_init(pctx) <= 0
            || EVP_PKEY_encrypt(pctx, NULL, &enclen, pms, pmslen) <= 0) {
            OPENSSL_clear_free(pms, pmslen);
            EVP_PKEY_CTX_free(pctx);
            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
                   ERR_R_EVP_LIB);
            goto err;
        }
        if (EVP_PKEY_encrypt(pctx, p, &enclen, pms, pmslen) <= 0) {
            OPENSSL_clear_free(pms, pmslen);
            EVP_PKEY_CTX_free(pctx);
            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
                   SSL_R_BAD_RSA_ENCRYPT);
            goto err;
@@ -2176,12 +2178,17 @@ psk_err:
            s2n(n, q);
            n += 2;
        }

        s->s3->tmp.pms = pms;
        s->s3->tmp.pmslen = pmslen;
    }
#endif
#ifndef OPENSSL_NO_DH
    else if (alg_k & (SSL_kDHE | SSL_kDHEPSK)) {
        DH *dh_clnt = NULL;
        const BIGNUM *pub_key;
        EVP_PKEY *ckey = NULL, *skey = NULL;

        skey = s->s3->peer_tmp;
        if (skey == NULL) {
            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
@@ -2194,6 +2201,7 @@ psk_err:
        if (dh_clnt == NULL || ssl_derive(s, ckey, skey) == 0) {
            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
                   ERR_R_INTERNAL_ERROR);
            EVP_PKEY_free(ckey);
            goto err;
        }

@@ -2211,6 +2219,9 @@ psk_err:

#ifndef OPENSSL_NO_EC
    else if (alg_k & (SSL_kECDHE | SSL_kECDHEPSK)) {
        unsigned char *encodedPoint = NULL;
        int encoded_pt_len = 0;
        EVP_PKEY *ckey = NULL, *skey = NULL;

        skey = s->s3->peer_tmp;
        if ((skey == NULL) || EVP_PKEY_get0_EC_KEY(skey) == NULL) {
@@ -2223,6 +2234,7 @@ psk_err:

        if (ssl_derive(s, ckey, skey) == 0) {
            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_EVP_LIB);
            EVP_PKEY_free(ckey);
            goto err;
        }

@@ -2233,6 +2245,7 @@ psk_err:

        if (encoded_pt_len == 0) {
            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_EC_LIB);
            EVP_PKEY_free(ckey);
            goto err;
        }

@@ -2263,15 +2276,13 @@ psk_err:
        unsigned char shared_ukm[32], tmp[256];
        EVP_MD_CTX *ukm_hash;
        int dgst_nid = NID_id_GostR3411_94;
        unsigned char *pms = NULL;
        size_t pmslen = 0;

        if ((s->s3->tmp.new_cipher->algorithm_auth & SSL_aGOST12) != 0)
            dgst_nid = NID_id_GostR3411_2012_256;


        pmslen = 32;
        pms = OPENSSL_malloc(pmslen);
        if (pms == NULL)
            goto memerr;

        /*
         * Get server sertificate PKEY and create ctx from it
         */
@@ -2295,12 +2306,17 @@ psk_err:
         */

        /* Otherwise, generate ephemeral key pair */
        pmslen = 32;
        pms = OPENSSL_malloc(pmslen);
        if (pms == NULL)
            goto memerr;

        if (pkey_ctx == NULL
                || EVP_PKEY_encrypt_init(pkey_ctx) <= 0
                /* Generate session key */
                || RAND_bytes(pms, pmslen) <= 0) {
            EVP_PKEY_CTX_free(pkey_ctx);
            OPENSSL_clear_free(pms, pmslen);
            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
                   ERR_R_INTERNAL_ERROR);
            goto err;
@@ -2331,6 +2347,7 @@ psk_err:
                                    SSL3_RANDOM_SIZE) <= 0
                || EVP_DigestFinal_ex(ukm_hash, shared_ukm, &md_len) <= 0) {
            EVP_MD_CTX_free(ukm_hash);
            OPENSSL_clear_free(pms, pmslen);
            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
                   ERR_R_INTERNAL_ERROR);
            goto err;
@@ -2339,6 +2356,7 @@ psk_err:
        if (EVP_PKEY_CTX_ctrl
            (pkey_ctx, -1, EVP_PKEY_OP_ENCRYPT, EVP_PKEY_CTRL_SET_IV, 8,
             shared_ukm) < 0) {
            OPENSSL_clear_free(pms, pmslen);
            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
                   SSL_R_LIBRARY_BUG);
            goto err;
@@ -2350,6 +2368,7 @@ psk_err:
        *(p++) = V_ASN1_SEQUENCE | V_ASN1_CONSTRUCTED;
        msglen = 255;
        if (EVP_PKEY_encrypt(pkey_ctx, tmp, &msglen, pms, pmslen) <= 0) {
            OPENSSL_clear_free(pms, pmslen);
            SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
                   SSL_R_LIBRARY_BUG);
            goto err;
@@ -2370,7 +2389,8 @@ psk_err:
            s->s3->flags |= TLS1_FLAGS_SKIP_CERT_VERIFY;
        }
        EVP_PKEY_CTX_free(pkey_ctx);

        s->s3->tmp.pms = pms;
        s->s3->tmp.pmslen = pmslen;
    }
#endif
#ifndef OPENSSL_NO_SRP
@@ -2411,27 +2431,13 @@ psk_err:
        goto err;
    }

    if (pms != NULL) {
        s->s3->tmp.pms = pms;
        s->s3->tmp.pmslen = pmslen;
    }

    return 1;
 memerr:
    ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
    SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
 err:
    OPENSSL_clear_free(pms, pmslen);
    OPENSSL_clear_free(s->s3->tmp.pms, s->s3->tmp.pmslen);
    s->s3->tmp.pms = NULL;
#ifndef OPENSSL_NO_RSA
    EVP_PKEY_CTX_free(pctx);
#endif
#ifndef OPENSSL_NO_EC
    OPENSSL_free(encodedPoint);
#endif
#if !defined(OPENSSL_NO_EC) || !defined(OPENSSL_NO_DH)
    EVP_PKEY_free(ckey);
#endif
#ifndef OPENSSL_NO_PSK
    OPENSSL_clear_free(s->s3->tmp.psk, s->s3->tmp.psklen);
    s->s3->tmp.psk = NULL;