Commit fe5d9450 authored by Boris Pismenny's avatar Boris Pismenny Committed by Matt Caswell
Browse files

sslapitest: add test ktls



Add a unit-test for ktls.

Signed-off-by: default avatarBoris Pismenny <borisp@mellanox.com>

Reviewed-by: default avatarTim Hudson <tjh@openssl.org>
Reviewed-by: default avatarPaul Yang <yang.yang@baishancloud.com>
Reviewed-by: default avatarMatt Caswell <matt@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/5253)
parent 50ec7505
Loading
Loading
Loading
Loading
+193 −0
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@
#include "testutil.h"
#include "testutil/output.h"
#include "internal/nelem.h"
#include "internal/ktls.h"
#include "../ssl/ssl_locl.h"

#ifndef OPENSSL_NO_TLS1_3
@@ -656,6 +657,192 @@ static int execute_test_large_message(const SSL_METHOD *smeth,
    return testresult;
}

#if !defined(OPENSSL_NO_TLS1_2) && !defined(OPENSSL_NO_KTLS)

/* sock must be connected */
static int ktls_chk_platform(int sock)
{
    if (!ktls_enable(sock))
        return 0;
    return 1;
}

static int ping_pong_query(SSL *clientssl, SSL *serverssl, int cfd, int sfd)
{
    static char count = 1;
    unsigned char cbuf[16000] = {0};
    unsigned char sbuf[16000];
    size_t err = 0;
    char crec_wseq_before[TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE];
    char crec_wseq_after[TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE];
    char srec_wseq_before[TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE];
    char srec_wseq_after[TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE];
    char srec_rseq_before[TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE];
    char srec_rseq_after[TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE];

    cbuf[0] = count++;
    memcpy(crec_wseq_before, &clientssl->rlayer.write_sequence,
            TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
    memcpy(srec_wseq_before, &serverssl->rlayer.write_sequence,
            TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
    memcpy(srec_rseq_before, &serverssl->rlayer.read_sequence,
            TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);

    if (!TEST_true(SSL_write(clientssl, cbuf, sizeof(cbuf)) == sizeof(cbuf)))
        goto end;

    while ((err = SSL_read(serverssl, &sbuf, sizeof(sbuf))) != sizeof(sbuf)) {
        if (SSL_get_error(serverssl, err) != SSL_ERROR_WANT_READ) {
            goto end;
        }
    }

    if (!TEST_true(SSL_write(serverssl, sbuf, sizeof(sbuf)) == sizeof(sbuf)))
        goto end;

    while ((err = SSL_read(clientssl, &cbuf, sizeof(cbuf))) != sizeof(cbuf)) {
        if (SSL_get_error(clientssl, err) != SSL_ERROR_WANT_READ) {
            goto end;
        }
    }

    memcpy(crec_wseq_after, &clientssl->rlayer.write_sequence,
            TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
    memcpy(srec_wseq_after, &serverssl->rlayer.write_sequence,
            TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
    memcpy(srec_rseq_after, &serverssl->rlayer.read_sequence,
            TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);

    /* verify the payload */
    if (!TEST_mem_eq(cbuf, sizeof(cbuf), sbuf, sizeof(sbuf)))
        goto end;

    /* ktls is used then kernel sequences are used instead of OpenSSL sequences */
    if (clientssl->mode & SSL_MODE_NO_KTLS_TX) {
        if (!TEST_mem_ne(crec_wseq_before, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE,
                         crec_wseq_after, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE))
            goto end;
    } else {
        if (!TEST_mem_eq(crec_wseq_before, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE,
                         crec_wseq_after, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE))
            goto end;
    }

    if (serverssl->mode & SSL_MODE_NO_KTLS_TX) {
        if (!TEST_mem_ne(srec_wseq_before, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE,
                         srec_wseq_after, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE))
            goto end;
    } else {
        if (!TEST_mem_eq(srec_wseq_before, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE,
                         srec_wseq_after, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE))
            goto end;
    }

    if (!TEST_mem_ne(srec_rseq_before, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE,
                     srec_rseq_after, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE))
        goto end;

    return 1;
end:
    return 0;
}

static int execute_test_ktls(int cis_ktls_tx, int sis_ktls_tx)
{
    SSL_CTX *cctx = NULL, *sctx = NULL;
    SSL *clientssl = NULL, *serverssl = NULL;
    int testresult = 0;
    int cfd, sfd;

    if (!TEST_true(create_test_sockets(&cfd, &sfd)))
        goto end;

    /* Skip this test if the platform does not support ktls */
    if (!ktls_chk_platform(cfd))
        return 1;

    /* Create a session based on SHA-256 */
    if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(),
                                       TLS_client_method(),
                                       TLS1_2_VERSION, TLS1_2_VERSION,
                                       &sctx, &cctx, cert, privkey))
            || !TEST_true(SSL_CTX_set_cipher_list(cctx,
                                                  "AES128-GCM-SHA256"))
            || !TEST_true(create_ssl_objects2(sctx, cctx, &serverssl,
                                          &clientssl, sfd, cfd)))
        goto end;

    if (!cis_ktls_tx) {
        if (!TEST_true(SSL_set_mode(clientssl, SSL_MODE_NO_KTLS_TX)))
            goto end;
    }

    if (!sis_ktls_tx) {
        if (!TEST_true(SSL_set_mode(serverssl, SSL_MODE_NO_KTLS_TX)))
            goto end;
    }

    if (!TEST_true(create_ssl_connection(serverssl, clientssl,
                                                SSL_ERROR_NONE)))
        goto end;

    if (!cis_ktls_tx) {
        if (!TEST_false(BIO_get_ktls_send(clientssl->wbio)))
            goto end;
    } else {
        if (!TEST_true(BIO_get_ktls_send(clientssl->wbio)))
            goto end;
    }

    if (!sis_ktls_tx) {
        if (!TEST_false(BIO_get_ktls_send(serverssl->wbio)))
            goto end;
    } else {
        if (!TEST_true(BIO_get_ktls_send(serverssl->wbio)))
            goto end;
    }

    if (!TEST_true(ping_pong_query(clientssl, serverssl, cfd, sfd)))
        goto end;

    testresult = 1;
end:
    if (clientssl) {
        SSL_shutdown(clientssl);
        SSL_free(clientssl);
    }
    if (serverssl) {
        SSL_shutdown(serverssl);
        SSL_free(serverssl);
    }
    SSL_CTX_free(sctx);
    SSL_CTX_free(cctx);
    serverssl = clientssl = NULL;
    return testresult;
}

static int test_ktls_client_server(void)
{
    return execute_test_ktls(1, 1);
}

static int test_ktls_no_client_server(void)
{
    return execute_test_ktls(0, 1);
}

static int test_ktls_client_no_server(void)
{
    return execute_test_ktls(1, 0);
}

static int test_ktls_no_client_no_server(void)
{
    return execute_test_ktls(0, 0);
}

#endif

static int test_large_message_tls(void)
{
    return execute_test_large_message(TLS_server_method(), TLS_client_method(),
@@ -5869,6 +6056,12 @@ int setup_tests(void)
#endif
    }

#if !defined(OPENSSL_NO_TLS1_2) && !defined(OPENSSL_NO_KTLS)
    ADD_TEST(test_ktls_client_server);
    ADD_TEST(test_ktls_no_client_server);
    ADD_TEST(test_ktls_client_no_server);
    ADD_TEST(test_ktls_no_client_no_server);
#endif
    ADD_TEST(test_large_message_tls);
    ADD_TEST(test_large_message_tls_read_ahead);
#ifndef OPENSSL_NO_DTLS
+121 −0
Original line number Diff line number Diff line
@@ -16,6 +16,14 @@

#ifdef OPENSSL_SYS_UNIX
# include <unistd.h>
#ifndef OPENSSL_NO_KTLS
# include <netinet/in.h>
# include <netinet/in.h>
# include <arpa/inet.h>
# include <sys/socket.h>
# include <unistd.h>
# include <fcntl.h>
#endif

static ossl_inline void ossl_sleep(unsigned int millis) {
    usleep(millis * 1000);
@@ -655,6 +663,119 @@ int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,

#define MAXLOOPS    1000000

#ifndef OPENSSL_NO_KTLS
static int set_nb(int fd)
{
    int flags;

    flags = fcntl(fd,F_GETFL,0);
    if (flags == -1)
        return flags;
    flags = fcntl(fd, F_SETFL, flags | O_NONBLOCK);
    return flags;
}

int create_test_sockets(int *cfd, int *sfd)
{
    struct sockaddr_in sin;
    const char *host = "127.0.0.1";
    int cfd_connected = 0, ret = 0;
    socklen_t slen = sizeof(sin);
    int afd = -1;

    *cfd = -1;
    *sfd = -1;

    memset ((char *) &sin, 0, sizeof(sin));
    sin.sin_family = AF_INET;
    sin.sin_addr.s_addr = inet_addr(host);

    afd = socket(AF_INET, SOCK_STREAM, 0);
    if (afd < 0)
        return 0;

    if (bind(afd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
        goto out;

    if (getsockname(afd, (struct sockaddr*)&sin, &slen) < 0)
        goto out;

    if (listen(afd, 1) < 0)
        goto out;

    *cfd = socket(AF_INET, SOCK_STREAM, 0);
    if (*cfd < 0)
        goto out;

    if (set_nb(afd) == -1)
        goto out;

    while (*sfd == -1 || !cfd_connected ) {
        *sfd = accept(afd, NULL, 0);
        if (*sfd == -1 && errno != EAGAIN)
            goto out;

        if (!cfd_connected && connect(*cfd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
            goto out;
        else
            cfd_connected = 1;
    }

    if (set_nb(*cfd) == -1 || set_nb(*sfd) == -1)
        goto out;
    ret = 1;
    goto success;

out:
        if (*cfd != -1)
            close(*cfd);
        if (*sfd != -1)
            close(*sfd);
success:
        if (afd != -1)
            close(afd);
    return ret;
}
#else
int create_test_sockets(int *cfd, int *sfd)
{
    return 0;
}
#endif

int create_ssl_objects2(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
                          SSL **cssl, int sfd, int cfd)
{
    SSL *serverssl = NULL, *clientssl = NULL;
    BIO *s_to_c_bio = NULL, *c_to_s_bio = NULL;

    if (*sssl != NULL)
        serverssl = *sssl;
    else if (!TEST_ptr(serverssl = SSL_new(serverctx)))
        goto error;
    if (*cssl != NULL)
        clientssl = *cssl;
    else if (!TEST_ptr(clientssl = SSL_new(clientctx)))
        goto error;

    if (!TEST_ptr(s_to_c_bio = BIO_new_socket(sfd, BIO_NOCLOSE))
            || !TEST_ptr(c_to_s_bio = BIO_new_socket(cfd, BIO_NOCLOSE)))
        goto error;

    SSL_set_bio(clientssl, c_to_s_bio, c_to_s_bio);
    SSL_set_bio(serverssl, s_to_c_bio, s_to_c_bio);
    *sssl = serverssl;
    *cssl = clientssl;
    return 1;

 error:
    SSL_free(serverssl);
    SSL_free(clientssl);
    BIO_free(s_to_c_bio);
    BIO_free(c_to_s_bio);
    return 0;
}

/*
 * NOTE: Transfers control of the BIOs - this function will free them on error
 */
+3 −0
Original line number Diff line number Diff line
@@ -19,6 +19,9 @@ int create_ssl_ctx_pair(const SSL_METHOD *sm, const SSL_METHOD *cm,
int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
                       SSL **cssl, BIO *s_to_c_fbio, BIO *c_to_s_fbio);
int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want);
int create_ssl_objects2(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
                       SSL **cssl, int sfd, int cfd);
int create_test_sockets(int *cfd, int *sfd);
int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want);
void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl);