On Thu, Jun 26, 2014 at 7:43 AM, Arne Becker <arne_bec...@genua.de> wrote:
> Hi.
>
>> Now soliciting diffs to change readwrite to a loop with two buffers
>> that poll()s in all four directions. :)
>
> Good thing you made me remember I wrote just this a while ago.
> This is my first OpenBSD diff, so tell me if I missed anything obvious.
> Tested quite extensively originally; for this diff I only checked a
> simple nc to nc "hello".
>
> Index: netcat.c
> ===================================================================
> RCS file: /mount/cvsdev/cvs/openbsd/src/usr.bin/nc/netcat.c,v
> retrieving revision 1.121
> diff -u -p -r1.121 netcat.c
> --- netcat.c    10 Jun 2014 16:35:42 -0000      1.121
> +++ netcat.c    26 Jun 2014 11:29:45 -0000
> @@ -65,6 +65,12 @@
>  #define PORT_MAX_LEN   6
>  #define UNIX_DG_TMP_SOCKET_SIZE        19
>
> +#define POLL_STDIN 0
> +#define POLL_NETOUT 1
> +#define POLL_NETIN 2
> +#define POLL_STDOUT 3
> +#define BUFSIZE 16384
> +
>  /* Command Line Options */
>  int    dflag;                                  /* detached, no stdin */
>  int    Fflag;                                  /* fdpass sock to stdout */
> @@ -112,6 +118,8 @@ void        set_common_sockopts(int);
>  int    map_tos(char *, int *);
>  void   report_connect(const struct sockaddr *, socklen_t);
>  void   usage(int);
> +ssize_t drainbuf(int, unsigned char *, size_t *);
> +ssize_t fillbuf(int, unsigned char *, size_t *);
>
>  int
>  main(int argc, char *argv[])
> @@ -608,7 +616,7 @@ remote_connect(const char *host, const c
>
>                         if (bind(s, (struct sockaddr *)ares->ai_addr,
>                             ares->ai_addrlen) < 0)
> -                               err(1, "bind failed");
> +                               errx(1, "bind failed: %s", strerror(errno));
>                         freeaddrinfo(ares);
>                 }
>
> @@ -640,7 +648,7 @@ timeout_connect(int s, const struct sock
>         if (timeout != -1) {
>                 flags = fcntl(s, F_GETFL, 0);
>                 if (fcntl(s, F_SETFL, flags | O_NONBLOCK) == -1)
> -                       err(1, "set non-blocking mode");
> +                       warn("unable to set non-blocking mode");
>         }
>
>         if ((ret = connect(s, name, namelen)) != 0 && errno == EINPROGRESS) {
> @@ -730,67 +738,229 @@ local_listen(char *host, char *port, str
>   * Loop that polls on the network file descriptor and stdin.
>   */
>  void
> -readwrite(int nfd)
> +readwrite(int net_fd)
>  {
> -       struct pollfd pfd[2];
> -       unsigned char buf[16 * 1024];
> -       int n, wfd = fileno(stdin);
> -       int lfd = fileno(stdout);
> -       int plen;
> -
> -       plen = sizeof(buf);
> -
> -       /* Setup Network FD */
> -       pfd[0].fd = nfd;
> -       pfd[0].events = POLLIN;
> +       struct pollfd pfd[4];
> +       int stdin_fd = STDIN_FILENO;
> +       int stdout_fd = STDOUT_FILENO;
> +       unsigned char netinbuf[BUFSIZE];
> +       size_t netinbufpos = 0;
> +       unsigned char stdinbuf[BUFSIZE];
> +       size_t stdinbufpos = 0;
> +       int n, num_fds, flags;
> +       ssize_t ret;
> +
> +       /* don't read from stdin if requested */
> +       if (dflag)
> +               stdin_fd = -1;
> +
> +       /* stdin */
> +       pfd[POLL_STDIN].fd = stdin_fd;
> +       pfd[POLL_STDIN].events = POLLIN;
> +
> +       /* network out */
> +       pfd[POLL_NETOUT].fd = net_fd;
> +       pfd[POLL_NETOUT].events = 0;
> +
> +       /* network in */
> +       pfd[POLL_NETIN].fd = net_fd;
> +       pfd[POLL_NETIN].events = POLLIN;
> +
> +       /* stdout */
> +       pfd[POLL_STDOUT].fd = stdout_fd;
> +       pfd[POLL_STDOUT].events = 0;
> +
> +
> +       /* make all fds non-blocking */
> +       for (n = 0; n < 4; n++) {
> +               if (pfd[n].fd == -1)
> +                       continue;
> +               flags = fcntl(pfd[n].fd, F_GETFL, 0);
> +               /*
> +                * For sockets and pipes, we want non-block, but setting it
> +                * might fail for files or devices, so we ignore the return
> +                * code.
> +                */
> +               fcntl(pfd[n].fd, F_SETFL, flags | O_NONBLOCK);
> +       }
>
> -       /* Set up STDIN FD. */
> -       pfd[1].fd = wfd;
> -       pfd[1].events = POLLIN;
> +       while (1) {
> +               /* both inputs are gone, buffers are empty, we are done */
> +               if (pfd[POLL_STDIN].fd == -1 && pfd[POLL_NETIN].fd == -1
> +                   && stdinbufpos == 0 && netinbufpos == 0) {
> +                       close(net_fd);
> +                       return;
> +               }
> +               /* both outputs are gone, we can't continue */
> +               if (pfd[POLL_NETOUT].fd == -1 && pfd[POLL_STDOUT].fd == -1) {
> +                       close(net_fd);
> +                       return;
> +               }
> +               /* listen and net in gone, queues empty, done */
> +               if (lflag && pfd[POLL_NETIN].fd == -1

lflag ???
warning only one ref in the diff

> +                   && stdinbufpos == 0 && netinbufpos == 0) {
> +                       close(net_fd);
> +                       return;
> +               }
>
> -       while (pfd[0].fd != -1) {
> +               /* help says -i is for "wait between lines sent". We read and
> +                * write arbitray amounts of data, and we don't want to start
> +                * scanning for newlines, so this is as good as it gets */
>                 if (iflag)
>                         sleep(iflag);
>
> -               if ((n = poll(pfd, 2 - dflag, timeout)) < 0) {
> -                       close(nfd);
> -                       err(1, "Polling Error");
> +               /* poll */
> +               num_fds = poll(pfd, 4, timeout);
> +
> +               /* treat poll errors */
> +               if (num_fds == -1) {
> +                       close(net_fd);
> +                       err(1, "polling error");
>                 }
>
> -               if (n == 0)
> +               /* timeout happened */
> +               if (num_fds == 0)
>                         return;
>
> -               if (pfd[0].revents & POLLIN) {
> -                       if ((n = read(nfd, buf, plen)) < 0)
> -                               return;
> -                       else if (n == 0) {
> -                               shutdown(nfd, SHUT_RD);
> -                               pfd[0].fd = -1;
> -                               pfd[0].events = 0;
> -                       } else {
> -                               if (tflag)
> -                                       atelnet(nfd, buf, n);
> -                               if (atomicio(vwrite, lfd, buf, n) != n)
> -                                       return;
> +               /* treat socket error conditions */
> +               for (n = 0; n < 4; n++) {
> +                       if (pfd[n].revents & (POLLERR|POLLNVAL)) {
> +                               pfd[n].fd = -1;
> +                               continue;
>                         }
>                 }
> +               /* only treat the fds used for writing, because
> +                * reading is always possible even after HUP */
> +               if (pfd[POLL_NETOUT].revents & POLLHUP) {
> +                       if (Nflag)
> +                               shutdown(pfd[POLL_NETOUT].fd, SHUT_WR);
> +                       pfd[POLL_NETOUT].fd = -1;
> +               }
> +               /* if HUP, stop watching stdout */
> +               if (pfd[POLL_STDOUT].revents & POLLHUP)
> +                       pfd[POLL_STDOUT].fd = -1;
> +               /* if no net out, stop watching stdin */
> +               if (pfd[POLL_NETOUT].fd == -1)
> +                       pfd[POLL_STDIN].fd = -1;
> +               /* if no stdout, stop watching net in */
> +               if (pfd[POLL_STDOUT].fd == -1) {
> +                       if (pfd[POLL_NETIN].fd != -1)
> +                               shutdown(pfd[POLL_NETIN].fd, SHUT_RD);
> +                       pfd[POLL_NETIN].fd = -1;
> +               }
>
> -               if (!dflag && pfd[1].revents & POLLIN) {
> -                       if ((n = read(wfd, buf, plen)) < 0)
> -                               return;
> -                       else if (n == 0) {
> -                               if (Nflag)
> -                                       shutdown(nfd, SHUT_WR);
> -                               pfd[1].fd = -1;
> -                               pfd[1].events = 0;
> -                       } else {
> -                               if (atomicio(vwrite, nfd, buf, n) != n)
> -                                       return;
> +               /* try to read from stdin */
> +               if (pfd[POLL_STDIN].revents & POLLIN && stdinbufpos < 
> BUFSIZE) {
> +                       ret = fillbuf(pfd[POLL_STDIN].fd, stdinbuf,
> +                           &stdinbufpos);
> +                       /* error or eof on stdin - remove from pfd */
> +                       if (ret == 0 || ret == -1)
> +                               pfd[POLL_STDIN].fd = -1;
> +                       /* read something - poll net out */
> +                       if (stdinbufpos > 0)
> +                               pfd[POLL_NETOUT].events = POLLOUT;
> +                       /* filled buffer - remove self from polling */
> +                       if (stdinbufpos == BUFSIZE)
> +                               pfd[POLL_STDIN].events = 0;
> +               }
> +               /* try to write to network */
> +               if (pfd[POLL_NETOUT].revents & POLLOUT && stdinbufpos > 0) {
> +                       ret = drainbuf(pfd[POLL_NETOUT].fd, stdinbuf,
> +                           &stdinbufpos);
> +                       if (ret == -1)
> +                               pfd[POLL_NETOUT].fd = -1;
> +                       /* buffer empty - remove self from polling */
> +                       if (stdinbufpos == 0)
> +                               pfd[POLL_NETOUT].events = 0;
> +                       /* buffer no longer full - poll stdin again */
> +                       if (stdinbufpos < BUFSIZE)
> +                               pfd[POLL_STDIN].events = POLLIN;
> +               }
> +               /* try to read from network */
> +               if (pfd[POLL_NETIN].revents & POLLIN && netinbufpos < 
> BUFSIZE) {
> +                       ret = fillbuf(pfd[POLL_NETIN].fd, netinbuf,
> +                           &netinbufpos);
> +                       if (ret == -1)
> +                               pfd[POLL_NETIN].fd = -1;
> +                       /* eof on net in - remove from pfd */
> +                       if (ret == 0) {
> +                               shutdown(pfd[POLL_NETIN].fd, SHUT_RD);
> +                               pfd[POLL_NETIN].fd = -1;
>                         }
> +                       /* read something - poll stdout */
> +                       if (netinbufpos > 0)
> +                               pfd[POLL_STDOUT].events = POLLOUT;
> +                       /* filled buffer - remove self from polling */
> +                       if (netinbufpos == BUFSIZE)
> +                               pfd[POLL_NETIN].events = 0;
> +                       /* handle telnet */
> +                       if (tflag)
> +                               atelnet(pfd[POLL_NETIN].fd, netinbuf,
> +                                   netinbufpos);
> +               }
> +               /* try to write to stdout */
> +               if (pfd[POLL_STDOUT].revents & POLLOUT && netinbufpos > 0) {
> +                       ret = drainbuf(pfd[POLL_STDOUT].fd, netinbuf,
> +                           &netinbufpos);
> +                       if (ret == -1)
> +                               pfd[POLL_STDOUT].fd = -1;
> +                       /* buffer empty - remove self from polling */
> +                       if (netinbufpos == 0)
> +                               pfd[POLL_STDOUT].events = 0;
> +                       /* buffer no longer full - poll net in again */
> +                       if (netinbufpos < BUFSIZE)
> +                               pfd[POLL_NETIN].events = POLLIN;
> +               }
> +
> +               /* stdin gone and queue empty? */
> +               if (pfd[POLL_STDIN].fd == -1 && stdinbufpos == 0) {
> +                       if (pfd[POLL_NETOUT].fd != -1 && Nflag)
> +                               shutdown(pfd[POLL_NETOUT].fd, SHUT_WR);
> +                       pfd[POLL_NETOUT].fd = -1;
>                 }
> +               /* net in gone and queue empty? */
> +               if (pfd[POLL_NETIN].fd == -1 && netinbufpos == 0)
> +                       pfd[POLL_STDOUT].fd = -1;
>         }
>  }
>
> +ssize_t
> +drainbuf(int fd, unsigned char *buf, size_t *bufpos)
> +{
> +       ssize_t n;
> +       ssize_t adjust;
> +
> +       n = write(fd, buf, *bufpos);
> +       /* don't treat EAGAIN, EINTR as error */
> +       if (n == -1 && (errno == EAGAIN || errno == EINTR))
> +               n = -2;
> +       if (n <= 0)
> +               return n;
> +       /* adjust buffer */
> +       adjust = *bufpos - n;
> +       if (adjust > 0)
> +               memmove(buf, buf + n, adjust);
> +       *bufpos -= n;
> +       return n;
> +}
> +
> +
> +ssize_t
> +fillbuf(int fd, unsigned char *buf, size_t *bufpos)
> +{
> +       size_t num = BUFSIZE - *bufpos;
> +       ssize_t n;
> +
> +       n = read(fd, buf + *bufpos, num);
> +       /* don't treat EAGAIN, EINTR as error */
> +       if (n == -1 && (errno == EAGAIN || errno == EINTR))
> +               n = -2;
> +       if (n <= 0)
> +               return n;
> +       *bufpos += n;
> +       return n;
> +}
> +
>  /*
>   * fdpass()
>   * Pass the connected file descriptor to stdout and exit.
> @@ -857,6 +1027,8 @@ atelnet(int nfd, unsigned char *buf, uns
>  {
>         unsigned char *p, *end;
>         unsigned char obuf[4];
> +       int flags;
> +       int blocking = 0;
>
>         if (size < 3)
>                 return;
> @@ -877,8 +1049,20 @@ atelnet(int nfd, unsigned char *buf, uns
>
>                 p++;
>                 obuf[2] = *p;
> +
> +               if (!blocking) {
> +                       flags = fcntl(nfd, F_GETFL, 0);
> +                       if (fcntl(nfd, F_SETFL, flags & ~O_NONBLOCK) == -1)
> +                               warn("unable to set blocking mode");
> +                       blocking = 1;
> +               }
>                 if (atomicio(vwrite, nfd, obuf, 3) != 3)
>                         warn("Write Error!");
> +       }
> +       if (blocking) {
> +               flags = fcntl(nfd, F_GETFL, 0);
> +               if (fcntl(nfd, F_SETFL, flags | O_NONBLOCK) == -1)
> +                       warn("unable to set non-blocking mode");
>         }
>  }
>



-- 
---------------------------------------------------------------------------------------------------------------------
() ascii ribbon campaign - against html e-mail
/\

Reply via email to