commit 30b5ef8a9e9c7a5b306e9285d1a8db323f8f22b2
Author: Arlo Breault <[email protected]>
Date:   Wed Nov 20 19:33:28 2019 -0500

    Use gorilla websocket in proxy-go too
    
    Trac: 32465
---
 common/websocketconn/websocketconn.go      | 89 +++++++++++++++++++++++++++
 common/websocketconn/websocketconn_test.go | 30 +++++++++
 proxy-go/proxy-go_test.go                  | 19 ------
 proxy-go/snowflake.go                      | 25 ++------
 server/server.go                           | 99 ++----------------------------
 5 files changed, 128 insertions(+), 134 deletions(-)

diff --git a/common/websocketconn/websocketconn.go 
b/common/websocketconn/websocketconn.go
new file mode 100644
index 0000000..399cbaa
--- /dev/null
+++ b/common/websocketconn/websocketconn.go
@@ -0,0 +1,89 @@
+package websocketconn
+
+import (
+       "io"
+       "log"
+       "sync"
+       "time"
+
+       "github.com/gorilla/websocket"
+)
+
+// An abstraction that makes an underlying WebSocket connection look like an
+// io.ReadWriteCloser.
+type WebSocketConn struct {
+       Ws *websocket.Conn
+       r  io.Reader
+}
+
+// Implements io.Reader.
+func (conn *WebSocketConn) Read(b []byte) (n int, err error) {
+       var opCode int
+       if conn.r == nil {
+               // New message
+               var r io.Reader
+               for {
+                       if opCode, r, err = conn.Ws.NextReader(); err != nil {
+                               return
+                       }
+                       if opCode != websocket.BinaryMessage && opCode != 
websocket.TextMessage {
+                               continue
+                       }
+
+                       conn.r = r
+                       break
+               }
+       }
+
+       n, err = conn.r.Read(b)
+       if err == io.EOF {
+               // Message finished
+               conn.r = nil
+               err = nil
+       }
+       return
+}
+
+// Implements io.Writer.
+func (conn *WebSocketConn) Write(b []byte) (n int, err error) {
+       var w io.WriteCloser
+       if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil {
+               return
+       }
+       if n, err = w.Write(b); err != nil {
+               return
+       }
+       err = w.Close()
+       return
+}
+
+// Implements io.Closer.
+func (conn *WebSocketConn) Close() error {
+       // Ignore any error in trying to write a Close frame.
+       _ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, 
time.Now().Add(time.Second))
+       return conn.Ws.Close()
+}
+
+// Create a new WebSocketConn.
+func NewWebSocketConn(ws *websocket.Conn) WebSocketConn {
+       var conn WebSocketConn
+       conn.Ws = ws
+       return conn
+}
+
+// Copy from WebSocket to socket and vice versa.
+func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) {
+       var wg sync.WaitGroup
+       copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
+               defer wg.Done()
+               if _, err := io.Copy(dst, src); err != nil {
+                       log.Printf("io.Copy inside CopyLoop generated an error: 
%v", err)
+               }
+               dst.Close()
+               src.Close()
+       }
+       wg.Add(2)
+       go copyer(c1, c2)
+       go copyer(c2, c1)
+       wg.Wait()
+}
diff --git a/common/websocketconn/websocketconn_test.go 
b/common/websocketconn/websocketconn_test.go
new file mode 100644
index 0000000..3293165
--- /dev/null
+++ b/common/websocketconn/websocketconn_test.go
@@ -0,0 +1,30 @@
+package websocketconn
+
+import (
+       "net"
+       "testing"
+
+       . "github.com/smartystreets/goconvey/convey"
+)
+
+func TestWebsocketConn(t *testing.T) {
+       Convey("CopyLoop", t, func() {
+               c1, s1 := net.Pipe()
+               c2, s2 := net.Pipe()
+               go CopyLoop(s1, s2)
+               go func() {
+                       bytes := []byte("Hello!")
+                       c1.Write(bytes)
+               }()
+               bytes := make([]byte, 6)
+               n, err := c2.Read(bytes)
+               So(n, ShouldEqual, 6)
+               So(err, ShouldEqual, nil)
+               So(bytes, ShouldResemble, []byte("Hello!"))
+               s1.Close()
+
+               // Check that copy loop has closed other connection
+               _, err = s2.Write(bytes)
+               So(err, ShouldNotBeNil)
+       })
+}
diff --git a/proxy-go/proxy-go_test.go b/proxy-go/proxy-go_test.go
index ebe4381..538957b 100644
--- a/proxy-go/proxy-go_test.go
+++ b/proxy-go/proxy-go_test.go
@@ -374,23 +374,4 @@ func TestUtilityFuncs(t *testing.T) {
                sid2 := genSessionID()
                So(sid1, ShouldNotEqual, sid2)
        })
-       Convey("CopyLoop", t, func() {
-               c1, s1 := net.Pipe()
-               c2, s2 := net.Pipe()
-               go CopyLoop(s1, s2)
-               go func() {
-                       bytes := []byte("Hello!")
-                       c1.Write(bytes)
-               }()
-               bytes := make([]byte, 6)
-               n, err := c2.Read(bytes)
-               So(n, ShouldEqual, 6)
-               So(err, ShouldEqual, nil)
-               So(bytes, ShouldResemble, []byte("Hello!"))
-               s1.Close()
-
-               //Check that copy loop has closed other connection
-               _, err = s2.Write(bytes)
-               So(err, ShouldNotBeNil)
-       })
 }
diff --git a/proxy-go/snowflake.go b/proxy-go/snowflake.go
index c4b2f0b..0e14eb2 100644
--- a/proxy-go/snowflake.go
+++ b/proxy-go/snowflake.go
@@ -21,8 +21,9 @@ import (
 
        "git.torproject.org/pluggable-transports/snowflake.git/common/messages"
        "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
+       
"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
+       "github.com/gorilla/websocket"
        "github.com/pion/webrtc"
-       "golang.org/x/net/websocket"
 )
 
 const defaultBrokerURL = "https://snowflake-broker.bamsoftware.com/";
@@ -239,22 +240,6 @@ func (b *Broker) sendAnswer(sid string, pc 
*webrtc.PeerConnection) error {
        return nil
 }
 
-func CopyLoop(c1 net.Conn, c2 net.Conn) {
-       var wg sync.WaitGroup
-       copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
-               defer wg.Done()
-               if _, err := io.Copy(dst, src); err != nil {
-                       log.Printf("io.Copy inside CopyLoop generated an error: 
%v", err)
-               }
-               dst.Close()
-               src.Close()
-       }
-       wg.Add(2)
-       go copyer(c1, c2)
-       go copyer(c2, c1)
-       wg.Wait()
-}
-
 // We pass conn.RemoteAddr() as an additional parameter, rather than calling
 // conn.RemoteAddr() inside this function, as a workaround for a hang that
 // otherwise occurs inside of conn.pc.RemoteDescription() (called by
@@ -279,15 +264,15 @@ func datachannelHandler(conn *webRTCConn, remoteAddr 
net.Addr) {
                log.Printf("no remote address given in websocket")
        }
 
-       wsConn, err := websocket.Dial(u.String(), "", relayURL)
+       ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
        if err != nil {
                log.Printf("error dialing relay: %s", err)
                return
        }
+       wsConn := websocketconn.NewWebSocketConn(ws)
        log.Printf("connected to relay")
        defer wsConn.Close()
-       wsConn.PayloadType = websocket.BinaryFrame
-       CopyLoop(conn, wsConn)
+       websocketconn.CopyLoop(conn, &wsConn)
        log.Printf("datachannelHandler ends")
 }
 
diff --git a/server/server.go b/server/server.go
index ce804fc..d950ddc 100644
--- a/server/server.go
+++ b/server/server.go
@@ -15,12 +15,12 @@ import (
        "os/signal"
        "path/filepath"
        "strings"
-       "sync"
        "syscall"
        "time"
 
        pt "git.torproject.org/pluggable-transports/goptlib.git"
        "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
+       
"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
        "github.com/gorilla/websocket"
        "golang.org/x/crypto/acme/autocert"
        "golang.org/x/net/http2"
@@ -50,97 +50,6 @@ additional HTTP listener on port 80 to work with ACME.
        flag.PrintDefaults()
 }
 
-// An abstraction that makes an underlying WebSocket connection look like an
-// io.ReadWriteCloser.
-type webSocketConn struct {
-       Ws *websocket.Conn
-       r  io.Reader
-}
-
-// Implements io.Reader.
-func (conn *webSocketConn) Read(b []byte) (n int, err error) {
-       var opCode int
-       if conn.r == nil {
-               // New message
-               var r io.Reader
-               for {
-                       if opCode, r, err = conn.Ws.NextReader(); err != nil {
-                               return
-                       }
-                       if opCode != websocket.BinaryMessage && opCode != 
websocket.TextMessage {
-                               continue
-                       }
-
-                       conn.r = r
-                       break
-               }
-       }
-
-       n, err = conn.r.Read(b)
-       if err == io.EOF {
-               // Message finished
-               conn.r = nil
-               err = nil
-       }
-       return
-}
-
-// Implements io.Writer.
-func (conn *webSocketConn) Write(b []byte) (n int, err error) {
-       var w io.WriteCloser
-       if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil {
-               return
-       }
-       if n, err = w.Write(b); err != nil {
-               return
-       }
-       err = w.Close()
-       return
-}
-
-// Implements io.Closer.
-func (conn *webSocketConn) Close() error {
-       // Ignore any error in trying to write a Close frame.
-       _ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, 
time.Now().Add(time.Second))
-       return conn.Ws.Close()
-}
-
-// Create a new webSocketConn.
-func newWebSocketConn(ws *websocket.Conn) webSocketConn {
-       var conn webSocketConn
-       conn.Ws = ws
-       return conn
-}
-
-// Copy from WebSocket to socket and vice versa.
-func proxy(local *net.TCPConn, conn *webSocketConn) {
-       var wg sync.WaitGroup
-       wg.Add(2)
-
-       go func() {
-               if _, err := io.Copy(conn, local); err != nil {
-                       log.Printf("error copying ORPort to WebSocket %v", err)
-               }
-               if err := local.CloseRead(); err != nil {
-                       log.Printf("error closing read after copying ORPort to 
WebSocket %v", err)
-               }
-               conn.Close()
-               wg.Done()
-       }()
-       go func() {
-               if _, err := io.Copy(local, conn); err != nil {
-                       log.Printf("error copying WebSocket to ORPort")
-               }
-               if err := local.CloseWrite(); err != nil {
-                       log.Printf("error closing write after copying WebSocket 
to ORPort %v", err)
-               }
-               conn.Close()
-               wg.Done()
-       }()
-
-       wg.Wait()
-}
-
 // Return an address string suitable to pass into pt.DialOr.
 func clientAddr(clientIPParam string) string {
        if clientIPParam == "" {
@@ -166,8 +75,8 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, 
r *http.Request) {
                return
        }
 
-       conn := newWebSocketConn(ws)
-       defer conn.Close()
+       wsConn := websocketconn.NewWebSocketConn(ws)
+       defer wsConn.Close()
 
        // Pass the address of client as the remote address of incoming 
connection
        clientIPParam := r.URL.Query().Get("client_ip")
@@ -184,7 +93,7 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, 
r *http.Request) {
        }
        defer or.Close()
 
-       proxy(or, &conn)
+       websocketconn.CopyLoop(or, &wsConn)
 }
 
 func initServer(addr *net.TCPAddr,



_______________________________________________
tor-commits mailing list
[email protected]
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits

Reply via email to