The following pull request was submitted through Github. It can be accessed and reviewed at: https://github.com/lxc/lxd/pull/4663
This e-mail was sent by the LXC bot, direct replies will not reach the author unless they happen to be subscribed to this list. === Description (from pull-request) === The old implementation used to call connect() every time a new client got accepted. Iiuc, this is not what we want. Ideally, we'd want all clients to dump their traffic to the same connect()ion. This is especially true when we are forwarding multiple ports. Unfortunately, this makes the actual implementation more complex. In any case, I might be mistaken and what we want is that each new accepted client on the forwarded port also causes a new connect() call. Closes #4601. Signed-off-by: Christian Brauner <[email protected]>
From 4fcd67fccbd91d81df3ae1bc21bbf887c8ed265d Mon Sep 17 00:00:00 2001 From: Christian Brauner <[email protected]> Date: Mon, 18 Jun 2018 14:32:28 +0200 Subject: [PATCH 1/4] reader: handle EINTR Signed-off-by: Christian Brauner <[email protected]> --- shared/eagain/file.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/shared/eagain/file.go b/shared/eagain/file.go index 9e3eac9c0..2739969df 100644 --- a/shared/eagain/file.go +++ b/shared/eagain/file.go @@ -22,7 +22,7 @@ again: // keep retrying on EAGAIN errno, ok := shared.GetErrno(err) - if ok && errno == syscall.EAGAIN { + if ok && (errno == syscall.EAGAIN || errno == syscall.EINTR) { goto again } @@ -44,7 +44,7 @@ again: // keep retrying on EAGAIN errno, ok := shared.GetErrno(err) - if ok && errno == syscall.EAGAIN { + if ok && (errno == syscall.EAGAIN || errno == syscall.EINTR) { goto again } From 9302ce50bce392129ccd722f4cf2306541591621 Mon Sep 17 00:00:00 2001 From: Christian Brauner <[email protected]> Date: Sat, 16 Jun 2018 13:09:18 +0200 Subject: [PATCH 2/4] proxy: genericize to handle multiple ports Closes #4601. Signed-off-by: Christian Brauner <[email protected]> --- lxd/main_forkproxy.go | 171 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 131 insertions(+), 40 deletions(-) diff --git a/lxd/main_forkproxy.go b/lxd/main_forkproxy.go index 074bb8c72..93a367718 100644 --- a/lxd/main_forkproxy.go +++ b/lxd/main_forkproxy.go @@ -6,6 +6,7 @@ import ( "net" "os" "os/signal" + "strconv" "strings" "syscall" "time" @@ -260,7 +261,7 @@ type cmdForkproxy struct { type proxyAddress struct { connType string - addr string + addr []string abstract bool } @@ -307,45 +308,71 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error { return fmt.Errorf("Failed to call forkproxy constructor") } - lAddr := parseAddr(listenAddr) + lAddr, err := parseAddr(listenAddr) + if err != nil { + return err + } if C.whoami == C.FORKPROXY_CHILD { - err := os.Remove(lAddr.addr) - if err != nil && !os.IsNotExist(err) { - return err + if lAddr.connType == "unix" && !lAddr.abstract { + err := os.Remove(lAddr.addr[0]) + if err != nil && !os.IsNotExist(err) { + return err + } } - file, err := getListenerFile(listenAddr) - if err != nil { - return err + for _, port := range lAddr.addr { + fmt.Println(port) + } + + for _, addr := range lAddr.addr { + file, err := getListenerFile(lAddr.connType, addr) + if err != nil { + return err + } + + err = shared.AbstractUnixSendFd(forkproxyUDSSockFDNum, int(file.Fd())) + file.Close() + if err != nil { + break + } } - err = shared.AbstractUnixSendFd(forkproxyUDSSockFDNum, int(file.Fd())) syscall.Close(forkproxyUDSSockFDNum) - file.Close() return err } - file, err := shared.AbstractUnixReceiveFd(forkproxyUDSSockFDNum) - syscall.Close(forkproxyUDSSockFDNum) - if err != nil { - fmt.Printf("Failed to receive fd from listener process: %v\n", err) - return err + files := []*os.File{} + for range lAddr.addr { + f, err := shared.AbstractUnixReceiveFd(forkproxyUDSSockFDNum) + if err != nil { + fmt.Printf("Failed to receive fd from listener process: %v\n", err) + return err + } + files = append(files, f) } + syscall.Close(forkproxyUDSSockFDNum) var srcConn net.Conn - var listener net.Listener + var listeners []*net.Listener udpFD := -1 if lAddr.connType == "udp" { - udpFD = int(file.Fd()) - srcConn, err = net.FileConn(file) + udpFD = int(files[0].Fd()) + srcConn, err = net.FileConn(files[0]) + if err != nil { + fmt.Printf("Failed to re-assemble listener: %v", err) + return err + } } else { - listener, err = net.FileListener(file) - } - if err != nil { - fmt.Printf("Failed to re-assemble listener: %v", err) - return err + for _, f := range files { + listener, err := net.FileListener(f) + if err != nil { + fmt.Printf("Failed to re-assemble listener: %v", err) + return err + } + listeners = append(listeners, &listener) + } } // Handle SIGTERM which is sent when the proxy is to be removed @@ -358,33 +385,44 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error { go func() { <-sigs terminate = true - file.Close() + + for _, f := range files { + f.Close() + } + if lAddr.connType == "udp" { srcConn.Close() // Kill ourselves since we will otherwise block on UDP // connect() or poll(). syscall.Kill(killOnUDP, syscall.SIGKILL) } else { - listener.Close() + for _, listener := range listeners { + (*listener).Close() + } } }() connectAddr := args[3] - cAddr := parseAddr(connectAddr) + cAddr, err := parseAddr(connectAddr) + if err != nil { + return err + } if cAddr.connType == "unix" && !cAddr.abstract { // Create socket - file, err := getListenerFile(fmt.Sprintf("unix:%s", cAddr.addr)) + file, err := getListenerFile("unix", cAddr.addr[0]) if err != nil { return err } file.Close() - defer os.Remove(cAddr.addr) + if cAddr.connType == "unix" && !cAddr.abstract { + defer os.Remove(cAddr.addr[0]) + } } if lAddr.connType == "unix" && !lAddr.abstract { - defer os.Remove(lAddr.addr) + defer os.Remove(lAddr.addr[0]) } fmt.Printf("Starting %s <-> %s proxy\n", lAddr.connType, cAddr.connType) @@ -418,7 +456,7 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error { // begin proxying for { // Accept a new client - srcConn, err = listener.Accept() + srcConn, err = (*listeners[0]).Accept() if err != nil { if terminate { break @@ -619,15 +657,12 @@ func tryListenUDP(protocol string, addr string) (*os.File, error) { return file, err } -func getListenerFile(listenAddr string) (*os.File, error) { - fields := strings.SplitN(listenAddr, ":", 2) - addr := strings.Join(fields[1:], "") - - if fields[0] == "udp" { - return tryListenUDP(fields[0], addr) +func getListenerFile(protocol string, addr string) (*os.File, error) { + if protocol == "udp" { + return tryListenUDP("udp", addr) } - listener, err := tryListen(fields[0], addr) + listener, err := tryListen(protocol, addr) if err != nil { return nil, fmt.Errorf("Failed to listen on %s: %v", addr, err) } @@ -654,11 +689,67 @@ func getDestConn(connectAddr string) (net.Conn, error) { return net.Dial(fields[0], addr) } -func parseAddr(addr string) *proxyAddress { +func parsePortRange(r string) (int64, int64, error) { + entries := strings.Split(r, "-") + if len(entries) > 2 { + return -1, -1, fmt.Errorf("Invalid port range %s", r) + } + + base, err := strconv.ParseInt(entries[0], 10, 64) + if err != nil { + return -1, -1, err + } + + size := int64(1) + if len(entries) > 1 { + size, err = strconv.ParseInt(entries[1], 10, 64) + if err != nil { + return -1, -1, err + } + + size -= base + size += 1 + } + + return base, size, nil +} + +func parseAddr(addr string) (*proxyAddress, error) { + // Split into <protocol> and <address> fields := strings.SplitN(addr, ":", 2) - return &proxyAddress{ + + newProxyAddr := &proxyAddress{ connType: fields[0], - addr: fields[1], abstract: strings.HasPrefix(fields[1], "@"), } + + // unix addresses cannot have ports + if newProxyAddr.connType == "unix" { + newProxyAddr.addr = []string{fields[1]} + return newProxyAddr, nil + } + + // Split <address> into <address> and <ports> + addrParts := strings.SplitN(fields[1], ":", 2) + // no ports + if len(addrParts) == 1 { + newProxyAddr.addr = []string{fields[1]} + return newProxyAddr, nil + } + + // Split <ports> into individual ports and port ranges + ports := strings.SplitN(addrParts[1], ",", -1) + for _, port := range ports { + portFirst, portRange, err := parsePortRange(port) + if err != nil { + return nil, err + } + + for i := int64(0); i < portRange; i++ { + newAddr := fmt.Sprintf("%s:%d", addrParts[0], portFirst + i) + newProxyAddr.addr = append(newProxyAddr.addr, newAddr) + } + } + + return newProxyAddr, nil } From fd58089ce0ef92f79b0116eff5b2dfb616a219cc Mon Sep 17 00:00:00 2001 From: Christian Brauner <[email protected]> Date: Sat, 16 Jun 2018 14:28:51 +0200 Subject: [PATCH 3/4] proxy: handle UDP and TCP port ranges Closes #4601. Signed-off-by: Christian Brauner <[email protected]> --- lxd/main_forkproxy.go | 214 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 128 insertions(+), 86 deletions(-) diff --git a/lxd/main_forkproxy.go b/lxd/main_forkproxy.go index 93a367718..946a87450 100644 --- a/lxd/main_forkproxy.go +++ b/lxd/main_forkproxy.go @@ -10,6 +10,7 @@ import ( "strings" "syscall" "time" + "unsafe" "github.com/spf13/cobra" @@ -25,6 +26,7 @@ import ( #include <stdio.h> #include <stdlib.h> #include <string.h> +#include <sys/epoll.h> #include <sys/socket.h> #include <sys/stat.h> #include <sys/types.h> @@ -283,6 +285,55 @@ func (c *cmdForkproxy) Command() *cobra.Command { return cmd } +func listenerInstance(lAddr *proxyAddress, cAddr *proxyAddress, connectAddr string, udpSrcConn *net.Conn, listener *net.Listener) error { + fmt.Printf("Starting %s <-> %s proxy\n", lAddr.connType, cAddr.connType) + if lAddr.connType == "udp" { + go func() error { + // Connect to the target + dstConn, err := getDestConn(connectAddr) + if err != nil { + fmt.Printf("Error: Failed to connect to target: %v\n", err) + (*udpSrcConn).Close() + return err + } + + genericRelay((*udpSrcConn), dstConn, false) + + return nil + }() + + return nil + } + + // Accept a new client + srcConn, err := (*listener).Accept() + if err != nil { + fmt.Printf("Error: Failed to accept new connection: %v\n", err) + return err + } + fmt.Printf("Accepted a new connection\n") + + // Connect to the target + dstConn, err := getDestConn(connectAddr) + if err != nil { + fmt.Printf("Error: Failed to connect to target: %v\n", err) + if lAddr.connType != "udp" { + srcConn.Close() + } + + return err + } + + if cAddr.connType == "unix" && lAddr.connType == "unix" { + // Handle OOB if both src and dst are using unix sockets + go unixRelay(srcConn, dstConn) + } else { + go genericRelay(srcConn, dstConn, true) + } + + return nil +} + func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error { // Only root should run this if os.Geteuid() != 0 { @@ -314,17 +365,13 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error { } if C.whoami == C.FORKPROXY_CHILD { - if lAddr.connType == "unix" && !lAddr.abstract { + if lAddr.connType == "unix" && !lAddr.abstract { err := os.Remove(lAddr.addr[0]) if err != nil && !os.IsNotExist(err) { return err } } - for _, port := range lAddr.addr { - fmt.Println(port) - } - for _, addr := range lAddr.addr { file, err := getListenerFile(lAddr.connType, addr) if err != nil { @@ -347,6 +394,7 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error { f, err := shared.AbstractUnixReceiveFd(forkproxyUDSSockFDNum) if err != nil { fmt.Printf("Failed to receive fd from listener process: %v\n", err) + syscall.Close(forkproxyUDSSockFDNum) return err } files = append(files, f) @@ -354,54 +402,36 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error { syscall.Close(forkproxyUDSSockFDNum) var srcConn net.Conn - var listeners []*net.Listener + var listenerMap map[int]*net.Listener + var udpConnMap map[int]*net.Conn - udpFD := -1 - if lAddr.connType == "udp" { - udpFD = int(files[0].Fd()) - srcConn, err = net.FileConn(files[0]) - if err != nil { - fmt.Printf("Failed to re-assemble listener: %v", err) - return err + isUDPListener := lAddr.connType == "udp" + if isUDPListener { + udpConnMap = make(map[int]*net.Conn, len(lAddr.addr)) + for _, f := range files { + srcConn, err = net.FileConn(files[0]) + if err != nil { + fmt.Printf("Failed to re-assemble listener: %v", err) + return err + } + udpConnMap[int(f.Fd())] = &srcConn } } else { + listenerMap = make(map[int]*net.Listener, len(lAddr.addr)) for _, f := range files { listener, err := net.FileListener(f) if err != nil { fmt.Printf("Failed to re-assemble listener: %v", err) return err } - listeners = append(listeners, &listener) + listenerMap[int(f.Fd())] = &listener } } // Handle SIGTERM which is sent when the proxy is to be removed - terminate := false sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGTERM) - // Wait for SIGTERM and close the listener in order to exit the loop below - killOnUDP := syscall.Getpid() - go func() { - <-sigs - terminate = true - - for _, f := range files { - f.Close() - } - - if lAddr.connType == "udp" { - srcConn.Close() - // Kill ourselves since we will otherwise block on UDP - // connect() or poll(). - syscall.Kill(killOnUDP, syscall.SIGKILL) - } else { - for _, listener := range listeners { - (*listener).Close() - } - } - }() - connectAddr := args[3] cAddr, err := parseAddr(connectAddr) if err != nil { @@ -425,70 +455,82 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error { defer os.Remove(lAddr.addr[0]) } - fmt.Printf("Starting %s <-> %s proxy\n", lAddr.connType, cAddr.connType) - if lAddr.connType == "udp" { - for { - ret, revents, err := shared.GetPollRevents(udpFD, -1, (shared.POLLIN | shared.POLLPRI | shared.POLLERR | shared.POLLHUP | shared.POLLRDHUP | shared.POLLNVAL)) - if ret < 0 { - fmt.Printf("Failed to poll on file descriptor: %s\n", err) - srcConn.Close() - return err - } + epFD := C.epoll_create1(C.EPOLL_CLOEXEC) + if epFD < 0 { + return fmt.Errorf("Failed to create new epoll instance") + } - if (revents & (shared.POLLERR | shared.POLLHUP | shared.POLLRDHUP | shared.POLLNVAL)) > 0 { - err := fmt.Errorf("Invalid UDP socket file descriptor") - fmt.Printf("%s\n", err) - srcConn.Close() - return err - } + // Wait for SIGTERM and close the listener in order to exit the loop below + self := syscall.Getpid() + go func() { + <-sigs - // Connect to the target - dstConn, err := getDestConn(connectAddr) - if err != nil { - fmt.Printf("error: Failed to connect to target: %v\n", err) - srcConn.Close() - return err + for _, f := range files { + C.epoll_ctl(epFD, C.EPOLL_CTL_DEL, C.int(f.Fd()), nil) + f.Close() + } + syscall.Close(int(epFD)) + + if isUDPListener { + for _, l := range udpConnMap { + (*l).Close() + } + } else { + for _, l := range listenerMap { + (*l).Close() } + } + syscall.Kill(self, syscall.SIGKILL) + }() + defer syscall.Kill(self, syscall.SIGTERM) - genericRelay(srcConn, dstConn, false) + for _, f := range files { + var ev C.struct_epoll_event + ev.events = C.EPOLLIN + if isUDPListener { + ev.events |= C.EPOLLET } - } else { - // begin proxying - for { - // Accept a new client - srcConn, err = (*listeners[0]).Accept() - if err != nil { - if terminate { - break - } - fmt.Printf("error: Failed to accept new connection: %v\n", err) - continue - } - fmt.Printf("Accepted a new connection\n") + *(*C.int)(unsafe.Pointer(uintptr(unsafe.Pointer(&ev)) + unsafe.Sizeof(ev.events))) = C.int(f.Fd()) + ret := C.epoll_ctl(epFD, C.EPOLL_CTL_ADD, C.int(f.Fd()), &ev) + if ret < 0 { + return fmt.Errorf("Failed to add listener fd to epoll instance") + } + fmt.Printf("Added listener socket file descriptor %d to epoll instance\n", int(f.Fd())) + } - // Connect to the target - dstConn, err := getDestConn(connectAddr) - if err != nil { - fmt.Printf("error: Failed to connect to target: %v\n", err) - if lAddr.connType != "udp" { - srcConn.Close() - } + for { + var events [10]C.struct_epoll_event + + nfds := C.epoll_wait(epFD, &events[0], 10, -1) + if nfds < 0 { + fmt.Printf("Failed to wait on epoll instance") + break + } + for i := C.int(0); i < nfds; i++ { + var listener *net.Listener + var udpListener *net.Conn + var ok bool + + curFD := *(*C.int)(unsafe.Pointer(uintptr(unsafe.Pointer(&events[i])) + unsafe.Sizeof(events[i].events))) + if isUDPListener { + udpListener, ok = udpConnMap[int(curFD)] + } else { + listener, ok = listenerMap[int(curFD)] + } + if !ok { continue } - if cAddr.connType == "unix" && lAddr.connType == "unix" { - // Handle OOB if both src and dst are using unix sockets - go unixRelay(srcConn, dstConn) - } else { - go genericRelay(srcConn, dstConn, true) + err := listenerInstance(lAddr, cAddr, connectAddr, udpListener, listener) + if err != nil { + fmt.Printf("Failed to prepare new listener instance: %s", err) } } } fmt.Printf("Stopping proxy\n") - return nil } @@ -746,7 +788,7 @@ func parseAddr(addr string) (*proxyAddress, error) { } for i := int64(0); i < portRange; i++ { - newAddr := fmt.Sprintf("%s:%d", addrParts[0], portFirst + i) + newAddr := fmt.Sprintf("%s:%d", addrParts[0], portFirst+i) newProxyAddr.addr = append(newProxyAddr.addr, newAddr) } } From 01f19eb8c9309f3b16613dbe143eed66a617d678 Mon Sep 17 00:00:00 2001 From: Christian Brauner <[email protected]> Date: Mon, 18 Jun 2018 11:46:28 +0200 Subject: [PATCH 4/4] proxy: dump traffic to the same connection The old implementation used to call connect() every time a new client got accepted. Iiuc, this is not what we want. Ideally, we'd want all clients to dump their traffic to the same connect()ion. This is especially true when we are forwarding multiple ports. Unfortunately, this makes the actual implementation more complex. In any case, I might be mistaken and what we want is that each new accepted client on the forwarded port also causes a new connect() call. Closes #4601. Signed-off-by: Christian Brauner <[email protected]> --- lxd/main_forkproxy.go | 228 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 165 insertions(+), 63 deletions(-) diff --git a/lxd/main_forkproxy.go b/lxd/main_forkproxy.go index 946a87450..7c3d40ccf 100644 --- a/lxd/main_forkproxy.go +++ b/lxd/main_forkproxy.go @@ -8,6 +8,7 @@ import ( "os/signal" "strconv" "strings" + "sync" "syscall" "time" "unsafe" @@ -285,22 +286,10 @@ func (c *cmdForkproxy) Command() *cobra.Command { return cmd } -func listenerInstance(lAddr *proxyAddress, cAddr *proxyAddress, connectAddr string, udpSrcConn *net.Conn, listener *net.Listener) error { - fmt.Printf("Starting %s <-> %s proxy\n", lAddr.connType, cAddr.connType) - if lAddr.connType == "udp" { - go func() error { - // Connect to the target - dstConn, err := getDestConn(connectAddr) - if err != nil { - fmt.Printf("Error: Failed to connect to target: %v\n", err) - (*udpSrcConn).Close() - return err - } - - genericRelay((*udpSrcConn), dstConn, false) - - return nil - }() +func listenerInstance(lProtocol string, cProtocol string, udpSrcConn *net.Conn, listener *net.Listener, dst net.Conn) error { + fmt.Printf("Starting %s <-> %s proxy\n", lProtocol, cProtocol) + if lProtocol == "udp" { + go genericRelay((*udpSrcConn), dst, true) return nil } @@ -311,29 +300,22 @@ func listenerInstance(lAddr *proxyAddress, cAddr *proxyAddress, connectAddr stri fmt.Printf("Error: Failed to accept new connection: %v\n", err) return err } - fmt.Printf("Accepted a new connection\n") - // Connect to the target - dstConn, err := getDestConn(connectAddr) - if err != nil { - fmt.Printf("Error: Failed to connect to target: %v\n", err) - if lAddr.connType != "udp" { - srcConn.Close() - } + if lProtocol == "unix" && cProtocol == "unix" { + // Handle OOB if both src and dst are using unix sockets + go unixRelay(srcConn, dst) - return err + return nil } - if cAddr.connType == "unix" && lAddr.connType == "unix" { - // Handle OOB if both src and dst are using unix sockets - go unixRelay(srcConn, dstConn) - } else { - go genericRelay(srcConn, dstConn, true) - } + go genericRelay(srcConn, dst, false) return nil } +var dstConnLock sync.Mutex +var dstConn *net.Conn + func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error { // Only root should run this if os.Geteuid() != 0 { @@ -523,7 +505,21 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error { continue } - err := listenerInstance(lAddr, cAddr, connectAddr, udpListener, listener) + dstConnLock.Lock() + if dstConn == nil { + // Connect to the target + tmp, err := getDestConn(connectAddr) + if err != nil { + fmt.Printf("Error: Failed to connect to target: %s\n", err) + dstConnLock.Unlock() + continue + } + + dstConn = &tmp + } + dstConnLock.Unlock() + + err := listenerInstance(lAddr.connType, cAddr.connType, udpListener, listener, *dstConn) if err != nil { fmt.Printf("Failed to prepare new listener instance: %s", err) } @@ -534,36 +530,112 @@ func (c *cmdForkproxy) Run(cmd *cobra.Command, args []string) error { return nil } -func genericRelay(dst io.ReadWriteCloser, src io.ReadWriteCloser, closeDst bool) { - relayer := func(dst io.Writer, src io.Reader, ch chan error) { - _, err := io.Copy(eagain.Writer{Writer: dst}, eagain.Reader{Reader: src}) - ch <- err +func copyBuffer(dst io.Writer, src io.Reader) (written int64, dstErr error, srcErr error) { + // If the reader has a WriteTo method, use it to do the copy. + // Avoids an allocation and a copy. + if wt, ok := src.(io.WriterTo); ok { + written, dstErr = wt.WriteTo(dst) + return written, dstErr, nil + } + + // Similarly, if the writer has a ReadFrom method, use it to do the copy. + if rt, ok := dst.(io.ReaderFrom); ok { + written, srcErr = rt.ReadFrom(src) + return written, srcErr, nil + } + + size := 32 * 1024 + if l, ok := src.(*io.LimitedReader); ok && int64(size) > l.N { + if l.N < 1 { + size = 1 + } else { + size = int(l.N) + } + } + + buf := make([]byte, size) + for { + nr, er := src.Read(buf) + if nr > 0 { + nw, ew := dst.Write(buf[0:nr]) + if nw > 0 { + written += int64(nw) + } + if ew != nil { + dstErr = ew + break + } + if nr != nw { + dstErr = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + srcErr = er + } + break + } + } + + return written, dstErr, srcErr +} + +func genericRelay(src io.ReadWriteCloser, dst io.ReadWriteCloser, udp bool) { + relayer := func(src io.Writer, dst io.Reader, srcCh chan error, dstCh chan error, udp bool) { + var srcErr, dstErr error + + if udp { + // EPOLLET behavior requires us to stop reading at + // EAGAIN so don't handle this error + _, srcErr, dstErr = copyBuffer(src, dst) + } else { + _, srcErr, dstErr = copyBuffer(eagain.Writer{Writer: src}, eagain.Reader{Reader: dst}) + } + srcCh <- srcErr + dstCh <- dstErr } - chSend := make(chan error) - go relayer(dst, src, chSend) + chSrcSend := make(chan error) + chDstSend := make(chan error) + go relayer(src, dst, chSrcSend, chDstSend, udp) - chRecv := make(chan error) - go relayer(src, dst, chRecv) + chSrcRecv := make(chan error) + chDstRecv := make(chan error) + go relayer(dst, src, chSrcRecv, chDstRecv, udp) - errSnd := <-chSend - errRcv := <-chRecv + errSrcSnd := <-chSrcSend + errDstSnd := <-chDstSend + errSrcRcv := <-chSrcRecv + errDstRcv := <-chDstRecv - src.Close() - if closeDst { - dst.Close() + if !udp { + src.Close() } - if errSnd != nil { - fmt.Printf("Error while sending data %s\n", errSnd) + if chDstSend != nil || chDstRecv != nil { + dstConnLock.Lock() + dstConn = nil + dstConnLock.Unlock() + fmt.Println("Resetting target port") } - if errRcv != nil { - fmt.Printf("Error while reading data %s\n", errRcv) + if errSrcSnd != nil { + fmt.Printf("Error: Sending data failed with %s\n", errSrcSnd) + } + if errDstSnd != nil { + fmt.Printf("Error: Sending data failed with %s\n", errDstSnd) + } + + if errSrcRcv != nil { + fmt.Printf("Error: Reading data failed with %s\n", errSrcRcv) + } + if errDstRcv != nil { + fmt.Printf("Error: Reading data failed with %s\n", errDstRcv) } } -func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) { +func unixRelayer(src *net.UnixConn, dst *net.UnixConn, srcCh chan error, dstCh chan error) { dataBuf := make([]byte, 4096) oobBuf := make([]byte, 4096) @@ -577,7 +649,8 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) { goto readAgain } fmt.Printf("Disconnected during read: %v\n", err) - ch <- true + srcCh <- err + dstCh <- nil return } @@ -586,7 +659,8 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) { entries, err := syscall.ParseSocketControlMessage(oobBuf[:sOob]) if err != nil { fmt.Printf("Failed to parse control message: %v\n", err) - ch <- true + srcCh <- nil + dstCh <- nil return } @@ -594,7 +668,8 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) { fds, err = syscall.ParseUnixRights(&msg) if err != nil { fmt.Printf("Failed to get fd list for control message: %v\n", err) - ch <- true + srcCh <- nil + dstCh <- nil return } } @@ -609,13 +684,15 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) { goto writeAgain } fmt.Printf("Disconnected during write: %v\n", err) - ch <- true + srcCh <- nil + dstCh <- err return } if sData != tData || sOob != tOob { fmt.Printf("Some data got lost during transfer, disconnecting.") - ch <- true + srcCh <- nil + dstCh <- nil return } @@ -625,7 +702,8 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) { err := syscall.Close(fd) if err != nil { fmt.Printf("Failed to close fd %d: %v\n", fd, err) - ch <- true + srcCh <- nil + dstCh <- nil return } } @@ -634,17 +712,41 @@ func unixRelayer(src *net.UnixConn, dst *net.UnixConn, ch chan bool) { } func unixRelay(dst io.ReadWriteCloser, src io.ReadWriteCloser) { - chSend := make(chan bool) - go unixRelayer(dst.(*net.UnixConn), src.(*net.UnixConn), chSend) + chSrcSend := make(chan error) + chDstSend := make(chan error) + go unixRelayer(dst.(*net.UnixConn), src.(*net.UnixConn), chSrcSend, chDstSend) + + chSrcRecv := make(chan error) + chDstRecv := make(chan error) + go unixRelayer(src.(*net.UnixConn), dst.(*net.UnixConn), chSrcRecv, chDstRecv) - chRecv := make(chan bool) - go unixRelayer(src.(*net.UnixConn), dst.(*net.UnixConn), chRecv) + errSrcSnd := <-chSrcSend + errDstSnd := <-chDstSend + errSrcRcv := <-chSrcRecv + errDstRcv := <-chDstRecv - <-chSend - <-chRecv + if chDstSend != nil || chDstRecv != nil { + dstConnLock.Lock() + dstConn = nil + dstConnLock.Unlock() + fmt.Println("Resetting target port") + } src.Close() - dst.Close() + + if errSrcSnd != nil { + fmt.Printf("Error: Sending data failed with %s\n", errSrcSnd) + } + if errDstSnd != nil { + fmt.Printf("Error: Sending data failed with %s\n", errDstSnd) + } + + if errSrcRcv != nil { + fmt.Printf("Error: Reading data failed with %s\n", errSrcRcv) + } + if errDstRcv != nil { + fmt.Printf("Error: Reading data failed with %s\n", errDstRcv) + } } func tryListen(protocol string, addr string) (net.Listener, error) {
_______________________________________________ lxc-devel mailing list [email protected] http://lists.linuxcontainers.org/listinfo/lxc-devel
