Commit 58e754fc authored by Rich Salz's avatar Rich Salz
Browse files

Convert modular exponentiation tests to new framework



Updated due to test framework changes
Updates after code review
Missed some checks

Reviewed-by: default avatarRichard Levitte <levitte@openssl.org>
Reviewed-by: default avatarRich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/3269)
parent 975922fd
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -153,7 +153,7 @@ INCLUDE_MAIN___test_libtestutil_OLB = /INCLUDE=MAIN

  SOURCE[exptest]=exptest.c
  INCLUDE[exptest]=../include
  DEPEND[exptest]=../libcrypto
  DEPEND[exptest]=../libcrypto libtestutil.a

  SOURCE[rsa_test]=rsa_test.c
  INCLUDE[rsa_test]=.. ../include
+109 −160
Original line number Diff line number Diff line
/*
 * Copyright 1995-2016 The OpenSSL Project Authors. All Rights Reserved.
 * Copyright 1995-2017 The OpenSSL Project Authors. All Rights Reserved.
 *
 * Licensed under the OpenSSL license (the "License").  You may not use
 * this file except in compliance with the License.  You can obtain a copy
@@ -18,34 +18,40 @@
#include <openssl/rand.h>
#include <openssl/err.h>

#include "testutil.h"

#define NUM_BITS        (BN_BITS2 * 4)

static const char rnd_seed[] =
    "string to make the random number generator think it has entropy";
#define BN_print_var(v) bn_print_var(#v, v)

static void bn_print_var(const char *var, const BIGNUM *bn)
{
    fprintf(stderr, "%s (%3d) = ", var, BN_num_bits(bn));
    BN_print_fp(stderr, bn);
    fprintf(stderr, "\n");
}

/*
 * Test that r == 0 in test_exp_mod_zero(). Returns one on success,
 * returns zero and prints debug output otherwise.
 */
static int a_is_zero_mod_one(const char *method, const BIGNUM *r,
                             const BIGNUM *a) {
                             const BIGNUM *a)
{
    if (!BN_is_zero(r)) {
        fprintf(stderr, "%s failed:\n", method);
        fprintf(stderr, "a ** 0 mod 1 = r (should be 0)\n");
        fprintf(stderr, "a = ");
        BN_print_fp(stderr, a);
        fprintf(stderr, "\nr = ");
        BN_print_fp(stderr, r);
        fprintf(stderr, "\n");
        BN_print_var(a);
        BN_print_var(r);
        return 0;
    }
    return 1;
}

/*
 * test_exp_mod_zero tests that x**0 mod 1 == 0. It returns zero on success.
 * test_mod_exp_zero tests that x**0 mod 1 == 0. It returns zero on success.
 */
static int test_exp_mod_zero()
static int test_mod_exp_zero()
{
    BIGNUM *a = NULL, *p = NULL, *m = NULL;
    BIGNUM *r = NULL;
@@ -53,77 +59,64 @@ static int test_exp_mod_zero()
    BN_CTX *ctx = BN_CTX_new();
    int ret = 1, failed = 0;

    m = BN_new();
    if (!m)
    if (!TEST_ptr(m = BN_new())
        || !TEST_ptr(a = BN_new())
        || !TEST_ptr(p = BN_new())
        || !TEST_ptr(r = BN_new()))
        goto err;
    BN_one(m);

    a = BN_new();
    if (!a)
        goto err;
    BN_one(m);
    BN_one(a);

    p = BN_new();
    if (!p)
        goto err;
    BN_zero(p);

    r = BN_new();
    if (!r)
        goto err;

    if (!BN_rand(a, 1024, BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ANY))
    if (!TEST_true(BN_rand(a, 1024, BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ANY)))
        goto err;

    if (!BN_mod_exp(r, a, p, m, ctx))
    if (!TEST_true(BN_mod_exp(r, a, p, m, ctx)))
        goto err;

    if (!a_is_zero_mod_one("BN_mod_exp", r, a))
    if (!TEST_true(a_is_zero_mod_one("BN_mod_exp", r, a)))
        failed = 1;

    if (!BN_mod_exp_recp(r, a, p, m, ctx))
    if (!TEST_true(BN_mod_exp_recp(r, a, p, m, ctx)))
        goto err;

    if (!a_is_zero_mod_one("BN_mod_exp_recp", r, a))
    if (!TEST_true(a_is_zero_mod_one("BN_mod_exp_recp", r, a)))
        failed = 1;

    if (!BN_mod_exp_simple(r, a, p, m, ctx))
    if (!TEST_true(BN_mod_exp_simple(r, a, p, m, ctx)))
        goto err;

    if (!a_is_zero_mod_one("BN_mod_exp_simple", r, a))
    if (!TEST_true(a_is_zero_mod_one("BN_mod_exp_simple", r, a)))
        failed = 1;

    if (!BN_mod_exp_mont(r, a, p, m, ctx, NULL))
    if (!TEST_true(BN_mod_exp_mont(r, a, p, m, ctx, NULL)))
        goto err;

    if (!a_is_zero_mod_one("BN_mod_exp_mont", r, a))
    if (!TEST_true(a_is_zero_mod_one("BN_mod_exp_mont", r, a)))
        failed = 1;

    if (!BN_mod_exp_mont_consttime(r, a, p, m, ctx, NULL)) {
    if (!TEST_true(BN_mod_exp_mont_consttime(r, a, p, m, ctx, NULL)))
        goto err;
    }

    if (!a_is_zero_mod_one("BN_mod_exp_mont_consttime", r, a))
    if (!TEST_true(a_is_zero_mod_one("BN_mod_exp_mont_consttime", r, a)))
        failed = 1;

    /*
     * A different codepath exists for single word multiplication
     * in non-constant-time only.
     */
    if (!BN_mod_exp_mont_word(r, one_word, p, m, ctx, NULL))
    if (!TEST_true(BN_mod_exp_mont_word(r, one_word, p, m, ctx, NULL)))
        goto err;

    if (!BN_is_zero(r)) {
    if (!TEST_true(BN_is_zero(r))) {
        fprintf(stderr, "BN_mod_exp_mont_word failed:\n");
        fprintf(stderr, "1 ** 0 mod 1 = r (should be 0)\n");
        fprintf(stderr, "r = ");
        BN_print_fp(stderr, r);
        fprintf(stderr, "\n");
        return 0;
        BN_print_var(r);
        goto err;
    }

    ret = failed;

    ret = !failed;
 err:
    BN_free(r);
    BN_free(a);
@@ -134,38 +127,31 @@ static int test_exp_mod_zero()
    return ret;
}

int main(int argc, char *argv[])
static int test_mod_exp(int round)
{
    BN_CTX *ctx;
    BIO *out = NULL;
    int i, ret;
    unsigned char c;
    BIGNUM *r_mont, *r_mont_const, *r_recp, *r_simple, *a, *b, *m;

    RAND_seed(rnd_seed, sizeof rnd_seed); /* or BN_rand may fail, and we
                                           * don't even check its return
                                           * value (which we should) */

    ctx = BN_CTX_new();
    if (ctx == NULL)
        EXIT(1);
    r_mont = BN_new();
    r_mont_const = BN_new();
    r_recp = BN_new();
    r_simple = BN_new();
    a = BN_new();
    b = BN_new();
    m = BN_new();
    if ((r_mont == NULL) || (r_recp == NULL) || (a == NULL) || (b == NULL))
    int ret = 0;
    BIGNUM *r_mont = NULL;
    BIGNUM *r_mont_const = NULL;
    BIGNUM *r_recp = NULL;
    BIGNUM *r_simple = NULL;
    BIGNUM *a = NULL;
    BIGNUM *b = NULL;
    BIGNUM *m = NULL;

    if (!TEST_ptr(ctx = BN_CTX_new()))
        goto err;

    out = BIO_new(BIO_s_file());

    if (out == NULL)
        EXIT(1);
    BIO_set_fp(out, stdout, BIO_NOCLOSE | BIO_FP_TEXT);
    if (!TEST_ptr(r_mont = BN_new())
        || !TEST_ptr(r_mont_const = BN_new())
        || !TEST_ptr(r_recp = BN_new())
        || !TEST_ptr(r_simple = BN_new())
        || !TEST_ptr(a = BN_new())
        || !TEST_ptr(b = BN_new())
        || !TEST_ptr(m = BN_new()))
        goto err;

    for (i = 0; i < 200; i++) {
    RAND_bytes(&c, 1);
    c = (c % BN_BITS) - BN_BITS2;
    BN_rand(a, NUM_BITS + c, BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ANY);
@@ -181,65 +167,37 @@ int main(int argc, char *argv[])
    BN_mod(a, a, m, ctx);
    BN_mod(b, b, m, ctx);

        ret = BN_mod_exp_mont(r_mont, a, b, m, ctx, NULL);
        if (ret <= 0) {
            printf("BN_mod_exp_mont() problems\n");
            ERR_print_errors(out);
            EXIT(1);
        }

        ret = BN_mod_exp_recp(r_recp, a, b, m, ctx);
        if (ret <= 0) {
            printf("BN_mod_exp_recp() problems\n");
            ERR_print_errors(out);
            EXIT(1);
        }

        ret = BN_mod_exp_simple(r_simple, a, b, m, ctx);
        if (ret <= 0) {
            printf("BN_mod_exp_simple() problems\n");
            ERR_print_errors(out);
            EXIT(1);
        }

        ret = BN_mod_exp_mont_consttime(r_mont_const, a, b, m, ctx, NULL);
        if (ret <= 0) {
            printf("BN_mod_exp_mont_consttime() problems\n");
            ERR_print_errors(out);
            EXIT(1);
        }
    if (!TEST_true(BN_mod_exp_mont(r_mont, a, b, m, ctx, NULL))
        || !TEST_true(BN_mod_exp_recp(r_recp, a, b, m, ctx))
        || !TEST_true(BN_mod_exp_simple(r_simple, a, b, m, ctx))
        || !TEST_true(BN_mod_exp_mont_consttime(r_mont_const, a, b, m, ctx, NULL)))
        goto err;

        if (BN_cmp(r_simple, r_mont) == 0
            && BN_cmp(r_simple, r_recp) == 0
            && BN_cmp(r_simple, r_mont_const) == 0) {
    if (TEST_int_eq(BN_cmp(r_simple, r_mont), 0)
        && TEST_int_eq(BN_cmp(r_simple, r_recp), 0)
        && TEST_int_eq(BN_cmp(r_simple, r_mont_const), 0)) {
        printf(".");
        fflush(stdout);
    } else {
        if (BN_cmp(r_simple, r_mont) != 0)
                printf("\nsimple and mont results differ\n");
            fprintf(stderr, "simple and mont results differ\n");
        if (BN_cmp(r_simple, r_mont_const) != 0)
                printf("\nsimple and mont const time results differ\n");
            fprintf(stderr, "simple and mont const time results differ\n");
        if (BN_cmp(r_simple, r_recp) != 0)
                printf("\nsimple and recp results differ\n");

            printf("a (%3d) = ", BN_num_bits(a));
            BN_print(out, a);
            printf("\nb (%3d) = ", BN_num_bits(b));
            BN_print(out, b);
            printf("\nm (%3d) = ", BN_num_bits(m));
            BN_print(out, m);
            printf("\nsimple   =");
            BN_print(out, r_simple);
            printf("\nrecp     =");
            BN_print(out, r_recp);
            printf("\nmont     =");
            BN_print(out, r_mont);
            printf("\nmont_ct  =");
            BN_print(out, r_mont_const);
            printf("\n");
            EXIT(1);
        }
            fprintf(stderr, "simple and recp results differ\n");

        BN_print_var(a);
        BN_print_var(b);
        BN_print_var(m);
        BN_print_var(r_simple);
        BN_print_var(r_recp);
        BN_print_var(r_mont);
        BN_print_var(r_mont_const);
        goto err;
    }

    ret = 1;
 err:
    BN_free(r_mont);
    BN_free(r_mont_const);
    BN_free(r_recp);
@@ -249,20 +207,11 @@ int main(int argc, char *argv[])
    BN_free(m);
    BN_CTX_free(ctx);

    if (test_exp_mod_zero() != 0)
        goto err;

#ifndef OPENSSL_NO_CRYPTO_MDEBUG
    if (CRYPTO_mem_leaks(out) <= 0)
        goto err;
#endif
    BIO_free(out);
    printf("\n");

    printf("done\n");
    return ret;
}

    EXIT(0);
 err:
    ERR_print_errors(out);
    EXIT(1);
void register_tests(void)
{
    ADD_TEST(test_mod_exp_zero);
    ADD_ALL_TESTS(test_mod_exp, 200);
}