Commit 058e764a authored by Yang Tse's avatar Yang Tse
Browse files

all reads from stdin and writes to stdout will be retried until the

whole operation completes or an unrecoverable condition is detected
parent 0d09f342
Loading
Loading
Loading
Loading
+206 −89
Original line number Diff line number Diff line
@@ -98,6 +98,7 @@ static volatile int sigpipe; /* Why? It's not used */

const char *serverlogfile = (char *)DEFAULT_LOGFILE;

bool verbose = FALSE;
bool use_ipv6 = FALSE;
unsigned short port = DEFAULT_PORT;
unsigned short connectport = 0; /* if non-zero, we activate this mode */
@@ -109,6 +110,118 @@ enum sockmode {
  ACTIVE_DISCONNECT  /* as a client, disconnected from server */
};

/*
 * fullread is a wrapper around the read() function. This will repeat the call
 * to read() until it actually has read the complete number of bytes indicated
 * in nbytes or it fails with a condition that cannot be handled with a simple
 * retry of the read call.
 */

static ssize_t fullread(int filedes, void *buffer, size_t nbytes)
{
  int error;
  ssize_t rc;
  ssize_t nread = 0;

  do {
    rc = read(filedes, (unsigned char *)buffer + nread, nbytes - nread);

    if(rc < 0) {
      error = ERRNO;
      if((error == EINTR) || (error == EAGAIN))
        continue;
      logmsg("unrecoverable read() failure: %s", strerror(error));
      return -1;
    }

    if(rc == 0) {
      logmsg("got 0 reading from stdin");
      return 0;
    }

    nread += rc;

  } while((size_t)nread < nbytes);

  if(verbose)
    logmsg("read %ld bytes", (long)nread);

  return nread;
}

/*
 * fullwrite is a wrapper around the write() function. This will repeat the
 * call to write() until it actually has written the complete number of bytes
 * indicated in nbytes or it fails with a condition that cannot be handled
 * with a simple retry of the write call.
 */

static ssize_t fullwrite(int filedes, const void *buffer, size_t nbytes)
{
  int error;
  ssize_t wc;
  ssize_t nwrite = 0;

  do {
    wc = write(filedes, (unsigned char *)buffer + nwrite, nbytes - nwrite);

    if(wc < 0) {
      error = ERRNO;
      if((error == EINTR) || (error == EAGAIN))
        continue;
      logmsg("unrecoverable write() failure: %s", strerror(error));
      return -1;
    }

    if(wc == 0) {
      logmsg("put 0 writing to stdout");
      return 0;
    }

    nwrite += wc;

  } while((size_t)nwrite < nbytes);

  if(verbose)
    logmsg("wrote %ld bytes", (long)nwrite);

  return nwrite;
}

/*
 * read_stdin tries to read from stdin nbytes into the given buffer. This is a
 * blocking function that will only return TRUE when nbytes have actually been
 * read or FALSE when an unrecoverable error has been detected. Failure of this
 * function is an indication that the whole program should terminate.
 */

static bool read_stdin(void *buffer, size_t nbytes)
{
  ssize_t nread = fullread(fileno(stdin), buffer, nbytes);
  if(nread != (ssize_t)nbytes) {
    logmsg("exiting...");
    return FALSE;
  }
  return TRUE;
}

/*
 * write_stdout tries to write to stdio nbytes from the given buffer. This is a
 * blocking function that will only return TRUE when nbytes have actually been
 * written or FALSE when an unrecoverable error has been detected. Failure of
 * this function is an indication that the whole program should terminate.
 */

static bool write_stdout(const void *buffer, size_t nbytes)
{
  ssize_t nwrite = fullwrite(fileno(stdout), buffer, nbytes);
  if(nwrite != (ssize_t)nbytes) {
    logmsg("exiting...");
    return FALSE;
  }
  return TRUE;
}

static void lograw(unsigned char *buffer, ssize_t len)
{
  char data[120];
@@ -171,7 +284,6 @@ static bool juggle(curl_socket_t *sockfdp,
  curl_socket_t sockfd;
  curl_socket_t maxfd;
  ssize_t rc;
  ssize_t nread_stdin;
  ssize_t nread_socket;
  ssize_t bytes_written;
  ssize_t buffer_len;
@@ -184,8 +296,10 @@ static bool juggle(curl_socket_t *sockfdp,
#ifdef HAVE_GETPPID
  /* As a last resort, quit if sockfilt process becomes orphan. Just in case
     parent ftpserver process has died without killing its sockfilt children */
  if(getppid() <= 1)
  if(getppid() <= 1) {
    logmsg("process becomes orphan, exiting");
    return FALSE;
  }
#endif

  timeout.tv_sec = 120;
@@ -246,17 +360,19 @@ static bool juggle(curl_socket_t *sockfdp,

  } /* switch(*mode) */


  do {
    rc = select(maxfd + 1, &fds_read, &fds_write, &fds_err, &timeout);

    rc = select((int)maxfd + 1, &fds_read, &fds_write, &fds_err, &timeout);

  } while((rc == -1) && (SOCKERRNO == EINTR));

  switch(rc) {
  case -1:
  if(rc < 0)
    return FALSE;

  case 0: /* timeout! */
  if(rc == 0)
    /* timeout */
    return TRUE;
  }


  if(FD_ISSET(fileno(stdin), &fds_read)) {
@@ -274,15 +390,17 @@ static bool juggle(curl_socket_t *sockfdp,

       DATA - plain pass-thru data
    */
    nread_stdin = read(fileno(stdin), buffer, 5);
    if(5 == nread_stdin) {

    if(!read_stdin(buffer, 5))
      return FALSE;

    logmsg("Received %c%c%c%c (on stdin)",
           buffer[0], buffer[1], buffer[2], buffer[3] );

    if(!memcmp("PING", buffer, 4)) {
      /* send reply on stdout, just proving we are alive */
        write(fileno(stdout), "PONG\n", 5);
      if(!write_stdout("PONG\n", 5))
        return FALSE;
    }

    else if(!memcmp("PORT", buffer, 4)) {
@@ -291,8 +409,10 @@ static bool juggle(curl_socket_t *sockfdp,
      sprintf((char *)buffer, "IPv%d/%d\n", use_ipv6?6:4, (int)port);
      buffer_len = (ssize_t)strlen((char *)buffer);
      snprintf(data, sizeof(data), "PORT\n%04x\n", buffer_len);
        write(fileno(stdout), data, 10);
        write(fileno(stdout), buffer, buffer_len);
      if(!write_stdout(data, 10))
        return FALSE;
      if(!write_stdout(buffer, buffer_len))
        return FALSE;
    }
    else if(!memcmp("QUIT", buffer, 4)) {
      /* just die */
@@ -302,8 +422,9 @@ static bool juggle(curl_socket_t *sockfdp,
    else if(!memcmp("DATA", buffer, 4)) {
      /* data IN => data OUT */

        if(5 != read(fileno(stdin), buffer, 5))
      if(!read_stdin(buffer, 5))
        return FALSE;

      buffer[5] = '\0';

      buffer_len = (ssize_t)strtol((char *)buffer, NULL, 16);
@@ -314,26 +435,15 @@ static bool juggle(curl_socket_t *sockfdp,
      }
      logmsg("> %d bytes data, server => client", buffer_len);

        /*
         * To properly support huge data chunks, we need to repeat the call
         * to read() until we're done or it fails.
         */
        nread_stdin = 0;
        do {
          /* get data in the buffer at the correct position */
          rc = read(fileno(stdin), &buffer[nread_stdin],
                    buffer_len - nread_stdin);
          logmsg("read %d bytes", rc);
          if(rc <= 0)
      if(!read_stdin(buffer, buffer_len))
        return FALSE;
          nread_stdin += rc;
        } while (nread_stdin < buffer_len);

      lograw(buffer, buffer_len);

      if(*mode == PASSIVE_LISTEN) {
        logmsg("*** We are disconnected!");
          write(fileno(stdout), "DISC\n", 5);
        if(!write_stdout("DISC\n", 5))
          return FALSE;
      }
      else {
        /* send away on the socket */
@@ -346,7 +456,8 @@ static bool juggle(curl_socket_t *sockfdp,
    }
    else if(!memcmp("DISC", buffer, 4)) {
      /* disconnect! */
        write(fileno(stdout), "DISC\n", 5);
      if(!write_stdout("DISC\n", 5))
        return FALSE;
      if(sockfd != CURL_SOCKET_BAD) {
        logmsg("====> Client forcibly disconnected");
        sclose(sockfd);
@@ -361,11 +472,6 @@ static bool juggle(curl_socket_t *sockfdp,
      return TRUE;
    }
  }
    else if(-1 == nread_stdin) {
      logmsg("read %d from stdin, exiting", nread_stdin);
      return FALSE;
    }
  }


  if((sockfd != CURL_SOCKET_BAD) && (FD_ISSET(sockfd, &fds_read)) ) {
@@ -378,7 +484,8 @@ static bool juggle(curl_socket_t *sockfdp,
        logmsg("accept() failed");
      else {
        logmsg("====> Client connect");
        write(fileno(stdout), "CNCT\n", 5);
        if(!write_stdout("CNCT\n", 5))
          return FALSE;
        *sockfdp = sockfd; /* store the new socket */
        *mode = PASSIVE_CONNECT; /* we have connected */
      }
@@ -390,7 +497,8 @@ static bool juggle(curl_socket_t *sockfdp,

    if(nread_socket <= 0) {
      logmsg("====> Client disconnect");
      write(fileno(stdout), "DISC\n", 5);
      if(!write_stdout("DISC\n", 5))
        return FALSE;
      sclose(sockfd);
      *sockfdp = CURL_SOCKET_BAD;
      if(*mode == PASSIVE_CONNECT)
@@ -401,8 +509,10 @@ static bool juggle(curl_socket_t *sockfdp,
    }

    snprintf(data, sizeof(data), "DATA\n%04x\n", nread_socket);
    write(fileno(stdout), data, 10);
    write(fileno(stdout), buffer, nread_socket);
    if(!write_stdout(data, 10))
      return FALSE;
    if(!write_stdout(buffer, nread_socket))
      return FALSE;

    logmsg("< %d bytes data, client => server", nread_socket);
    lograw(buffer, nread_socket);
@@ -537,6 +647,10 @@ int main(int argc, char *argv[])
             );
      return 0;
    }
    else if(!strcmp("--verbose", argv[arg])) {
      verbose = TRUE;
      arg++;
    }
    else if(!strcmp("--pidfile", argv[arg])) {
      arg++;
      if(argc>arg)
@@ -585,6 +699,7 @@ int main(int argc, char *argv[])
    else {
      puts("Usage: sockfilt [option]\n"
           " --version\n"
           " --verbose\n"
           " --logfile [file]\n"
           " --pidfile [file]\n"
           " --ipv4\n"
@@ -689,7 +804,9 @@ int main(int argc, char *argv[])
  while(juggle(&msgsock, sock, &mode));

  sclose(sock);
  unlink(pidname);

  logmsg("sockfilt exits");
  return 0;
}