Commit 0b3f5eab authored by Matt Caswell's avatar Matt Caswell
Browse files

Add a test for duplicated DTLS records



Reviewed-by: default avatarPaul Dale <paul.dale@oracle.com>
(Merged from https://github.com/openssl/openssl/pull/7414)

(cherry picked from commit f1358634af5b84be22cb20fff3dcb613f5f8c978)
parent 86fe421d
Loading
Loading
Loading
Loading
+35 −0
Original line number Diff line number Diff line
@@ -289,6 +289,40 @@ static int test_cookie(void)
    return testresult;
}

static int test_dtls_duplicate_records(void)
{
    SSL_CTX *sctx = NULL, *cctx = NULL;
    SSL *serverssl = NULL, *clientssl = NULL;
    int testresult = 0;

    if (!TEST_true(create_ssl_ctx_pair(DTLS_server_method(),
                                       DTLS_client_method(),
                                       DTLS1_VERSION, DTLS_MAX_VERSION,
                                       &sctx, &cctx, cert, privkey)))
        return 0;

    if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl,
                                      NULL, NULL)))
        goto end;

    DTLS_set_timer_cb(clientssl, timer_cb);
    DTLS_set_timer_cb(serverssl, timer_cb);

    BIO_ctrl(SSL_get_wbio(clientssl), MEMPACKET_CTRL_SET_DUPLICATE_REC, 1, NULL);
    BIO_ctrl(SSL_get_wbio(serverssl), MEMPACKET_CTRL_SET_DUPLICATE_REC, 1, NULL);

    if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)))
        goto end;

    testresult = 1;
 end:
    SSL_free(serverssl);
    SSL_free(clientssl);
    SSL_CTX_free(sctx);
    SSL_CTX_free(cctx);

    return testresult;
}

int setup_tests(void)
{
@@ -299,6 +333,7 @@ int setup_tests(void)
    ADD_ALL_TESTS(test_dtls_unprocessed, NUM_TESTS);
    ADD_ALL_TESTS(test_dtls_drop_records, TOTAL_RECORDS);
    ADD_TEST(test_cookie);
    ADD_TEST(test_dtls_duplicate_records);

    return 1;
}
+57 −23
Original line number Diff line number Diff line
@@ -284,6 +284,7 @@ typedef struct mempacket_test_ctx_st {
    unsigned int noinject;
    unsigned int dropepoch;
    int droprec;
    int duprec;
} MEMPACKET_TEST_CTX;

static int mempacket_test_new(BIO *bi);
@@ -426,12 +427,25 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
                          int type)
{
    MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
    MEMPACKET *thispkt, *looppkt, *nextpkt;
    int i;
    MEMPACKET *thispkt = NULL, *looppkt, *nextpkt, *allpkts[3];
    int i, duprec = ctx->duprec > 0;
    const unsigned char *inu = (const unsigned char *)in;
    size_t len = ((inu[RECORD_LEN_HI] << 8) | inu[RECORD_LEN_LO])
                 + DTLS1_RT_HEADER_LENGTH;

    if (ctx == NULL)
        return -1;

    if ((size_t)inl < len)
        return -1;

    if ((size_t)inl == len)
        duprec = 0;

    /* We don't support arbitrary injection when duplicating records */
    if (duprec && pktnum != -1)
        return -1;

    /* We only allow injection before we've started writing any data */
    if (pktnum >= 0) {
        if (ctx->noinject)
@@ -441,25 +455,36 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
        ctx->noinject = 1;
    }

    if (!TEST_ptr(thispkt = OPENSSL_malloc(sizeof(*thispkt))))
        return -1;
    if (!TEST_ptr(thispkt->data = OPENSSL_malloc(inl))) {
        mempacket_free(thispkt);
        return -1;
    }
    for (i = 0; i < (duprec ? 3 : 1); i++) {
        if (!TEST_ptr(allpkts[i] = OPENSSL_malloc(sizeof(*thispkt))))
            goto err;
        thispkt = allpkts[i];

        if (!TEST_ptr(thispkt->data = OPENSSL_malloc(inl)))
            goto err;
        /*
         * If we are duplicating the packet, we duplicate it three times. The
         * first two times we drop the first record if there are more than one.
         * In this way we know that libssl will not be able to make progress
         * until it receives the last packet, and hence will be forced to
         * buffer these records.
         */
        if (duprec && i != 2) {
            memcpy(thispkt->data, in + len, inl - len);
            thispkt->len = inl - len;
        } else {
            memcpy(thispkt->data, in, inl);
            thispkt->len = inl;
    thispkt->num = (pktnum >= 0) ? (unsigned int)pktnum : ctx->lastpkt;
        }
        thispkt->num = (pktnum >= 0) ? (unsigned int)pktnum : ctx->lastpkt + i;
        thispkt->type = type;
    }

    for(i = 0; (looppkt = sk_MEMPACKET_value(ctx->pkts, i)) != NULL; i++) {
        /* Check if we found the right place to insert this packet */
        if (looppkt->num > thispkt->num) {
            if (sk_MEMPACKET_insert(ctx->pkts, thispkt, i) == 0) {
                mempacket_free(thispkt);
                return -1;
            }
            if (sk_MEMPACKET_insert(ctx->pkts, thispkt, i) == 0)
                goto err;
            /* If we're doing up front injection then we're done */
            if (pktnum >= 0)
                return inl;
@@ -480,7 +505,7 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
        } else if (looppkt->num == thispkt->num) {
            if (!ctx->noinject) {
                /* We injected two packets with the same packet number! */
                return -1;
                goto err;
            }
            ctx->lastpkt++;
            thispkt->num++;
@@ -490,15 +515,21 @@ int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
     * We didn't find any packets with a packet number equal to or greater than
     * this one, so we just add it onto the end
     */
    if (!sk_MEMPACKET_push(ctx->pkts, thispkt)) {
        mempacket_free(thispkt);
        return -1;
    }
    for (i = 0; i < (duprec ? 3 : 1); i++) {
        thispkt = allpkts[i];
        if (!sk_MEMPACKET_push(ctx->pkts, thispkt))
            goto err;

        if (pktnum < 0)
            ctx->lastpkt++;
    }

    return inl;

 err:
    for (i = 0; i < (ctx->duprec > 0 ? 3 : 1); i++)
        mempacket_free(allpkts[i]);
    return -1;
}

static int mempacket_test_write(BIO *bio, const char *in, int inl)
@@ -544,6 +575,9 @@ static long mempacket_test_ctrl(BIO *bio, int cmd, long num, void *ptr)
    case MEMPACKET_CTRL_GET_DROP_REC:
        ret = ctx->droprec;
        break;
    case MEMPACKET_CTRL_SET_DUPLICATE_REC:
        ctx->duprec = (int)num;
        break;
    case BIO_CTRL_RESET:
    case BIO_CTRL_DUP:
    case BIO_CTRL_PUSH:
+4 −3
Original line number Diff line number Diff line
@@ -40,6 +40,7 @@ void bio_s_mempacket_test_free(void);
#define MEMPACKET_CTRL_SET_DROP_EPOCH       (1 << 15)
#define MEMPACKET_CTRL_SET_DROP_REC         (2 << 15)
#define MEMPACKET_CTRL_GET_DROP_REC         (3 << 15)
#define MEMPACKET_CTRL_SET_DUPLICATE_REC    (4 << 15)

int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
                          int type);