Commit 9c941e92 authored by Kamil Dudka's avatar Kamil Dudka
Browse files

nss: propagate blocking direction from NSPR I/O

... during the non-blocking SSL handshake
parent 2e57c7e0
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -1389,7 +1389,7 @@ static CURLcode https_connecting(struct connectdata *conn, bool *done)
#endif

#if defined(USE_SSLEAY) || defined(USE_GNUTLS) || defined(USE_SCHANNEL) || \
    defined(USE_DARWINSSL) || defined(USE_POLARSSL)
    defined(USE_DARWINSSL) || defined(USE_POLARSSL) || defined(USE_NSS)
/* This function is for OpenSSL, GnuTLS, darwinssl, schannel and polarssl only.
   It should be made to query the generic SSL layer instead. */
static int https_getsock(struct connectdata *conn,
+103 −5
Original line number Diff line number Diff line
@@ -180,6 +180,10 @@ static const cipher_s cipherlist[] = {
static const char* pem_library = "libnsspem.so";
SECMODModule* mod = NULL;

/* NSPR I/O layer we use to detect blocking direction during SSL handshake */
static PRDescIdentity nspr_io_identity = PR_INVALID_IO_LAYER;
static PRIOMethods nspr_io_methods;

static const char* nss_error_to_name(PRErrorCode code)
{
  const char *name = PR_ErrorToName(code);
@@ -940,6 +944,60 @@ isTLSIntoleranceError(PRInt32 err)
  }
}

/* update blocking direction in case of PR_WOULD_BLOCK_ERROR */
static void nss_update_connecting_state(ssl_connect_state state, void *secret)
{
  struct ssl_connect_data *connssl = (struct ssl_connect_data *)secret;
  if(PR_GetError() != PR_WOULD_BLOCK_ERROR)
    /* an unrelated error is passing by */
    return;

  switch(connssl->connecting_state) {
  case ssl_connect_2:
  case ssl_connect_2_reading:
  case ssl_connect_2_writing:
    break;
  default:
    /* we are not called from an SSL handshake */
    return;
  }

  /* update the state accordingly */
  connssl->connecting_state = state;
}

/* recv() wrapper we use to detect blocking direction during SSL handshake */
static PRInt32 nspr_io_recv(PRFileDesc *fd, void *buf, PRInt32 amount,
                            PRIntn flags, PRIntervalTime timeout)
{
  const PRRecvFN recv_fn = fd->lower->methods->recv;
  const PRInt32 rv = recv_fn(fd->lower, buf, amount, flags, timeout);
  if(rv < 0)
    /* check for PR_WOULD_BLOCK_ERROR and update blocking direction */
    nss_update_connecting_state(ssl_connect_2_reading, fd->secret);
  return rv;
}

/* send() wrapper we use to detect blocking direction during SSL handshake */
static PRInt32 nspr_io_send(PRFileDesc *fd, const void *buf, PRInt32 amount,
                            PRIntn flags, PRIntervalTime timeout)
{
  const PRSendFN send_fn = fd->lower->methods->send;
  const PRInt32 rv = send_fn(fd->lower, buf, amount, flags, timeout);
  if(rv < 0)
    /* check for PR_WOULD_BLOCK_ERROR and update blocking direction */
    nss_update_connecting_state(ssl_connect_2_writing, fd->secret);
  return rv;
}

/* close() wrapper to avoid assertion failure due to fd->secret != NULL */
static PRStatus nspr_io_close(PRFileDesc *fd)
{
  const PRCloseFN close_fn = PR_GetDefaultIOMethods()->close;
  fd->secret = NULL;
  return close_fn(fd);
}

static CURLcode nss_init_core(struct SessionHandle *data, const char *cert_dir)
{
  NSSInitParameters initparams;
@@ -1004,6 +1062,21 @@ static CURLcode nss_init(struct SessionHandle *data)
    }
  }

  if(nspr_io_identity == PR_INVALID_IO_LAYER) {
    /* allocate an identity for our own NSPR I/O layer */
    nspr_io_identity = PR_GetUniqueIdentity("libcurl");
    if(nspr_io_identity == PR_INVALID_IO_LAYER)
      return CURLE_OUT_OF_MEMORY;

    /* the default methods just call down to the lower I/O layer */
    memcpy(&nspr_io_methods, PR_GetDefaultIOMethods(), sizeof nspr_io_methods);

    /* override certain methods in the table by our wrappers */
    nspr_io_methods.recv  = nspr_io_recv;
    nspr_io_methods.send  = nspr_io_send;
    nspr_io_methods.close = nspr_io_close;
  }

  rv = nss_init_core(data, cert_dir);
  if(rv)
    return rv;
@@ -1353,6 +1426,8 @@ static CURLcode nss_set_nonblock(struct ssl_connect_data *connssl,
static CURLcode nss_setup_connect(struct connectdata *conn, int sockindex)
{
  PRFileDesc *model = NULL;
  PRFileDesc *nspr_io = NULL;
  PRFileDesc *nspr_io_stub = NULL;
  PRBool ssl_no_cache;
  PRBool ssl_cbc_random_iv;
  struct SessionHandle *data = conn->data;
@@ -1525,11 +1600,34 @@ static CURLcode nss_setup_connect(struct connectdata *conn, int sockindex)
    goto error;
  }

  /* Import our model socket  onto the existing file descriptor */
  connssl->handle = PR_ImportTCPSocket(sockfd);
  connssl->handle = SSL_ImportFD(model, connssl->handle);
  if(!connssl->handle)
  /* wrap OS file descriptor by NSPR's file descriptor abstraction */
  nspr_io = PR_ImportTCPSocket(sockfd);
  if(!nspr_io)
    goto error;

  /* create our own NSPR I/O layer */
  nspr_io_stub = PR_CreateIOLayerStub(nspr_io_identity, &nspr_io_methods);
  if(!nspr_io_stub) {
    PR_Close(nspr_io);
    goto error;
  }

  /* make the per-connection data accessible from NSPR I/O callbacks */
  nspr_io_stub->secret = (void *)connssl;

  /* push our new layer to the NSPR I/O stack */
  if(PR_PushIOLayer(nspr_io, PR_TOP_IO_LAYER, nspr_io_stub) != PR_SUCCESS) {
    PR_Close(nspr_io);
    PR_Close(nspr_io_stub);
    goto error;
  }

  /* import our model socket onto the current I/O stack */
  connssl->handle = SSL_ImportFD(model, nspr_io);
  if(!connssl->handle) {
    PR_Close(nspr_io);
    goto error;
  }

  PR_Close(model); /* We don't need this any more */
  model = NULL;
@@ -1612,7 +1710,7 @@ static CURLcode nss_do_connect(struct connectdata *conn, int sockindex)
  timeout = PR_MillisecondsToInterval((PRUint32) time_left);
  if(SSL_ForceHandshakeWithTimeout(connssl->handle, timeout) != SECSuccess) {
    if(PR_GetError() == PR_WOULD_BLOCK_ERROR)
      /* TODO: propagate the blocking direction from the NSPR layer */
      /* blocking direction is updated by nss_update_connecting_state() */
      return CURLE_AGAIN;
    else if(conn->data->set.ssl.certverifyresult == SSL_ERROR_BAD_CERT_DOMAIN)
      curlerr = CURLE_PEER_FAILED_VERIFICATION;