Commit 4ee7d3f9 authored by Kurt Roeckx's avatar Kurt Roeckx
Browse files

Implement SSL_read_ex() and SSL_write_ex() as documented.



Reviewed-by: default avatarMatt Caswell <matt@openssl.org>
Reviewed-by: default avatarRichard Levitte <levitte@openssl.org>
GH: #1964
parent 2afaee51
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -2244,8 +2244,10 @@ int ERR_load_SSL_strings(void);
# define SSL_F_SSL_PARSE_SERVERHELLO_USE_SRTP_EXT         311
# define SSL_F_SSL_PEEK                                   270
# define SSL_F_SSL_PEEK_EX                                432
# define SSL_F_SSL_PEEK_INTERNAL                          521
# define SSL_F_SSL_READ                                   223
# define SSL_F_SSL_READ_EX                                434
# define SSL_F_SSL_READ_INTERNAL                          519
# define SSL_F_SSL_RENEGOTIATE                            516
# define SSL_F_SSL_SCAN_CLIENTHELLO_TLSEXT                320
# define SSL_F_SSL_SCAN_SERVERHELLO_TLSEXT                321
@@ -2284,6 +2286,7 @@ int ERR_load_SSL_strings(void);
# define SSL_F_SSL_VERIFY_CERT_CHAIN                      207
# define SSL_F_SSL_WRITE                                  208
# define SSL_F_SSL_WRITE_EX                               433
# define SSL_F_SSL_WRITE_INTERNAL                         520
# define SSL_F_STATE_MACHINE                              353
# define SSL_F_TLS12_CHECK_PEER_SIGALG                    333
# define SSL_F_TLS13_CHANGE_CIPHER_STATE                  440
+2 −2
Original line number Diff line number Diff line
@@ -103,7 +103,7 @@ static int ssl_read(BIO *b, char *buf, size_t size, size_t *readbytes)

    BIO_clear_retry_flags(b);

    ret = SSL_read_ex(ssl, buf, size, readbytes);
    ret = ssl_read_internal(ssl, buf, size, readbytes);

    switch (SSL_get_error(ssl, ret)) {
    case SSL_ERROR_NONE:
@@ -172,7 +172,7 @@ static int ssl_write(BIO *b, const char *buf, size_t size, size_t *written)

    BIO_clear_retry_flags(b);

    ret = SSL_write_ex(ssl, buf, size, written);
    ret = ssl_write_internal(ssl, buf, size, written);

    switch (SSL_get_error(ssl, ret)) {
    case SSL_ERROR_NONE:
+3 −0
Original line number Diff line number Diff line
@@ -203,8 +203,10 @@ static ERR_STRING_DATA SSL_str_functs[] = {
     "ssl_parse_serverhello_use_srtp_ext"},
    {ERR_FUNC(SSL_F_SSL_PEEK), "SSL_peek"},
    {ERR_FUNC(SSL_F_SSL_PEEK_EX), "SSL_peek_ex"},
    {ERR_FUNC(SSL_F_SSL_PEEK_INTERNAL), "ssl_peek_internal"},
    {ERR_FUNC(SSL_F_SSL_READ), "SSL_read"},
    {ERR_FUNC(SSL_F_SSL_READ_EX), "SSL_read_ex"},
    {ERR_FUNC(SSL_F_SSL_READ_INTERNAL), "ssl_read_internal"},
    {ERR_FUNC(SSL_F_SSL_RENEGOTIATE), "SSL_renegotiate"},
    {ERR_FUNC(SSL_F_SSL_SCAN_CLIENTHELLO_TLSEXT),
     "ssl_scan_clienthello_tlsext"},
@@ -252,6 +254,7 @@ static ERR_STRING_DATA SSL_str_functs[] = {
    {ERR_FUNC(SSL_F_SSL_VERIFY_CERT_CHAIN), "ssl_verify_cert_chain"},
    {ERR_FUNC(SSL_F_SSL_WRITE), "SSL_write"},
    {ERR_FUNC(SSL_F_SSL_WRITE_EX), "SSL_write_ex"},
    {ERR_FUNC(SSL_F_SSL_WRITE_INTERNAL), "ssl_write_internal"},
    {ERR_FUNC(SSL_F_STATE_MACHINE), "state_machine"},
    {ERR_FUNC(SSL_F_TLS12_CHECK_PEER_SIGALG), "tls12_check_peer_sigalg"},
    {ERR_FUNC(SSL_F_TLS13_CHANGE_CIPHER_STATE), "tls13_change_cipher_state"},
+68 −40
Original line number Diff line number Diff line
@@ -1532,38 +1532,16 @@ static int ssl_io_intern(void *vargs)
    return -1;
}

int SSL_read(SSL *s, void *buf, int num)
{
    int ret;
    size_t readbytes;

    if (num < 0) {
        SSLerr(SSL_F_SSL_READ, SSL_R_BAD_LENGTH);
        return -1;
    }

    ret = SSL_read_ex(s, buf, (size_t)num, &readbytes);

    /*
     * The cast is safe here because ret should be <= INT_MAX because num is
     * <= INT_MAX
     */
    if (ret > 0)
        ret = (int)readbytes;

    return ret;
}

int SSL_read_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
int ssl_read_internal(SSL *s, void *buf, size_t num, size_t *readbytes)
{
    if (s->handshake_func == NULL) {
        SSLerr(SSL_F_SSL_READ_EX, SSL_R_UNINITIALIZED);
        SSLerr(SSL_F_SSL_READ_INTERNAL, SSL_R_UNINITIALIZED);
        return -1;
    }

    if (s->shutdown & SSL_RECEIVED_SHUTDOWN) {
        s->rwstate = SSL_NOTHING;
        return (0);
        return 0;
    }

    if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
@@ -1584,17 +1562,17 @@ int SSL_read_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
    }
}

int SSL_peek(SSL *s, void *buf, int num)
int SSL_read(SSL *s, void *buf, int num)
{
    int ret;
    size_t readbytes;

    if (num < 0) {
        SSLerr(SSL_F_SSL_PEEK, SSL_R_BAD_LENGTH);
        SSLerr(SSL_F_SSL_READ, SSL_R_BAD_LENGTH);
        return -1;
    }

    ret = SSL_peek_ex(s, buf, (size_t)num, &readbytes);
    ret = ssl_read_internal(s, buf, (size_t)num, &readbytes);

    /*
     * The cast is safe here because ret should be <= INT_MAX because num is
@@ -1606,15 +1584,24 @@ int SSL_peek(SSL *s, void *buf, int num)
    return ret;
}

int SSL_peek_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
int SSL_read_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
{
    int ret = ssl_read_internal(s, buf, num, readbytes);

    if (ret < 0)
        ret = 0;
    return ret;
}

static int ssl_peek_internal(SSL *s, void *buf, size_t num, size_t *readbytes)
{
    if (s->handshake_func == NULL) {
        SSLerr(SSL_F_SSL_PEEK_EX, SSL_R_UNINITIALIZED);
        SSLerr(SSL_F_SSL_PEEK_INTERNAL, SSL_R_UNINITIALIZED);
        return -1;
    }

    if (s->shutdown & SSL_RECEIVED_SHUTDOWN) {
        return (0);
        return 0;
    }
    if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
        struct ssl_async_args args;
@@ -1634,39 +1621,49 @@ int SSL_peek_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
    }
}

int SSL_write(SSL *s, const void *buf, int num)
int SSL_peek(SSL *s, void *buf, int num)
{
    int ret;
    size_t written;
    size_t readbytes;

    if (num < 0) {
        SSLerr(SSL_F_SSL_WRITE, SSL_R_BAD_LENGTH);
        SSLerr(SSL_F_SSL_PEEK, SSL_R_BAD_LENGTH);
        return -1;
    }

    ret = SSL_write_ex(s, buf, (size_t)num, &written);
    ret = ssl_peek_internal(s, buf, (size_t)num, &readbytes);

    /*
     * The cast is safe here because ret should be <= INT_MAX because num is
     * <= INT_MAX
     */
    if (ret > 0)
        ret = (int)written;
        ret = (int)readbytes;

    return ret;
}

int SSL_write_ex(SSL *s, const void *buf, size_t num, size_t *written)

int SSL_peek_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
{
    int ret = ssl_peek_internal(s, buf, num, readbytes);

    if (ret < 0)
        ret = 0;
    return ret;
}

int ssl_write_internal(SSL *s, const void *buf, size_t num, size_t *written)
{
    if (s->handshake_func == NULL) {
        SSLerr(SSL_F_SSL_WRITE_EX, SSL_R_UNINITIALIZED);
        SSLerr(SSL_F_SSL_WRITE_INTERNAL, SSL_R_UNINITIALIZED);
        return -1;
    }

    if (s->shutdown & SSL_SENT_SHUTDOWN) {
        s->rwstate = SSL_NOTHING;
        SSLerr(SSL_F_SSL_WRITE_EX, SSL_R_PROTOCOL_IS_SHUTDOWN);
        return (-1);
        SSLerr(SSL_F_SSL_WRITE_INTERNAL, SSL_R_PROTOCOL_IS_SHUTDOWN);
        return -1;
    }

    if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
@@ -1687,6 +1684,37 @@ int SSL_write_ex(SSL *s, const void *buf, size_t num, size_t *written)
    }
}

int SSL_write(SSL *s, const void *buf, int num)
{
    int ret;
    size_t written;

    if (num < 0) {
        SSLerr(SSL_F_SSL_WRITE, SSL_R_BAD_LENGTH);
        return -1;
    }

    ret = ssl_write_internal(s, buf, (size_t)num, &written);

    /*
     * The cast is safe here because ret should be <= INT_MAX because num is
     * <= INT_MAX
     */
    if (ret > 0)
        ret = (int)written;

    return ret;
}

int SSL_write_ex(SSL *s, const void *buf, size_t num, size_t *written)
{
    int ret = ssl_write_internal(s, buf, num, written);

    if (ret < 0)
        ret = 0;
    return ret;
}

int SSL_shutdown(SSL *s)
{
    /*
+3 −1
Original line number Diff line number Diff line
@@ -1979,6 +1979,8 @@ static ossl_inline int ssl_has_cert(const SSL *s, int idx)

# ifndef OPENSSL_UNIT_TEST

__owur int ssl_read_internal(SSL *s, void *buf, size_t num, size_t *readbytes);
__owur int ssl_write_internal(SSL *s, const void *buf, size_t num, size_t *written);
void ssl_clear_cipher_ctx(SSL *s);
int ssl_clear_bad_session(SSL *s);
__owur CERT *ssl_cert_new(void);
@@ -2384,7 +2386,7 @@ void custom_exts_free(custom_ext_methods *exts);

void ssl_comp_free_compression_methods_int(void);

# else
# else /* OPENSSL_UNIT_TEST */

#  define ssl_init_wbio_buffer SSL_test_functions()->p_ssl_init_wbio_buffer
#  define ssl3_setup_buffers SSL_test_functions()->p_ssl3_setup_buffers