Commit bb01ef3f authored by Matt Caswell's avatar Matt Caswell
Browse files

Add a test for SNI in conjunction with custom extensions



Test that custom extensions still work even after a change in SSL_CTX due
to SNI. See #2180.

Reviewed-by: default avatarRich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/3425)
parent 21181889
Loading
Loading
Loading
Loading
+56 −10
Original line number Diff line number Diff line
@@ -1793,6 +1793,7 @@ static int clntaddnewcb = 0;
static int clntparsenewcb = 0;
static int srvaddnewcb = 0;
static int srvparsenewcb = 0;
static int snicb = 0;

#define TEST_EXT_TYPE1  0xff00

@@ -1886,16 +1887,30 @@ static int new_parse_cb(SSL *s, unsigned int ext_type, unsigned int context,

    return 1;
}

static int sni_cb(SSL *s, int *al, void *arg)
{
    SSL_CTX *ctx = (SSL_CTX *)arg;

    if (SSL_set_SSL_CTX(s, ctx) == NULL) {
        *al = SSL_AD_INTERNAL_ERROR;
        return SSL_TLSEXT_ERR_ALERT_FATAL;
    }
    snicb++;
    return SSL_TLSEXT_ERR_OK;
}

/*
 * Custom call back tests.
 * Test 0: Old style callbacks in TLSv1.2
 * Test 1: New style callbacks in TLSv1.2
 * Test 2: New style callbacks in TLSv1.3. Extensions in CH and EE
 * Test 3: New style callbacks in TLSv1.3. Extensions in CH, SH, EE, Cert + NST
 * Test 2: New style callbacks in TLSv1.2 with SNI
 * Test 3: New style callbacks in TLSv1.3. Extensions in CH and EE
 * Test 4: New style callbacks in TLSv1.3. Extensions in CH, SH, EE, Cert + NST
 */
static int test_custom_exts(int tst)
{
    SSL_CTX *cctx = NULL, *sctx = NULL;
    SSL_CTX *cctx = NULL, *sctx = NULL, *sctx2 = NULL;
    SSL *clientssl = NULL, *serverssl = NULL;
    int testresult = 0;
    static int server = 1;
@@ -1906,18 +1921,27 @@ static int test_custom_exts(int tst)
    /* Reset callback counters */
    clntaddoldcb = clntparseoldcb = srvaddoldcb = srvparseoldcb = 0;
    clntaddnewcb = clntparsenewcb = srvaddnewcb = srvparsenewcb = 0;
    snicb = 0;

    if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(),
                                       TLS_client_method(), &sctx,
                                       &cctx, cert, privkey)))
        goto end;

    if (tst < 2) {
    if (tst == 2
            && !TEST_true(create_ssl_ctx_pair(TLS_server_method(), NULL, &sctx2,
                                              NULL, cert, privkey)))
        goto end;


    if (tst < 3) {
        SSL_CTX_set_options(cctx, SSL_OP_NO_TLSv1_3);
        SSL_CTX_set_options(sctx, SSL_OP_NO_TLSv1_3);
        if (sctx2 != NULL)
            SSL_CTX_set_options(sctx2, SSL_OP_NO_TLSv1_3);
    }

    if (tst == 3) {
    if (tst == 4) {
        context = SSL_EXT_CLIENT_HELLO
                  | SSL_EXT_TLS1_2_SERVER_HELLO
                  | SSL_EXT_TLS1_3_SERVER_HELLO
@@ -1967,6 +1991,12 @@ static int test_custom_exts(int tst)
                                              new_add_cb, new_free_cb,
                                              &server, new_parse_cb, &server)))
            goto end;
        if (sctx2 != NULL
                && !TEST_true(SSL_CTX_add_custom_ext(sctx2, TEST_EXT_TYPE1,
                                                     context, new_add_cb,
                                                     new_free_cb, &server,
                                                     new_parse_cb, &server)))
            goto end;
    }

    /* Should not be able to add duplicates */
@@ -1980,6 +2010,13 @@ static int test_custom_exts(int tst)
                                                  new_parse_cb, &server)))
        goto end;

    if (tst == 2) {
        /* Set up SNI */
        if (!TEST_true(SSL_CTX_set_tlsext_servername_callback(sctx, sni_cb))
                || !TEST_true(SSL_CTX_set_tlsext_servername_arg(sctx, sctx2)))
            goto end;
    }

    if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl,
                                      &clientssl, NULL, NULL))
            || !TEST_true(create_ssl_connection(serverssl, clientssl,
@@ -1992,11 +2029,13 @@ static int test_custom_exts(int tst)
                || srvaddoldcb != 1
                || srvparseoldcb != 1)
            goto end;
    } else if (tst == 1 || tst == 2) {
    } else if (tst == 1 || tst == 2 || tst == 3) {
        if (clntaddnewcb != 1
                || clntparsenewcb != 1
                || srvaddnewcb != 1
                || srvparsenewcb != 1)
                || srvparsenewcb != 1
                || (tst != 2 && snicb != 0)
                || (tst == 2 && snicb != 1))
            goto end;
    } else {
        if (clntaddnewcb != 1
@@ -2013,6 +2052,12 @@ static int test_custom_exts(int tst)
    SSL_free(clientssl);
    serverssl = clientssl = NULL;

    if (tst == 3) {
        /* We don't bother with the resumption aspects for this test */
        testresult = 1;
        goto end;
    }

    if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
                                      NULL, NULL))
            || !TEST_true(SSL_set_session(clientssl, sess))
@@ -2032,7 +2077,7 @@ static int test_custom_exts(int tst)
                || srvaddoldcb != 1
                || srvparseoldcb != 1)
            goto end;
    } else if (tst == 1 || tst == 2) {
    } else if (tst == 1 || tst == 2 || tst == 3) {
        if (clntaddnewcb != 2
                || clntparsenewcb != 2
                || srvaddnewcb != 2
@@ -2053,6 +2098,7 @@ end:
    SSL_SESSION_free(sess);
    SSL_free(serverssl);
    SSL_free(clientssl);
    SSL_CTX_free(sctx2);
    SSL_CTX_free(sctx);
    SSL_CTX_free(cctx);
    return testresult;
@@ -2161,9 +2207,9 @@ int test_main(int argc, char *argv[])
# endif
#endif
#ifndef OPENSSL_NO_TLS1_3
    ADD_ALL_TESTS(test_custom_exts, 4);
    ADD_ALL_TESTS(test_custom_exts, 5);
#else
    ADD_ALL_TESTS(test_custom_exts, 2);
    ADD_ALL_TESTS(test_custom_exts, 3);
#endif
    ADD_ALL_TESTS(test_serverinfo, 8);

+3 −2
Original line number Diff line number Diff line
@@ -518,7 +518,7 @@ int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
    SSL_CTX *clientctx = NULL;

    if (!TEST_ptr(serverctx = SSL_CTX_new(sm))
            || !TEST_ptr(clientctx = SSL_CTX_new(cm)))
            || (cctx != NULL && !TEST_ptr(clientctx = SSL_CTX_new(cm))))
        goto err;

    if (!TEST_int_eq(SSL_CTX_use_certificate_file(serverctx, certfile,
@@ -533,6 +533,7 @@ int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
#endif

    *sctx = serverctx;
    if (cctx != NULL)
        *cctx = clientctx;
    return 1;