easy-tls.c 31.3 KB
Newer Older
	    }
	}
	
	if (clear_to_tls.len != 0 && !in_handshake) {
	    assert(!closed);
	    
	    r = tls_write_attempt(ssl, &clear_to_tls, &tls_write_select, &tls_read_select, &closed, &progress, &err_pref_1);
	    if (r != 0)
		goto err;
	    if (closed) {
		assert(progress);
		tls_to_clear.offset = 0;
		tls_to_clear.len = 0;
	    }
	}
	
	if (tls_to_clear.len != 0) {
	    assert(!closed);

	    r = write_attempt(clear_fd, &tls_to_clear, &clear_write_select, &closed, &progress);
	    if (r != 0)
		goto err_return;
	    if (closed) {
		assert(progress);
		clear_to_tls.offset = 0;
		clear_to_tls.len = 0;
	    }
	}
	
	if (!closed) {
	    if (clear_to_tls.offset + clear_to_tls.len < sizeof clear_to_tls.buf) {
		r = read_attempt(clear_fd, &clear_to_tls, &clear_read_select, &closed, &progress);
		if (r != 0)
		    goto err_return;
		if (closed) {
		    r = SSL_shutdown(ssl);
		    DEBUG_MSG2("SSL_shutdown", r);
		}
	    }
	}
	
	if (!closed && !in_handshake) {
	    if (tls_to_clear.offset + tls_to_clear.len < sizeof tls_to_clear.buf) {
		r = tls_read_attempt(ssl, &tls_to_clear, &tls_write_select, &tls_read_select, &closed, &progress, &err_pref_1);
		if (r != 0)
		    goto err;
		if (closed) {
		    r = SSL_shutdown(ssl);
		    DEBUG_MSG2("SSL_shutdown", r);
		}
	    }
	}

	if (!progress) {
	    DEBUG_MSG("!progress?");
	    if (num_read != BIO_number_read(rbio) || num_written != BIO_number_written(wbio))
		progress = 1;

	    if (!progress) {
		DEBUG_MSG("!progress");
		assert(clear_read_select || tls_read_select || clear_write_select || tls_write_select);
		tls_sockets_select(clear_read_select ? clear_fd : -1, tls_read_select ? tls_fd : -1, clear_write_select ? clear_fd : -1, tls_write_select ? tls_fd : -1, -1);
	    }
	}
    } while (!closed);
    return;

 err:
    tls_openssl_errors(err_pref_1, err_pref_2, err_def, tls_child_apparg);
 err_return:
    return;
}


static int
tls_get_error(SSL *ssl, int r, int *write_select, int *read_select, int *closed, int *progress)
{
    int err = SSL_get_error(ssl, r);

    if (err == SSL_ERROR_NONE) {
	assert(r > 0);
	*progress = 1;
	return 0;
    }

    assert(r <= 0);

    switch (err) {
    case SSL_ERROR_ZERO_RETURN:
	assert(r == 0);
	*closed = 1;
	*progress = 1;
	return 0;

    case SSL_ERROR_WANT_WRITE:
	*write_select = 1;
	return 0;
	
    case SSL_ERROR_WANT_READ:
	*read_select = 1;
	return 0;
    }

    return -1;
}

static int
tls_connect_attempt(SSL *ssl, int *write_select, int *read_select, int *closed, int *progress, const char **err_pref)
{
    int n, r;

    DEBUG_MSG("tls_connect_attempt");
    n = SSL_connect(ssl);
    DEBUG_MSG2("SSL_connect",n);
    r = tls_get_error(ssl, n, write_select, read_select, closed, progress);
    if (r == -1)
	*err_pref = " during SSL_connect";
    return r;
}

static int
tls_accept_attempt(SSL *ssl, int *write_select, int *read_select, int *closed, int *progress, const char **err_pref)
{
    int n, r;

    DEBUG_MSG("tls_accept_attempt");
    n = SSL_accept(ssl);
    DEBUG_MSG2("SSL_accept",n);
    r = tls_get_error(ssl, n, write_select, read_select, closed, progress);
    if (r == -1)
	*err_pref = " during SSL_accept";
    return r;
}

static int
tls_write_attempt(SSL *ssl, struct tunnelbuf *buf, int *write_select, int *read_select, int *closed, int *progress, const char **err_pref)
{
    int n, r;

    DEBUG_MSG("tls_write_attempt");
    n = SSL_write(ssl, buf->buf + buf->offset, buf->len);
    DEBUG_MSG2("SSL_write",n);
    r = tls_get_error(ssl, n, write_select, read_select, closed, progress);
    if (n > 0) {
	buf->len -= n;
	assert(buf->len >= 0);
	if (buf->len == 0)
	    buf->offset = 0;
	else
	    buf->offset += n;
    }
    if (r == -1)
	*err_pref = " during SSL_write";
    return r;
}

static int
tls_read_attempt(SSL *ssl, struct tunnelbuf *buf, int *write_select, int *read_select, int *closed, int *progress, const char **err_pref)
{
    int n, r;
    size_t total;

    DEBUG_MSG("tls_read_attempt");
    total = buf->offset + buf->len;
    assert(total < sizeof buf->buf);
    n = SSL_read(ssl, buf->buf + total, (sizeof buf->buf) - total);
    DEBUG_MSG2("SSL_read",n);
    r = tls_get_error(ssl, n, write_select, read_select, closed, progress);
    if (n > 0) {
	buf->len += n;
	assert(buf->offset + buf->len <= sizeof buf->buf);
    }
    if (r == -1)
	*err_pref = " during SSL_read";
    return r;
}

static int
get_error(int r, int *select, int *closed, int *progress)
{
    if (r >= 0) {
	*progress = 1;
	if (r == 0)
	    *closed = 1;
	return 0;
    } else {
	assert(r == -1);
	if (errno == EAGAIN || errno == EWOULDBLOCK) {
	    *select = 1;
	    return 0;
	} else if (errno == EPIPE) {
	    *progress = 1;
	    *closed = 1;
	    return 0;
	} else
	    return -1;
    }
}

static int write_attempt(int fd, struct tunnelbuf *buf, int *select, int *closed, int *progress)
{
    int n, r;

    DEBUG_MSG("write_attempt");
    n = write(fd, buf->buf + buf->offset, buf->len);
    DEBUG_MSG2("write",n);
    r = get_error(n, select, closed, progress);
    if (n > 0) {
	buf->len -= n;
	assert(buf->len >= 0);
	if (buf->len == 0)
	    buf->offset = 0;
	else
	    buf->offset += n;
    }
    if (r == -1)
	tls_errprintf(1, tls_child_apparg, "write error: %s\n", strerror(errno));
    return r;
}
    
static int
read_attempt(int fd, struct tunnelbuf *buf, int *select, int *closed, int *progress)
{
    int n, r;
    size_t total;

    DEBUG_MSG("read_attempt");
    total = buf->offset + buf->len;
    assert(total < sizeof buf->buf);
    n = read(fd, buf->buf + total, (sizeof buf->buf) - total);
    DEBUG_MSG2("read",n);
    r = get_error(n, select, closed, progress);
    if (n > 0) {
	buf->len += n;
	assert(buf->offset + buf->len <= sizeof buf->buf);
    }
    if (r == -1)
	tls_errprintf(1, tls_child_apparg, "read error: %s\n", strerror(errno));
    return r;
}