Hi.

> Now soliciting diffs to change readwrite to a loop with two buffers
> that poll()s in all four directions. :)

Good thing you made me remember I wrote just this a while ago.
This is my first OpenBSD diff, so tell me if I missed anything obvious.
Tested quite extensively originally; for this diff I only checked a
simple nc to nc "hello".

Index: netcat.c
===================================================================
RCS file: /mount/cvsdev/cvs/openbsd/src/usr.bin/nc/netcat.c,v
retrieving revision 1.121
diff -u -p -r1.121 netcat.c
--- netcat.c    10 Jun 2014 16:35:42 -0000      1.121
+++ netcat.c    26 Jun 2014 11:29:45 -0000
@@ -65,6 +65,12 @@
 #define PORT_MAX_LEN   6
 #define UNIX_DG_TMP_SOCKET_SIZE        19

+#define POLL_STDIN 0
+#define POLL_NETOUT 1
+#define POLL_NETIN 2
+#define POLL_STDOUT 3
+#define BUFSIZE 16384
+
 /* Command Line Options */
 int    dflag;                                  /* detached, no stdin */
 int    Fflag;                                  /* fdpass sock to stdout */
@@ -112,6 +118,8 @@ void        set_common_sockopts(int);
 int    map_tos(char *, int *);
 void   report_connect(const struct sockaddr *, socklen_t);
 void   usage(int);
+ssize_t drainbuf(int, unsigned char *, size_t *);
+ssize_t fillbuf(int, unsigned char *, size_t *);

 int
 main(int argc, char *argv[])
@@ -608,7 +616,7 @@ remote_connect(const char *host, const c

                        if (bind(s, (struct sockaddr *)ares->ai_addr,
                            ares->ai_addrlen) < 0)
-                               err(1, "bind failed");
+                               errx(1, "bind failed: %s", strerror(errno));
                        freeaddrinfo(ares);
                }

@@ -640,7 +648,7 @@ timeout_connect(int s, const struct sock
        if (timeout != -1) {
                flags = fcntl(s, F_GETFL, 0);
                if (fcntl(s, F_SETFL, flags | O_NONBLOCK) == -1)
-                       err(1, "set non-blocking mode");
+                       warn("unable to set non-blocking mode");
        }

        if ((ret = connect(s, name, namelen)) != 0 && errno == EINPROGRESS) {
@@ -730,67 +738,229 @@ local_listen(char *host, char *port, str
  * Loop that polls on the network file descriptor and stdin.
  */
 void
-readwrite(int nfd)
+readwrite(int net_fd)
 {
-       struct pollfd pfd[2];
-       unsigned char buf[16 * 1024];
-       int n, wfd = fileno(stdin);
-       int lfd = fileno(stdout);
-       int plen;
-
-       plen = sizeof(buf);
-
-       /* Setup Network FD */
-       pfd[0].fd = nfd;
-       pfd[0].events = POLLIN;
+       struct pollfd pfd[4];
+       int stdin_fd = STDIN_FILENO;
+       int stdout_fd = STDOUT_FILENO;
+       unsigned char netinbuf[BUFSIZE];
+       size_t netinbufpos = 0;
+       unsigned char stdinbuf[BUFSIZE];
+       size_t stdinbufpos = 0;
+       int n, num_fds, flags;
+       ssize_t ret;
+
+       /* don't read from stdin if requested */
+       if (dflag)
+               stdin_fd = -1;
+
+       /* stdin */
+       pfd[POLL_STDIN].fd = stdin_fd;
+       pfd[POLL_STDIN].events = POLLIN;
+
+       /* network out */
+       pfd[POLL_NETOUT].fd = net_fd;
+       pfd[POLL_NETOUT].events = 0;
+
+       /* network in */
+       pfd[POLL_NETIN].fd = net_fd;
+       pfd[POLL_NETIN].events = POLLIN;
+
+       /* stdout */
+       pfd[POLL_STDOUT].fd = stdout_fd;
+       pfd[POLL_STDOUT].events = 0;
+
+
+       /* make all fds non-blocking */
+       for (n = 0; n < 4; n++) {
+               if (pfd[n].fd == -1)
+                       continue;
+               flags = fcntl(pfd[n].fd, F_GETFL, 0);
+               /*
+                * For sockets and pipes, we want non-block, but setting it
+                * might fail for files or devices, so we ignore the return
+                * code.
+                */
+               fcntl(pfd[n].fd, F_SETFL, flags | O_NONBLOCK);
+       }

-       /* Set up STDIN FD. */
-       pfd[1].fd = wfd;
-       pfd[1].events = POLLIN;
+       while (1) {
+               /* both inputs are gone, buffers are empty, we are done */
+               if (pfd[POLL_STDIN].fd == -1 && pfd[POLL_NETIN].fd == -1
+                   && stdinbufpos == 0 && netinbufpos == 0) {
+                       close(net_fd);
+                       return;
+               }
+               /* both outputs are gone, we can't continue */
+               if (pfd[POLL_NETOUT].fd == -1 && pfd[POLL_STDOUT].fd == -1) {
+                       close(net_fd);
+                       return;
+               }
+               /* listen and net in gone, queues empty, done */
+               if (lflag && pfd[POLL_NETIN].fd == -1
+                   && stdinbufpos == 0 && netinbufpos == 0) {
+                       close(net_fd);
+                       return;
+               }

-       while (pfd[0].fd != -1) {
+               /* help says -i is for "wait between lines sent". We read and
+                * write arbitray amounts of data, and we don't want to start
+                * scanning for newlines, so this is as good as it gets */
                if (iflag)
                        sleep(iflag);

-               if ((n = poll(pfd, 2 - dflag, timeout)) < 0) {
-                       close(nfd);
-                       err(1, "Polling Error");
+               /* poll */
+               num_fds = poll(pfd, 4, timeout);
+
+               /* treat poll errors */
+               if (num_fds == -1) {
+                       close(net_fd);
+                       err(1, "polling error");
                }

-               if (n == 0)
+               /* timeout happened */
+               if (num_fds == 0)
                        return;

-               if (pfd[0].revents & POLLIN) {
-                       if ((n = read(nfd, buf, plen)) < 0)
-                               return;
-                       else if (n == 0) {
-                               shutdown(nfd, SHUT_RD);
-                               pfd[0].fd = -1;
-                               pfd[0].events = 0;
-                       } else {
-                               if (tflag)
-                                       atelnet(nfd, buf, n);
-                               if (atomicio(vwrite, lfd, buf, n) != n)
-                                       return;
+               /* treat socket error conditions */
+               for (n = 0; n < 4; n++) {
+                       if (pfd[n].revents & (POLLERR|POLLNVAL)) {
+                               pfd[n].fd = -1;
+                               continue;
                        }
                }
+               /* only treat the fds used for writing, because
+                * reading is always possible even after HUP */
+               if (pfd[POLL_NETOUT].revents & POLLHUP) {
+                       if (Nflag)
+                               shutdown(pfd[POLL_NETOUT].fd, SHUT_WR);
+                       pfd[POLL_NETOUT].fd = -1;
+               }
+               /* if HUP, stop watching stdout */
+               if (pfd[POLL_STDOUT].revents & POLLHUP)
+                       pfd[POLL_STDOUT].fd = -1;
+               /* if no net out, stop watching stdin */
+               if (pfd[POLL_NETOUT].fd == -1)
+                       pfd[POLL_STDIN].fd = -1;
+               /* if no stdout, stop watching net in */
+               if (pfd[POLL_STDOUT].fd == -1) {
+                       if (pfd[POLL_NETIN].fd != -1)
+                               shutdown(pfd[POLL_NETIN].fd, SHUT_RD);
+                       pfd[POLL_NETIN].fd = -1;
+               }

-               if (!dflag && pfd[1].revents & POLLIN) {
-                       if ((n = read(wfd, buf, plen)) < 0)
-                               return;
-                       else if (n == 0) {
-                               if (Nflag)
-                                       shutdown(nfd, SHUT_WR);
-                               pfd[1].fd = -1;
-                               pfd[1].events = 0;
-                       } else {
-                               if (atomicio(vwrite, nfd, buf, n) != n)
-                                       return;
+               /* try to read from stdin */
+               if (pfd[POLL_STDIN].revents & POLLIN && stdinbufpos < BUFSIZE) {
+                       ret = fillbuf(pfd[POLL_STDIN].fd, stdinbuf,
+                           &stdinbufpos);
+                       /* error or eof on stdin - remove from pfd */
+                       if (ret == 0 || ret == -1)
+                               pfd[POLL_STDIN].fd = -1;
+                       /* read something - poll net out */
+                       if (stdinbufpos > 0)
+                               pfd[POLL_NETOUT].events = POLLOUT;
+                       /* filled buffer - remove self from polling */
+                       if (stdinbufpos == BUFSIZE)
+                               pfd[POLL_STDIN].events = 0;
+               }
+               /* try to write to network */
+               if (pfd[POLL_NETOUT].revents & POLLOUT && stdinbufpos > 0) {
+                       ret = drainbuf(pfd[POLL_NETOUT].fd, stdinbuf,
+                           &stdinbufpos);
+                       if (ret == -1)
+                               pfd[POLL_NETOUT].fd = -1;
+                       /* buffer empty - remove self from polling */
+                       if (stdinbufpos == 0)
+                               pfd[POLL_NETOUT].events = 0;
+                       /* buffer no longer full - poll stdin again */
+                       if (stdinbufpos < BUFSIZE)
+                               pfd[POLL_STDIN].events = POLLIN;
+               }
+               /* try to read from network */
+               if (pfd[POLL_NETIN].revents & POLLIN && netinbufpos < BUFSIZE) {
+                       ret = fillbuf(pfd[POLL_NETIN].fd, netinbuf,
+                           &netinbufpos);
+                       if (ret == -1)
+                               pfd[POLL_NETIN].fd = -1;
+                       /* eof on net in - remove from pfd */
+                       if (ret == 0) {
+                               shutdown(pfd[POLL_NETIN].fd, SHUT_RD);
+                               pfd[POLL_NETIN].fd = -1;
                        }
+                       /* read something - poll stdout */
+                       if (netinbufpos > 0)
+                               pfd[POLL_STDOUT].events = POLLOUT;
+                       /* filled buffer - remove self from polling */
+                       if (netinbufpos == BUFSIZE)
+                               pfd[POLL_NETIN].events = 0;
+                       /* handle telnet */
+                       if (tflag)
+                               atelnet(pfd[POLL_NETIN].fd, netinbuf,
+                                   netinbufpos);
+               }
+               /* try to write to stdout */
+               if (pfd[POLL_STDOUT].revents & POLLOUT && netinbufpos > 0) {
+                       ret = drainbuf(pfd[POLL_STDOUT].fd, netinbuf,
+                           &netinbufpos);
+                       if (ret == -1)
+                               pfd[POLL_STDOUT].fd = -1;
+                       /* buffer empty - remove self from polling */
+                       if (netinbufpos == 0)
+                               pfd[POLL_STDOUT].events = 0;
+                       /* buffer no longer full - poll net in again */
+                       if (netinbufpos < BUFSIZE)
+                               pfd[POLL_NETIN].events = POLLIN;
+               }
+
+               /* stdin gone and queue empty? */
+               if (pfd[POLL_STDIN].fd == -1 && stdinbufpos == 0) {
+                       if (pfd[POLL_NETOUT].fd != -1 && Nflag)
+                               shutdown(pfd[POLL_NETOUT].fd, SHUT_WR);
+                       pfd[POLL_NETOUT].fd = -1;
                }
+               /* net in gone and queue empty? */
+               if (pfd[POLL_NETIN].fd == -1 && netinbufpos == 0)
+                       pfd[POLL_STDOUT].fd = -1;
        }
 }

+ssize_t
+drainbuf(int fd, unsigned char *buf, size_t *bufpos)
+{
+       ssize_t n;
+       ssize_t adjust;
+
+       n = write(fd, buf, *bufpos);
+       /* don't treat EAGAIN, EINTR as error */
+       if (n == -1 && (errno == EAGAIN || errno == EINTR))
+               n = -2;
+       if (n <= 0)
+               return n;
+       /* adjust buffer */
+       adjust = *bufpos - n;
+       if (adjust > 0)
+               memmove(buf, buf + n, adjust);
+       *bufpos -= n;
+       return n;
+}
+
+
+ssize_t
+fillbuf(int fd, unsigned char *buf, size_t *bufpos)
+{
+       size_t num = BUFSIZE - *bufpos;
+       ssize_t n;
+
+       n = read(fd, buf + *bufpos, num);
+       /* don't treat EAGAIN, EINTR as error */
+       if (n == -1 && (errno == EAGAIN || errno == EINTR))
+               n = -2;
+       if (n <= 0)
+               return n;
+       *bufpos += n;
+       return n;
+}
+
 /*
  * fdpass()
  * Pass the connected file descriptor to stdout and exit.
@@ -857,6 +1027,8 @@ atelnet(int nfd, unsigned char *buf, uns
 {
        unsigned char *p, *end;
        unsigned char obuf[4];
+       int flags;
+       int blocking = 0;

        if (size < 3)
                return;
@@ -877,8 +1049,20 @@ atelnet(int nfd, unsigned char *buf, uns

                p++;
                obuf[2] = *p;
+
+               if (!blocking) {
+                       flags = fcntl(nfd, F_GETFL, 0);
+                       if (fcntl(nfd, F_SETFL, flags & ~O_NONBLOCK) == -1)
+                               warn("unable to set blocking mode");
+                       blocking = 1;
+               }
                if (atomicio(vwrite, nfd, obuf, 3) != 3)
                        warn("Write Error!");
+       }
+       if (blocking) {
+               flags = fcntl(nfd, F_GETFL, 0);
+               if (fcntl(nfd, F_SETFL, flags | O_NONBLOCK) == -1)
+                       warn("unable to set non-blocking mode");
        }
 }

Reply via email to