Commit c8feb103 authored by Matt Caswell's avatar Matt Caswell
Browse files

Write a test for receiving a KeyUpdate (update requested) while writing



Reviewed-by: default avatarBen Kaduk <kaduk@mit.edu>
(Merged from https://github.com/openssl/openssl/pull/8773)

(cherry picked from commit a77b4dba237d001073d2d1c5d55c674a196c949f)
parent 6c2f347c
Loading
Loading
Loading
Loading
+92 −0
Original line number Diff line number Diff line
@@ -4290,7 +4290,95 @@ static int test_key_update(void)
                || !TEST_int_eq(SSL_read(serverssl, buf, sizeof(buf)),
                                         strlen(mess)))
            goto end;

        if (!TEST_int_eq(SSL_write(serverssl, mess, strlen(mess)), strlen(mess))
                || !TEST_int_eq(SSL_read(clientssl, buf, sizeof(buf)),
                                         strlen(mess)))
            goto end;
    }

    testresult = 1;

 end:
    SSL_free(serverssl);
    SSL_free(clientssl);
    SSL_CTX_free(sctx);
    SSL_CTX_free(cctx);

    return testresult;
}

/*
 * Test we can handle a KeyUpdate (update requested) message while write data
 * is pending.
 * Test 0: Client sends KeyUpdate while Server is writing
 * Test 1: Server sends KeyUpdate while Client is writing
 */
static int test_key_update_in_write(int tst)
{
    SSL_CTX *cctx = NULL, *sctx = NULL;
    SSL *clientssl = NULL, *serverssl = NULL;
    int testresult = 0;
    char buf[20];
    static char *mess = "A test message";
    BIO *bretry = BIO_new(bio_s_always_retry());
    BIO *tmp = NULL;
    SSL *peerupdate = NULL, *peerwrite = NULL;

    if (!TEST_ptr(bretry)
            || !TEST_true(create_ssl_ctx_pair(TLS_server_method(),
                                              TLS_client_method(),
                                              TLS1_3_VERSION,
                                              0,
                                              &sctx, &cctx, cert, privkey))
            || !TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
                                             NULL, NULL))
            || !TEST_true(create_ssl_connection(serverssl, clientssl,
                                                SSL_ERROR_NONE)))
        goto end;

    peerupdate = tst == 0 ? clientssl : serverssl;
    peerwrite = tst == 0 ? serverssl : clientssl;

    if (!TEST_true(SSL_key_update(peerupdate, SSL_KEY_UPDATE_REQUESTED))
            || !TEST_true(SSL_do_handshake(peerupdate)))
        goto end;

    /* Swap the writing endpoint's write BIO to force a retry */
    tmp = SSL_get_wbio(peerwrite);
    if (!TEST_ptr(tmp) || !TEST_true(BIO_up_ref(tmp))) {
        tmp = NULL;
        goto end;
    }
    SSL_set0_wbio(peerwrite, bretry);
    bretry = NULL;

    /* Write data that we know will fail with SSL_ERROR_WANT_WRITE */
    if (!TEST_int_eq(SSL_write(peerwrite, mess, strlen(mess)), -1)
            || !TEST_int_eq(SSL_get_error(peerwrite, 0), SSL_ERROR_WANT_WRITE))
        goto end;

    /* Reinstate the original writing endpoint's write BIO */
    SSL_set0_wbio(peerwrite, tmp);
    tmp = NULL;

    /* Now read some data - we will read the key update */
    if (!TEST_int_eq(SSL_read(peerwrite, buf, sizeof(buf)), -1)
            || !TEST_int_eq(SSL_get_error(peerwrite, 0), SSL_ERROR_WANT_READ))
        goto end;

    /*
     * Complete the write we started previously and read it from the other
     * endpoint
     */
    if (!TEST_int_eq(SSL_write(peerwrite, mess, strlen(mess)), strlen(mess))
            || !TEST_int_eq(SSL_read(peerupdate, buf, sizeof(buf)), strlen(mess)))
        goto end;

    /* Write more data to ensure we send the KeyUpdate message back */
    if (!TEST_int_eq(SSL_write(peerwrite, mess, strlen(mess)), strlen(mess))
            || !TEST_int_eq(SSL_read(peerupdate, buf, sizeof(buf)), strlen(mess)))
        goto end;

    testresult = 1;

@@ -4299,6 +4387,8 @@ static int test_key_update(void)
    SSL_free(clientssl);
    SSL_CTX_free(sctx);
    SSL_CTX_free(cctx);
    BIO_free(bretry);
    BIO_free(tmp);

    return testresult;
}
@@ -5982,6 +6072,7 @@ int setup_tests(void)
#ifndef OPENSSL_NO_TLS1_3
    ADD_ALL_TESTS(test_export_key_mat_early, 3);
    ADD_TEST(test_key_update);
    ADD_ALL_TESTS(test_key_update_in_write, 2);
#endif
    ADD_ALL_TESTS(test_ssl_clear, 2);
    ADD_ALL_TESTS(test_max_fragment_len_ext, OSSL_NELEM(max_fragment_len_test));
@@ -6002,4 +6093,5 @@ int setup_tests(void)
void cleanup_tests(void)
{
    bio_s_mempacket_test_free();
    bio_s_always_retry_free();
}
+96 −0
Original line number Diff line number Diff line
@@ -62,9 +62,11 @@ static int tls_dump_puts(BIO *bp, const char *str);
/* Choose a sufficiently large type likely to be unused for this custom BIO */
#define BIO_TYPE_TLS_DUMP_FILTER  (0x80 | BIO_TYPE_FILTER)
#define BIO_TYPE_MEMPACKET_TEST    0x81
#define BIO_TYPE_ALWAYS_RETRY      0x82

static BIO_METHOD *method_tls_dump = NULL;
static BIO_METHOD *meth_mem = NULL;
static BIO_METHOD *meth_always_retry = NULL;

/* Note: Not thread safe! */
const BIO_METHOD *bio_f_tls_dump_filter(void)
@@ -612,6 +614,100 @@ static int mempacket_test_puts(BIO *bio, const char *str)
    return mempacket_test_write(bio, str, strlen(str));
}

static int always_retry_new(BIO *bi);
static int always_retry_free(BIO *a);
static int always_retry_read(BIO *b, char *out, int outl);
static int always_retry_write(BIO *b, const char *in, int inl);
static long always_retry_ctrl(BIO *b, int cmd, long num, void *ptr);
static int always_retry_gets(BIO *bp, char *buf, int size);
static int always_retry_puts(BIO *bp, const char *str);

const BIO_METHOD *bio_s_always_retry(void)
{
    if (meth_always_retry == NULL) {
        if (!TEST_ptr(meth_always_retry = BIO_meth_new(BIO_TYPE_ALWAYS_RETRY,
                                                       "Always Retry"))
            || !TEST_true(BIO_meth_set_write(meth_always_retry,
                                             always_retry_write))
            || !TEST_true(BIO_meth_set_read(meth_always_retry,
                                            always_retry_read))
            || !TEST_true(BIO_meth_set_puts(meth_always_retry,
                                            always_retry_puts))
            || !TEST_true(BIO_meth_set_gets(meth_always_retry,
                                            always_retry_gets))
            || !TEST_true(BIO_meth_set_ctrl(meth_always_retry,
                                            always_retry_ctrl))
            || !TEST_true(BIO_meth_set_create(meth_always_retry,
                                              always_retry_new))
            || !TEST_true(BIO_meth_set_destroy(meth_always_retry,
                                               always_retry_free)))
            return NULL;
    }
    return meth_always_retry;
}

void bio_s_always_retry_free(void)
{
    BIO_meth_free(meth_always_retry);
}

static int always_retry_new(BIO *bio)
{
    BIO_set_init(bio, 1);
    return 1;
}

static int always_retry_free(BIO *bio)
{
    BIO_set_data(bio, NULL);
    BIO_set_init(bio, 0);
    return 1;
}

static int always_retry_read(BIO *bio, char *out, int outl)
{
    BIO_set_retry_read(bio);
    return -1;
}

static int always_retry_write(BIO *bio, const char *in, int inl)
{
    BIO_set_retry_write(bio);
    return -1;
}

static long always_retry_ctrl(BIO *bio, int cmd, long num, void *ptr)
{
    long ret = 1;

    switch (cmd) {
    case BIO_CTRL_FLUSH:
        BIO_set_retry_write(bio);
        /* fall through */
    case BIO_CTRL_EOF:
    case BIO_CTRL_RESET:
    case BIO_CTRL_DUP:
    case BIO_CTRL_PUSH:
    case BIO_CTRL_POP:
    default:
        ret = 0;
        break;
    }
    return ret;
}

static int always_retry_gets(BIO *bio, char *buf, int size)
{
    BIO_set_retry_read(bio);
    return -1;
}

static int always_retry_puts(BIO *bio, const char *str)
{
    BIO_set_retry_write(bio);
    return -1;
}

int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
                        int min_proto_version, int max_proto_version,
                        SSL_CTX **sctx, SSL_CTX **cctx, char *certfile,
+3 −0
Original line number Diff line number Diff line
@@ -30,6 +30,9 @@ void bio_f_tls_dump_filter_free(void);
const BIO_METHOD *bio_s_mempacket_test(void);
void bio_s_mempacket_test_free(void);

const BIO_METHOD *bio_s_always_retry(void);
void bio_s_always_retry_free(void);

/* Packet types - value 0 is reserved */
#define INJECT_PACKET                   1
#define INJECT_PACKET_IGNORE_REC_SEQ    2