Commit 6745a1ff authored by Dr. Stephen Henson's avatar Dr. Stephen Henson
Browse files

Cache maskHash parameter



Store hash algorithm used for MGF1 masks in PSS and OAEP modes in PSS and
OAEP parameter structure: this avoids the need to decode part of the ASN.1
structure every time it is used.

Reviewed-by: default avatarRich Salz <rsalz@openssl.org>
Reviewed-by: default avatarMatt Caswell <matt@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/2177)
parent 3b72dcd5
Loading
Loading
Loading
Loading
+30 −58
Original line number Diff line number Diff line
@@ -188,37 +188,36 @@ static int rsa_priv_print(BIO *bp, const EVP_PKEY *pkey, int indent,
    return do_rsa_print(bp, pkey->pkey.rsa, indent, 1);
}

/* Given an MGF1 Algorithm ID decode to an Algorithm Identifier */
static X509_ALGOR *rsa_mgf1_decode(X509_ALGOR *alg)
{
    if (alg == NULL)
        return NULL;
    if (OBJ_obj2nid(alg->algorithm) != NID_mgf1)
        return NULL;
    return ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(X509_ALGOR),
                                     alg->parameter);
}

static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg,
                                      X509_ALGOR **pmaskHash)
static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg)
{
    RSA_PSS_PARAMS *pss;

    *pmaskHash = NULL;

    pss = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_PSS_PARAMS),
                                    alg->parameter);

    if (!pss)
    if (pss == NULL)
        return NULL;

    *pmaskHash = rsa_mgf1_decode(pss->maskGenAlgorithm);
    if (pss->maskGenAlgorithm != NULL) {
        pss->maskHash = rsa_mgf1_decode(pss->maskGenAlgorithm);
        if (pss->maskHash == NULL) {
            RSA_PSS_PARAMS_free(pss);
            return NULL;
        }
    }

    return pss;
}

static int rsa_pss_param_print(BIO *bp, RSA_PSS_PARAMS *pss,
                               X509_ALGOR *maskHash, int indent)
static int rsa_pss_param_print(BIO *bp, RSA_PSS_PARAMS *pss, int indent)
{
    int rv = 0;
    if (!pss) {
@@ -252,8 +251,8 @@ static int rsa_pss_param_print(BIO *bp, RSA_PSS_PARAMS *pss,
            goto err;
        if (BIO_puts(bp, " with ") <= 0)
            goto err;
        if (maskHash) {
            if (i2a_ASN1_OBJECT(bp, maskHash->algorithm) <= 0)
        if (pss->maskHash) {
            if (i2a_ASN1_OBJECT(bp, pss->maskHash->algorithm) <= 0)
                goto err;
        } else if (BIO_puts(bp, "INVALID") <= 0)
            goto err;
@@ -296,11 +295,9 @@ static int rsa_sig_print(BIO *bp, const X509_ALGOR *sigalg,
    if (OBJ_obj2nid(sigalg->algorithm) == NID_rsassaPss) {
        int rv;
        RSA_PSS_PARAMS *pss;
        X509_ALGOR *maskHash;
        pss = rsa_pss_decode(sigalg, &maskHash);
        rv = rsa_pss_param_print(bp, pss, maskHash, indent);
        pss = rsa_pss_decode(sigalg);
        rv = rsa_pss_param_print(bp, pss, indent);
        RSA_PSS_PARAMS_free(pss);
        X509_ALGOR_free(maskHash);
        if (!rv)
            return 0;
    } else if (!sig && BIO_puts(bp, "\n") <= 0)
@@ -410,29 +407,6 @@ static const EVP_MD *rsa_algor_to_md(X509_ALGOR *alg)
    return md;
}

/* convert MGF1 algorithm ID to EVP_MD, default SHA1 */
static const EVP_MD *rsa_mgf1_to_md(X509_ALGOR *alg, X509_ALGOR *maskHash)
{
    const EVP_MD *md;
    if (!alg)
        return EVP_sha1();
    /* Check mask and lookup mask hash algorithm */
    if (OBJ_obj2nid(alg->algorithm) != NID_mgf1) {
        RSAerr(RSA_F_RSA_MGF1_TO_MD, RSA_R_UNSUPPORTED_MASK_ALGORITHM);
        return NULL;
    }
    if (!maskHash) {
        RSAerr(RSA_F_RSA_MGF1_TO_MD, RSA_R_UNSUPPORTED_MASK_PARAMETER);
        return NULL;
    }
    md = EVP_get_digestbyobj(maskHash->algorithm);
    if (md == NULL) {
        RSAerr(RSA_F_RSA_MGF1_TO_MD, RSA_R_UNKNOWN_MASK_DIGEST);
        return NULL;
    }
    return md;
}

/*
 * Convert EVP_PKEY_CTX is PSS mode into corresponding algorithm parameter,
 * suitable for setting an AlgorithmIdentifier.
@@ -497,20 +471,19 @@ static int rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkctx,
    int saltlen;
    const EVP_MD *mgf1md = NULL, *md = NULL;
    RSA_PSS_PARAMS *pss;
    X509_ALGOR *maskHash;
    /* Sanity check: make sure it is PSS */
    if (OBJ_obj2nid(sigalg->algorithm) != NID_rsassaPss) {
        RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
        return -1;
    }
    /* Decode PSS parameters */
    pss = rsa_pss_decode(sigalg, &maskHash);
    pss = rsa_pss_decode(sigalg);

    if (pss == NULL) {
        RSAerr(RSA_F_RSA_PSS_TO_CTX, RSA_R_INVALID_PSS_PARAMETERS);
        goto err;
    }
    mgf1md = rsa_mgf1_to_md(pss->maskGenAlgorithm, maskHash);
    mgf1md = rsa_algor_to_md(pss->maskHash);
    if (!mgf1md)
        goto err;
    md = rsa_algor_to_md(pss->hashAlgorithm);
@@ -568,7 +541,6 @@ static int rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkctx,

 err:
    RSA_PSS_PARAMS_free(pss);
    X509_ALGOR_free(maskHash);
    return rv;
}

@@ -674,22 +646,24 @@ static int rsa_item_sign(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
}

#ifndef OPENSSL_NO_CMS
static RSA_OAEP_PARAMS *rsa_oaep_decode(const X509_ALGOR *alg,
                                        X509_ALGOR **pmaskHash)
static RSA_OAEP_PARAMS *rsa_oaep_decode(const X509_ALGOR *alg)
{
    RSA_OAEP_PARAMS *pss;

    *pmaskHash = NULL;
    RSA_OAEP_PARAMS *oaep;

    pss = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_OAEP_PARAMS),
    oaep = ASN1_TYPE_unpack_sequence(ASN1_ITEM_rptr(RSA_OAEP_PARAMS),
                                    alg->parameter);

    if (!pss)
    if (oaep == NULL)
        return NULL;

    *pmaskHash = rsa_mgf1_decode(pss->maskGenFunc);

    return pss;
    if (oaep->maskGenFunc != NULL) {
        oaep->maskHash = rsa_mgf1_decode(oaep->maskGenFunc);
        if (oaep->maskHash == NULL) {
            RSA_OAEP_PARAMS_free(oaep);
            return NULL;
        }
    }
    return oaep;
}

static int rsa_cms_decrypt(CMS_RecipientInfo *ri)
@@ -702,7 +676,6 @@ static int rsa_cms_decrypt(CMS_RecipientInfo *ri)
    int labellen = 0;
    const EVP_MD *mgf1md = NULL, *md = NULL;
    RSA_OAEP_PARAMS *oaep;
    X509_ALGOR *maskHash;
    pkctx = CMS_RecipientInfo_get0_pkey_ctx(ri);
    if (!pkctx)
        return 0;
@@ -716,14 +689,14 @@ static int rsa_cms_decrypt(CMS_RecipientInfo *ri)
        return -1;
    }
    /* Decode OAEP parameters */
    oaep = rsa_oaep_decode(cmsalg, &maskHash);
    oaep = rsa_oaep_decode(cmsalg);

    if (oaep == NULL) {
        RSAerr(RSA_F_RSA_CMS_DECRYPT, RSA_R_INVALID_OAEP_PARAMETERS);
        goto err;
    }

    mgf1md = rsa_mgf1_to_md(oaep->maskGenFunc, maskHash);
    mgf1md = rsa_algor_to_md(oaep->maskHash);
    if (!mgf1md)
        goto err;
    md = rsa_algor_to_md(oaep->hashFunc);
@@ -760,7 +733,6 @@ static int rsa_cms_decrypt(CMS_RecipientInfo *ri)

 err:
    RSA_OAEP_PARAMS_free(oaep);
    X509_ALGOR_free(maskHash);
    return rv;
}

+26 −4
Original line number Diff line number Diff line
@@ -49,20 +49,42 @@ ASN1_SEQUENCE_cb(RSAPublicKey, rsa_cb) = {
        ASN1_SIMPLE(RSA, e, BIGNUM),
} ASN1_SEQUENCE_END_cb(RSA, RSAPublicKey)

ASN1_SEQUENCE(RSA_PSS_PARAMS) = {
/* Free up maskHash */
static int rsa_pss_cb(int operation, ASN1_VALUE **pval, const ASN1_ITEM *it,
                      void *exarg)
{
    if (operation == ASN1_OP_FREE_PRE) {
        RSA_PSS_PARAMS *pss = (RSA_PSS_PARAMS *)*pval;
        X509_ALGOR_free(pss->maskHash);
    }
    return 1;
}

ASN1_SEQUENCE_cb(RSA_PSS_PARAMS, rsa_pss_cb) = {
        ASN1_EXP_OPT(RSA_PSS_PARAMS, hashAlgorithm, X509_ALGOR,0),
        ASN1_EXP_OPT(RSA_PSS_PARAMS, maskGenAlgorithm, X509_ALGOR,1),
        ASN1_EXP_OPT(RSA_PSS_PARAMS, saltLength, ASN1_INTEGER,2),
        ASN1_EXP_OPT(RSA_PSS_PARAMS, trailerField, ASN1_INTEGER,3)
} ASN1_SEQUENCE_END(RSA_PSS_PARAMS)
} ASN1_SEQUENCE_END_cb(RSA_PSS_PARAMS, RSA_PSS_PARAMS)

IMPLEMENT_ASN1_FUNCTIONS(RSA_PSS_PARAMS)

ASN1_SEQUENCE(RSA_OAEP_PARAMS) = {
/* Free up maskHash */
static int rsa_oaep_cb(int operation, ASN1_VALUE **pval, const ASN1_ITEM *it,
                       void *exarg)
{
    if (operation == ASN1_OP_FREE_PRE) {
        RSA_OAEP_PARAMS *oaep = (RSA_OAEP_PARAMS *)*pval;
        X509_ALGOR_free(oaep->maskHash);
    }
    return 1;
}

ASN1_SEQUENCE_cb(RSA_OAEP_PARAMS, rsa_oaep_cb) = {
        ASN1_EXP_OPT(RSA_OAEP_PARAMS, hashFunc, X509_ALGOR, 0),
        ASN1_EXP_OPT(RSA_OAEP_PARAMS, maskGenFunc, X509_ALGOR, 1),
        ASN1_EXP_OPT(RSA_OAEP_PARAMS, pSourceFunc, X509_ALGOR, 2),
} ASN1_SEQUENCE_END(RSA_OAEP_PARAMS)
} ASN1_SEQUENCE_END_cb(RSA_OAEP_PARAMS, RSA_OAEP_PARAMS)

IMPLEMENT_ASN1_FUNCTIONS(RSA_OAEP_PARAMS)

+4 −0
Original line number Diff line number Diff line
@@ -239,6 +239,8 @@ typedef struct rsa_pss_params_st {
    X509_ALGOR *maskGenAlgorithm;
    ASN1_INTEGER *saltLength;
    ASN1_INTEGER *trailerField;
    /* Decoded hash algorithm from maskGenAlgorithm */
    X509_ALGOR *maskHash;
} RSA_PSS_PARAMS;

DECLARE_ASN1_FUNCTIONS(RSA_PSS_PARAMS)
@@ -247,6 +249,8 @@ typedef struct rsa_oaep_params_st {
    X509_ALGOR *hashFunc;
    X509_ALGOR *maskGenFunc;
    X509_ALGOR *pSourceFunc;
    /* Decoded hash algorithm from maskGenFunc */
    X509_ALGOR *maskHash;
} RSA_OAEP_PARAMS;

DECLARE_ASN1_FUNCTIONS(RSA_OAEP_PARAMS)