Commit c13d2a5b authored by Matt Caswell's avatar Matt Caswell
Browse files

Convert ServerKeyExchange construction to WPACKET



Reviewed-by: default avatarRich Salz <rsalz@openssl.org>
parent 1ff84340
Loading
Loading
Loading
Loading
+97 −120
Original line number Diff line number Diff line
@@ -1591,7 +1591,6 @@ int tls_construct_server_key_exchange(SSL *s)
{
#ifndef OPENSSL_NO_DH
    EVP_PKEY *pkdh = NULL;
    int j;
#endif
#ifndef OPENSSL_NO_EC
    unsigned char *encodedPoint = NULL;
@@ -1600,36 +1599,30 @@ int tls_construct_server_key_exchange(SSL *s)
#endif
    EVP_PKEY *pkey;
    const EVP_MD *md = NULL;
    unsigned char *p, *d;
    int al, i;
    int al = SSL_AD_INTERNAL_ERROR, i;
    unsigned long type;
    int n;
    const BIGNUM *r[4];
    int nr[4], kn;
    BUF_MEM *buf;
    EVP_MD_CTX *md_ctx = EVP_MD_CTX_new();
    WPACKET pkt;
    size_t paramlen, paramoffset;

    if (!WPACKET_init(&pkt, s->init_buf)
            || !ssl_set_handshake_header2(s, &pkt,
                                          SSL3_MT_SERVER_KEY_EXCHANGE)
            || !WPACKET_get_total_written(&pkt, &paramoffset)) {
        SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_HELLO, ERR_R_INTERNAL_ERROR);
        goto f_err;
    }

    if (md_ctx == NULL) {
        SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
        al = SSL_AD_INTERNAL_ERROR;
        goto f_err;
    }

    type = s->s3->tmp.new_cipher->algorithm_mkey;

    buf = s->init_buf;

    r[0] = r[1] = r[2] = r[3] = NULL;
    n = 0;
#ifndef OPENSSL_NO_PSK
    if (type & SSL_PSK) {
        /*
         * reserve size for record length and PSK identity hint
         */
        n += 2;
        if (s->cert->psk_identity_hint)
            n += strlen(s->cert->psk_identity_hint);
    }
    /* Plain PSK or RSAPSK nothing to do */
    if (type & (SSL_kPSK | SSL_kRSAPSK)) {
    } else
@@ -1646,7 +1639,6 @@ int tls_construct_server_key_exchange(SSL *s)
            pkdh = EVP_PKEY_new();
            if (pkdh == NULL || dhp == NULL) {
                DH_free(dhp);
                al = SSL_AD_INTERNAL_ERROR;
                SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
                       ERR_R_INTERNAL_ERROR);
                goto f_err;
@@ -1660,7 +1652,6 @@ int tls_construct_server_key_exchange(SSL *s)
            DH *dhp = s->cert->dh_tmp_cb(s, 0, 1024);
            pkdh = ssl_dh_to_pkey(dhp);
            if (pkdh == NULL) {
                al = SSL_AD_INTERNAL_ERROR;
                SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
                       ERR_R_INTERNAL_ERROR);
                goto f_err;
@@ -1723,7 +1714,6 @@ int tls_construct_server_key_exchange(SSL *s)
        s->s3->tmp.pkey = ssl_generate_pkey_curve(curve_id);
        /* Generate a new key for this curve */
        if (s->s3->tmp.pkey == NULL) {
            al = SSL_AD_INTERNAL_ERROR;
            SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE, ERR_R_EVP_LIB);
            goto f_err;
        }
@@ -1736,13 +1726,6 @@ int tls_construct_server_key_exchange(SSL *s)
            goto err;
        }

        /*
         * We only support named (not generic) curves in ECDH ephemeral key
         * exchanges. In this situation, we need four additional bytes to
         * encode the entire ServerECDHParams structure.
         */
        n += 4 + encodedlen;

        /*
         * We'll generate the serverKeyExchange message explicitly so we
         * can set these to NULLs
@@ -1774,25 +1757,6 @@ int tls_construct_server_key_exchange(SSL *s)
               SSL_R_UNKNOWN_KEY_EXCHANGE_TYPE);
        goto f_err;
    }
    for (i = 0; i < 4 && r[i] != NULL; i++) {
        nr[i] = BN_num_bytes(r[i]);
#ifndef OPENSSL_NO_SRP
        if ((i == 2) && (type & SSL_kSRP))
            n += 1 + nr[i];
        else
#endif
#ifndef OPENSSL_NO_DH
        /*-
         * for interoperability with some versions of the Microsoft TLS
         * stack, we need to zero pad the DHE pub key to the same length
         * as the prime, so use the length of the prime here
         */
        if ((i == 2) && (type & (SSL_kDHE | SSL_kDHEPSK)))
            n += 2 + nr[0];
        else
#endif
            n += 2 + nr[i];
    }

    if (!(s->s3->tmp.new_cipher->algorithm_auth & (SSL_aNULL | SSL_aSRP))
        && !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_PSK)) {
@@ -1801,53 +1765,46 @@ int tls_construct_server_key_exchange(SSL *s)
            al = SSL_AD_DECODE_ERROR;
            goto f_err;
        }
        kn = EVP_PKEY_size(pkey);
        /* Allow space for signature algorithm */
        if (SSL_USE_SIGALGS(s))
            kn += 2;
        /* Allow space for signature length */
        kn += 2;
    } else {
        pkey = NULL;
        kn = 0;
    }

    if (!BUF_MEM_grow_clean(buf, n + SSL_HM_HEADER_LENGTH(s) + kn)) {
        SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE, ERR_LIB_BUF);
        goto err;
    }
    d = p = ssl_handshake_start(s);

#ifndef OPENSSL_NO_PSK
    if (type & SSL_PSK) {
        /* copy PSK identity hint */
        if (s->cert->psk_identity_hint) {
            size_t len = strlen(s->cert->psk_identity_hint);
            if (len > PSK_MAX_IDENTITY_LEN) {
        size_t len = (s->cert->psk_identity_hint == NULL)
                        ? 0 : strlen(s->cert->psk_identity_hint);

        /*
                 * Should not happen - we already checked this when we set
                 * the identity hint
         * It should not happen that len > PSK_MAX_IDENTITY_LEN - we already
         * checked this when we set the identity hint - but just in case
         */
        if (len > PSK_MAX_IDENTITY_LEN
                || !WPACKET_sub_memcpy_u16(&pkt, s->cert->psk_identity_hint,
                                           len)) {
            SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
                   ERR_R_INTERNAL_ERROR);
                goto err;
            }
            s2n(len, p);
            memcpy(p, s->cert->psk_identity_hint, len);
            p += len;
        } else {
            s2n(0, p);
            goto f_err;
        }
    }
#endif

    for (i = 0; i < 4 && r[i] != NULL; i++) {
        unsigned char *binval;
        int res;

#ifndef OPENSSL_NO_SRP
        if ((i == 2) && (type & SSL_kSRP)) {
            *p = nr[i];
            p++;
            res = WPACKET_start_sub_packet_u8(&pkt);
        } else
#endif
            res = WPACKET_start_sub_packet_u16(&pkt);

        if (!res) {
            SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
                   ERR_R_INTERNAL_ERROR);
            goto f_err;
        }

#ifndef OPENSSL_NO_DH
        /*-
         * for interoperability with some versions of the Microsoft TLS
@@ -1855,38 +1812,45 @@ int tls_construct_server_key_exchange(SSL *s)
         * as the prime
         */
        if ((i == 2) && (type & (SSL_kDHE | SSL_kDHEPSK))) {
            s2n(nr[0], p);
            for (j = 0; j < (nr[0] - nr[2]); ++j) {
                *p = 0;
                ++p;
            size_t len = BN_num_bytes(r[0]) - BN_num_bytes(r[2]);
            if (len > 0) {
                if (!WPACKET_allocate_bytes(&pkt, len, &binval)) {
                    SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
                           ERR_R_INTERNAL_ERROR);
                    goto f_err;
                }
                memset(binval, 0, len);
            }
        }
        } else
#endif
            s2n(nr[i], p);
        BN_bn2bin(r[i], p);
        p += nr[i];
        if (!WPACKET_allocate_bytes(&pkt, BN_num_bytes(r[i]), &binval)
                || !WPACKET_close(&pkt)) {
            SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
                   ERR_R_INTERNAL_ERROR);
            goto f_err;
        }

        BN_bn2bin(r[i], binval);
    }

#ifndef OPENSSL_NO_EC
    if (type & (SSL_kECDHE | SSL_kECDHEPSK)) {
        /*
         * XXX: For now, we only support named (not generic) curves. In
         * this situation, the serverKeyExchange message has: [1 byte
         * CurveType], [2 byte CurveName] [1 byte length of encoded
         * point], followed by the actual encoded point itself
         */
        *p = NAMED_CURVE_TYPE;
        p += 1;
        *p = 0;
        p += 1;
        *p = curve_id;
        p += 1;
        *p = encodedlen;
        p += 1;
        memcpy(p, encodedPoint, encodedlen);
         * We only support named (not generic) curves. In this situation, the
         * ServerKeyExchange message has: [1 byte CurveType], [2 byte CurveName]
         * [1 byte length of encoded point], followed by the actual encoded
         * point itself
         */
        if (!WPACKET_put_bytes_u8(&pkt, NAMED_CURVE_TYPE)
                || !WPACKET_put_bytes_u8(&pkt, 0)
                || !WPACKET_put_bytes_u8(&pkt, curve_id)
                || !WPACKET_sub_memcpy_u8(&pkt, encodedPoint, encodedlen)) {
            SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
                   ERR_R_INTERNAL_ERROR);
            goto f_err;
        }
        OPENSSL_free(encodedPoint);
        encodedPoint = NULL;
        p += encodedlen;
    }
#endif

@@ -1897,36 +1861,49 @@ int tls_construct_server_key_exchange(SSL *s)
         * points to the space at the end.
         */
        if (md) {
            unsigned char *sigbytes1, *sigbytes2;
            unsigned int siglen;

            /* Get length of the parameters we have written above */
            if (!WPACKET_get_length(&pkt, &paramlen)) {
                SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
                       ERR_R_INTERNAL_ERROR);
                goto f_err;
            }
            /* send signature algorithm */
            if (SSL_USE_SIGALGS(s)) {
                if (!tls12_get_sigandhash_old(p, pkey, md)) {
                if (!tls12_get_sigandhash(&pkt, pkey, md)) {
                    /* Should never happen */
                    al = SSL_AD_INTERNAL_ERROR;
                    SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
                           ERR_R_INTERNAL_ERROR);
                    goto f_err;
                }
                p += 2;
            }
#ifdef SSL_DEBUG
            fprintf(stderr, "Using hash %s\n", EVP_MD_name(md));
#endif
            if (EVP_SignInit_ex(md_ctx, md, NULL) <= 0
            /*
             * Create the signature. We don't know the actual length of the sig
             * until after we've created it, so we reserve enough bytes for it
             * up front, and then properly allocate them in the WPACKET
             * afterwards.
             */
            if (!WPACKET_sub_reserve_bytes_u16(&pkt, EVP_PKEY_size(pkey),
                                               &sigbytes1)
                    || EVP_SignInit_ex(md_ctx, md, NULL) <= 0
                    || EVP_SignUpdate(md_ctx, &(s->s3->client_random[0]),
                                      SSL3_RANDOM_SIZE) <= 0
                    || EVP_SignUpdate(md_ctx, &(s->s3->server_random[0]),
                                      SSL3_RANDOM_SIZE) <= 0
                || EVP_SignUpdate(md_ctx, d, n) <= 0
                || EVP_SignFinal(md_ctx, &(p[2]),
                                 (unsigned int *)&i, pkey) <= 0) {
                SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE, ERR_LIB_EVP);
                al = SSL_AD_INTERNAL_ERROR;
                    || EVP_SignUpdate(md_ctx, s->init_buf->data + paramoffset,
                                      paramlen) <= 0
                    || EVP_SignFinal(md_ctx, sigbytes1, &siglen, pkey) <= 0
                    || !WPACKET_sub_allocate_bytes_u16(&pkt, siglen, &sigbytes2)
                    || sigbytes1 != sigbytes2) {
                SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
                       ERR_R_INTERNAL_ERROR);
                goto f_err;
            }
            s2n(i, p);
            n += i + 2;
            if (SSL_USE_SIGALGS(s))
                n += 2;
        } else {
            /* Is this error check actually needed? */
            al = SSL_AD_HANDSHAKE_FAILURE;
@@ -1936,8 +1913,7 @@ int tls_construct_server_key_exchange(SSL *s)
        }
    }

    if (!ssl_set_handshake_header(s, SSL3_MT_SERVER_KEY_EXCHANGE, n)) {
        al = SSL_AD_HANDSHAKE_FAILURE;
    if (!ssl_close_construct_packet(s, &pkt)) {
        SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
        goto f_err;
    }
@@ -1955,6 +1931,7 @@ int tls_construct_server_key_exchange(SSL *s)
#endif
    EVP_MD_CTX_free(md_ctx);
    ossl_statem_set_error(s);
    WPACKET_cleanup(&pkt);
    return 0;
}