Commit aff58ee3 authored by Matt Caswell's avatar Matt Caswell
Browse files

Add a test for the certificate callback



Reviewed-by: default avatarBen Kaduk <kaduk@mit.edu>
(Merged from https://github.com/openssl/openssl/pull/7257)

(cherry picked from commit cd6fe29f5bad1a350a039673e06f83ec7a7ef619)
parent ec6788fb
Loading
Loading
Loading
Loading
+95 −0
Original line number Diff line number Diff line
@@ -5497,6 +5497,100 @@ static int test_shutdown(int tst)
    return testresult;
}

static int cert_cb_cnt;

static int cert_cb(SSL *s, void *arg)
{
    SSL_CTX *ctx = (SSL_CTX *)arg;

    if (cert_cb_cnt == 0) {
        /* Suspend the handshake */
        cert_cb_cnt++;
        return -1;
    } else if (cert_cb_cnt == 1) {
        /*
         * Update the SSL_CTX, set the certificate and private key and then
         * continue the handshake normally.
         */
        if (ctx != NULL && !TEST_ptr(SSL_set_SSL_CTX(s, ctx)))
            return 0;

        if (!TEST_true(SSL_use_certificate_file(s, cert, SSL_FILETYPE_PEM))
                || !TEST_true(SSL_use_PrivateKey_file(s, privkey,
                                                      SSL_FILETYPE_PEM))
                || !TEST_true(SSL_check_private_key(s)))
            return 0;
        cert_cb_cnt++;
        return 1;
    }

    /* Abort the handshake */
    return 0;
}

/*
 * Test the certificate callback.
 * Test 0: Callback fails
 * Test 1: Success - no SSL_set_SSL_CTX() in the callback
 * Test 2: Success - SSL_set_SSL_CTX() in the callback
 */
static int test_cert_cb_int(int prot, int tst)
{
    SSL_CTX *cctx = NULL, *sctx = NULL, *snictx = NULL;
    SSL *clientssl = NULL, *serverssl = NULL;
    int testresult = 0, ret;

    if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(),
                                       TLS_client_method(),
                                       TLS1_VERSION,
                                       prot,
                                       &sctx, &cctx, NULL, NULL)))
        goto end;

    if (tst == 0)
        cert_cb_cnt = -1;
    else
        cert_cb_cnt = 0;
    if (tst == 2)
        snictx = SSL_CTX_new(TLS_server_method());
    SSL_CTX_set_cert_cb(sctx, cert_cb, snictx);

    if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
                                      NULL, NULL)))
        goto end;

    ret = create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE);
    if (!TEST_true(tst == 0 ? !ret : ret)
            || (tst > 0 && !TEST_int_eq(cert_cb_cnt, 2))) {
        goto end;
    }

    testresult = 1;

 end:
    SSL_free(serverssl);
    SSL_free(clientssl);
    SSL_CTX_free(sctx);
    SSL_CTX_free(cctx);
    SSL_CTX_free(snictx);

    return testresult;
}

static int test_cert_cb(int tst)
{
    int testresult = 1;

#ifndef OPENSSL_NO_TLS1_2
    testresult &= test_cert_cb_int(TLS1_2_VERSION, tst);
#endif
#ifdef OPENSSL_NO_TLS1_3
    testresult &= test_cert_cb_int(TLS1_3_VERSION, tst);
#endif

    return testresult;
}

int setup_tests(void)
{
    if (!TEST_ptr(cert = test_get_argument(0))
@@ -5599,6 +5693,7 @@ int setup_tests(void)
    ADD_ALL_TESTS(test_ssl_get_shared_ciphers, OSSL_NELEM(shared_ciphers_data));
    ADD_ALL_TESTS(test_ticket_callbacks, 12);
    ADD_ALL_TESTS(test_shutdown, 7);
    ADD_ALL_TESTS(test_cert_cb, 3);
    return 1;
}

+3 −1
Original line number Diff line number Diff line
@@ -712,7 +712,9 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
                err = SSL_get_error(serverssl, rets);
        }

        if (!servererr && rets <= 0 && err != SSL_ERROR_WANT_READ) {
        if (!servererr && rets <= 0
                && err != SSL_ERROR_WANT_READ
                && err != SSL_ERROR_WANT_X509_LOOKUP) {
            TEST_info("SSL_accept() failed %d, %d", rets, err);
            servererr = 1;
        }