commit 11f0846264d4033e7a7dc7824febb6ad7140762f
Author: Cecylia Bocovich <coh...@torproject.org>
Date:   Sat Mar 20 18:24:00 2021 -0400

    Implement server as a v2.1 PT Go API
---
 server/lib/http.go                   | 211 +++++++++++++++++
 server/lib/server_test.go            |  55 +++++
 server/lib/snowflake.go              | 242 ++++++++++++++++++++
 server/{ => lib}/turbotunnel.go      |   2 +-
 server/{ => lib}/turbotunnel_test.go |   2 +-
 server/server.go                     | 426 ++++-------------------------------
 server/server_test.go                | 153 -------------
 7 files changed, 551 insertions(+), 540 deletions(-)

diff --git a/server/lib/http.go b/server/lib/http.go
new file mode 100644
index 0000000..b1c453c
--- /dev/null
+++ b/server/lib/http.go
@@ -0,0 +1,211 @@
+package lib
+
+import (
+       "bufio"
+       "bytes"
+       "fmt"
+       "io"
+       "log"
+       "net"
+       "net/http"
+       "time"
+
+       
"git.torproject.org/pluggable-transports/snowflake.git/common/encapsulation"
+       
"git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel"
+       
"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
+       "github.com/gorilla/websocket"
+)
+
+const requestTimeout = 10 * time.Second
+
+// How long to remember outgoing packets for a client, when we don't currently
+// have an active WebSocket connection corresponding to that client. Because a
+// client session may span multiple WebSocket connections, we keep packets we
+// aren't able to send immediately in memory, for a little while but not
+// indefinitely.
+const clientMapTimeout = 1 * time.Minute
+
+// How big to make the map of ClientIDs to IP addresses. The map is used in
+// turbotunnelMode to store a reasonable IP address for a client session that
+// may outlive any single WebSocket connection.
+const clientIDAddrMapCapacity = 1024
+
+// How long to wait for ListenAndServe or ListenAndServeTLS to return an error
+// before deciding that it's not going to return.
+const listenAndServeErrorTimeout = 100 * time.Millisecond
+
+var upgrader = websocket.Upgrader{
+       CheckOrigin: func(r *http.Request) bool { return true },
+}
+
+// clientIDAddrMap stores short-term mappings from ClientIDs to IP addresses.
+// When we call pt.DialOr, tor wants us to provide a USERADDR string that
+// represents the remote IP address of the client (for metrics purposes, etc.).
+// This data structure bridges the gap between ServeHTTP, which knows about IP
+// addresses, and handleStream, which is what calls pt.DialOr. The common piece
+// of information linking both ends of the chain is the ClientID, which is
+// attached to the WebSocket connection and every session.
+var clientIDAddrMap = newClientIDMap(clientIDAddrMapCapacity)
+
+// overrideReadConn is a net.Conn with an overridden Read method. Compare to
+// recordingConn at
+// https://dave.cheney.net/2015/05/22/struct-composition-with-go.
+type overrideReadConn struct {
+       net.Conn
+       io.Reader
+}
+
+func (conn *overrideReadConn) Read(p []byte) (int, error) {
+       return conn.Reader.Read(p)
+}
+
+type HTTPHandler struct {
+       // pconn is the adapter layer between stream-oriented WebSocket
+       // connections and the packet-oriented KCP layer.
+       pconn *turbotunnel.QueuePacketConn
+       ln    *SnowflakeListener
+}
+
+func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+       ws, err := upgrader.Upgrade(w, r, nil)
+       if err != nil {
+               log.Println(err)
+               return
+       }
+
+       conn := websocketconn.New(ws)
+       defer conn.Close()
+
+       // Pass the address of client as the remote address of incoming 
connection
+       clientIPParam := r.URL.Query().Get("client_ip")
+       addr := clientAddr(clientIPParam)
+
+       var token [len(turbotunnel.Token)]byte
+       _, err = io.ReadFull(conn, token[:])
+       if err != nil {
+               // Don't bother logging EOF: that happens with an unused
+               // connection, which clients make frequently as they maintain a
+               // pool of proxies.
+               if err != io.EOF {
+                       log.Printf("reading token: %v", err)
+               }
+               return
+       }
+
+       switch {
+       case bytes.Equal(token[:], turbotunnel.Token[:]):
+               err = turbotunnelMode(conn, addr, handler.pconn)
+       default:
+               // We didn't find a matching token, which means that we are
+               // dealing with a client that doesn't know about such things.
+               // "Unread" the token by constructing a new Reader and pass it
+               // to the old one-session-per-WebSocket mode.
+               conn2 := &overrideReadConn{Conn: conn, Reader: 
io.MultiReader(bytes.NewReader(token[:]), conn)}
+               err = oneshotMode(conn2, addr, handler.ln)
+       }
+       if err != nil {
+               log.Println(err)
+               return
+       }
+}
+
+// oneshotMode handles clients that did not send turbotunnel.Token at the start
+// of their stream. These clients use the WebSocket as a raw pipe, and expect
+// their session to begin and end when this single WebSocket does.
+func oneshotMode(conn net.Conn, addr net.Addr, ln *SnowflakeListener) error {
+       return ln.QueueConn(&SnowflakeClientConn{Conn: conn, address: addr})
+}
+
+// turbotunnelMode handles clients that sent turbotunnel.Token at the start of
+// their stream. These clients expect to send and receive encapsulated packets,
+// with a long-lived session identified by ClientID.
+func turbotunnelMode(conn net.Conn, addr net.Addr, pconn 
*turbotunnel.QueuePacketConn) error {
+       // Read the ClientID prefix. Every packet encapsulated in this WebSocket
+       // connection pertains to the same ClientID.
+       var clientID turbotunnel.ClientID
+       _, err := io.ReadFull(conn, clientID[:])
+       if err != nil {
+               return fmt.Errorf("reading ClientID: %v", err)
+       }
+
+       // Store a a short-term mapping from the ClientID to the client IP
+       // address attached to this WebSocket connection. tor will want us to
+       // provide a client IP address when we call pt.DialOr. But a KCP session
+       // does not necessarily correspond to any single IP address--it's
+       // composed of packets that are carried in possibly multiple WebSocket
+       // streams. We apply the heuristic that the IP address of the most
+       // recent WebSocket connection that has had to do with a session, at the
+       // time the session is established, is the IP address that should be
+       // credited for the entire KCP session.
+       clientIDAddrMap.Set(clientID, addr.String())
+
+       errCh := make(chan error)
+
+       // The remainder of the WebSocket stream consists of encapsulated
+       // packets. We read them one by one and feed them into the
+       // QueuePacketConn on which kcp.ServeConn was set up, which eventually
+       // leads to KCP-level sessions in the acceptSessions function.
+       go func() {
+               for {
+                       p, err := encapsulation.ReadData(conn)
+                       if err != nil {
+                               errCh <- err
+                               break
+                       }
+                       pconn.QueueIncoming(p, clientID)
+               }
+       }()
+
+       // At the same time, grab packets addressed to this ClientID and
+       // encapsulate them into the downstream.
+       go func() {
+               // Buffer encapsulation.WriteData operations to keep length
+               // prefixes in the same send as the data that follows.
+               bw := bufio.NewWriter(conn)
+               for p := range pconn.OutgoingQueue(clientID) {
+                       _, err := encapsulation.WriteData(bw, p)
+                       if err == nil {
+                               err = bw.Flush()
+                       }
+                       if err != nil {
+                               errCh <- err
+                               break
+                       }
+               }
+       }()
+
+       // Wait until one of the above loops terminates. The closing of the
+       // WebSocket connection will terminate the other one.
+       <-errCh
+
+       return nil
+}
+
+type ClientMapAddr string
+
+func (addr ClientMapAddr) Network() string {
+       return "snowflake"
+}
+
+func (addr ClientMapAddr) String() string {
+       return string(addr)
+}
+
+// Return a client address
+func clientAddr(clientIPParam string) net.Addr {
+       if clientIPParam == "" {
+               return ClientMapAddr("")
+       }
+       // Check if client addr is a valid IP
+       clientIP := net.ParseIP(clientIPParam)
+       if clientIP == nil {
+               return ClientMapAddr("")
+       }
+       // Check if client addr is 0.0.0.0 or [::]. Some proxies erroneously
+       // report an address of 0.0.0.0: https://bugs.torproject.org/33157.
+       if clientIP.IsUnspecified() {
+               return ClientMapAddr("")
+       }
+       // Add a stub port number. USERADDR requires a port number.
+       return ClientMapAddr((&net.TCPAddr{IP: clientIP, Port: 1, Zone: 
""}).String())
+}
diff --git a/server/lib/server_test.go b/server/lib/server_test.go
new file mode 100644
index 0000000..65d31d1
--- /dev/null
+++ b/server/lib/server_test.go
@@ -0,0 +1,55 @@
+package lib
+
+import (
+       "net"
+       "strconv"
+       "testing"
+
+       . "github.com/smartystreets/goconvey/convey"
+)
+
+func TestClientAddr(t *testing.T) {
+       Convey("Testing clientAddr", t, func() {
+               // good tests
+               for _, test := range []struct {
+                       input    string
+                       expected net.IP
+               }{
+                       {"1.2.3.4", net.ParseIP("1.2.3.4")},
+                       {"1:2::3:4", net.ParseIP("1:2::3:4")},
+               } {
+                       useraddr := clientAddr(test.input).String()
+                       host, port, err := net.SplitHostPort(useraddr)
+                       if err != nil {
+                               t.Errorf("clientAddr(%q) → SplitHostPort 
error %v", test.input, err)
+                               continue
+                       }
+                       if !test.expected.Equal(net.ParseIP(host)) {
+                               t.Errorf("clientAddr(%q) → host %q, not %v", 
test.input, host, test.expected)
+                       }
+                       portNo, err := strconv.Atoi(port)
+                       if err != nil {
+                               t.Errorf("clientAddr(%q) → port %q", 
test.input, port)
+                               continue
+                       }
+                       if portNo == 0 {
+                               t.Errorf("clientAddr(%q) → port %d", 
test.input, portNo)
+                       }
+               }
+
+               // bad tests
+               for _, input := range []string{
+                       "",
+                       "abc",
+                       "1.2.3.4.5",
+                       "[12::34]",
+                       "0.0.0.0",
+                       "[::]",
+               } {
+                       useraddr := clientAddr(input).String()
+                       if useraddr != "" {
+                               t.Errorf("clientAddr(%q) → %q, not %q", 
input, useraddr, "")
+                       }
+               }
+       })
+}
diff --git a/server/lib/snowflake.go b/server/lib/snowflake.go
new file mode 100644
index 0000000..319acd8
--- /dev/null
+++ b/server/lib/snowflake.go
@@ -0,0 +1,242 @@
+package lib
+
+import (
+       "crypto/tls"
+       "fmt"
+       "io"
+       "log"
+       "net"
+       "net/http"
+       "sync"
+       "time"
+
+       
"git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel"
+       "github.com/xtaci/kcp-go/v5"
+       "github.com/xtaci/smux"
+       "golang.org/x/net/http2"
+)
+
+// Transport is a structure with methods that conform to the Go PT v2.1 API
+// 
https://github.com/Pluggable-Transports/Pluggable-Transports-spec/blob/master/releases/PTSpecV2.1/Pluggable%20Transport%20Specification%20v2.1%20-%20Go%20Transport%20API.pdf
+type Transport struct {
+       getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)
+}
+
+func NewSnowflakeServer(getCertificate func(*tls.ClientHelloInfo) 
(*tls.Certificate, error)) *Transport {
+
+       return &Transport{getCertificate: getCertificate}
+}
+
+func (t *Transport) Listen(addr net.Addr) (*SnowflakeListener, error) {
+       listener := &SnowflakeListener{addr: addr, queue: make(chan net.Conn, 
65534)}
+
+       handler := HTTPHandler{
+               // pconn is shared among all connections to this server. It
+               // overlays packet-based client sessions on top of ephemeral
+               // WebSocket connections.
+               pconn: turbotunnel.NewQueuePacketConn(addr, clientMapTimeout),
+       }
+       server := &http.Server{
+               Addr:        addr.String(),
+               Handler:     &handler,
+               ReadTimeout: requestTimeout,
+       }
+       // We need to override server.TLSConfig.GetCertificate--but first
+       // server.TLSConfig needs to be non-nil. If we just create our own new
+       // &tls.Config, it will lack the default settings that the net/http
+       // package sets up for things like HTTP/2. Therefore we first call
+       // http2.ConfigureServer for its side effect of initializing
+       // server.TLSConfig properly. An alternative would be to make a dummy
+       // net.Listener, call Serve on it, and let it return.
+       // https://github.com/golang/go/issues/16588#issuecomment-237386446
+       err := http2.ConfigureServer(server, nil)
+       if err != nil {
+               return nil, err
+       }
+       server.TLSConfig.GetCertificate = t.getCertificate
+
+       // Another unfortunate effect of the inseparable net/http ListenAndServe
+       // is that we can't check for Listen errors like "permission denied" and
+       // "address already in use" without potentially entering the infinite
+       // loop of Serve. The hack we apply here is to wait a short time,
+       // listenAndServeErrorTimeout, to see if an error is returned (because
+       // it's better if the error message goes to the tor log through
+       // SMETHOD-ERROR than if it only goes to the snowflake log).
+       errChan := make(chan error)
+       go func() {
+               if t.getCertificate == nil {
+                       // TLS is disabled
+                       log.Printf("listening with plain HTTP on %s", addr)
+                       err := server.ListenAndServe()
+                       if err != nil {
+                               log.Printf("error in ListenAndServe: %s", err)
+                       }
+                       errChan <- err
+               } else {
+                       log.Printf("listening with HTTPS on %s", addr)
+                       err := server.ListenAndServeTLS("", "")
+                       if err != nil {
+                               log.Printf("error in ListenAndServeTLS: %s", 
err)
+                       }
+                       errChan <- err
+               }
+       }()
+
+       select {
+       case err = <-errChan:
+               break
+       case <-time.After(listenAndServeErrorTimeout):
+               break
+       }
+
+       listener.server = server
+
+       // Start a KCP engine, set up to read and write its packets over the
+       // WebSocket connections that arrive at the web server.
+       // handler.ServeHTTP is responsible for encapsulation/decapsulation of
+       // packets on behalf of KCP. KCP takes those packets and turns them into
+       // sessions which appear in the acceptSessions function.
+       ln, err := kcp.ServeConn(nil, 0, 0, handler.pconn)
+       if err != nil {
+               server.Close()
+               return nil, err
+       }
+       go func() {
+               defer ln.Close()
+               err := listener.acceptSessions(ln)
+               if err != nil {
+                       log.Printf("acceptSessions: %v", err)
+               }
+       }()
+
+       listener.ln = ln
+
+       return listener, nil
+
+}
+
+type SnowflakeListener struct {
+       addr      net.Addr
+       queue     chan net.Conn
+       server    *http.Server
+       ln        *kcp.Listener
+       closed    chan struct{}
+       closeOnce sync.Once
+}
+
+// Allows the caller to accept incoming Snowflake connections
+// We accept connections from a queue to accommodate both incoming
+// smux Streams and legacy non-turbotunnel connections
+func (l *SnowflakeListener) Accept() (net.Conn, error) {
+       select {
+       case <-l.closed:
+               //channel has been closed, no longer accepting connections
+               return nil, io.ErrClosedPipe
+       case conn := <-l.queue:
+               return conn, nil
+       }
+}
+
+func (l *SnowflakeListener) Addr() net.Addr {
+       return l.addr
+}
+
+func (l *SnowflakeListener) Close() error {
+       // Close our HTTP server and our KCP listener
+       l.closeOnce.Do(func() {
+               close(l.closed)
+               l.server.Close()
+               l.ln.Close()
+       })
+       return nil
+}
+
+// acceptStreams layers an smux.Session on the KCP connection and awaits 
streams
+// on it. Passes each stream to our SnowflakeListener accept queue.
+func (l *SnowflakeListener) acceptStreams(conn *kcp.UDPSession) error {
+       // Look up the IP address associated with this KCP session, via the
+       // ClientID that is returned by the session's RemoteAddr method.
+       addr, ok := 
clientIDAddrMap.Get(conn.RemoteAddr().(turbotunnel.ClientID))
+       if !ok {
+               // This means that the map is tending to run over capacity, not
+               // just that there was not client_ip on the incoming connection.
+               // We store "" in the map in the absence of client_ip. This log
+               // message means you should increase clientIDAddrMapCapacity.
+               log.Printf("no address in clientID-to-IP map (capacity %d)", 
clientIDAddrMapCapacity)
+       }
+
+       smuxConfig := smux.DefaultConfig()
+       smuxConfig.Version = 2
+       smuxConfig.KeepAliveTimeout = 10 * time.Minute
+       sess, err := smux.Server(conn, smuxConfig)
+       if err != nil {
+               return err
+       }
+
+       for {
+               stream, err := sess.AcceptStream()
+               if err != nil {
+                       if err, ok := err.(net.Error); ok && err.Temporary() {
+                               continue
+                       }
+                       return err
+               }
+               l.QueueConn(&SnowflakeClientConn{Conn: stream, address: 
clientAddr(addr)})
+       }
+}
+
+// acceptSessions listens for incoming KCP connections and passes them to
+// acceptStreams. It is handler.ServeHTTP that provides the network interface
+// that drives this function.
+func (l *SnowflakeListener) acceptSessions(ln *kcp.Listener) error {
+       for {
+               conn, err := ln.AcceptKCP()
+               if err != nil {
+                       if err, ok := err.(net.Error); ok && err.Temporary() {
+                               continue
+                       }
+                       return err
+               }
+               // Permit coalescing the payloads of consecutive sends.
+               conn.SetStreamMode(true)
+               // Set the maximum send and receive window sizes to a high 
number
+               // Removes KCP bottlenecks: 
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40026
+               conn.SetWindowSize(65535, 65535)
+               // Disable the dynamic congestion window (limit only by the
+               // maximum of local and remote static windows).
+               conn.SetNoDelay(
+                       0, // default nodelay
+                       0, // default interval
+                       0, // default resend
+                       1, // nc=1 => congestion window off
+               )
+               go func() {
+                       defer conn.Close()
+                       err := l.acceptStreams(conn)
+                       if err != nil && err != io.ErrClosedPipe {
+                               log.Printf("acceptStreams: %v", err)
+                       }
+               }()
+       }
+}
+
+func (l *SnowflakeListener) QueueConn(conn net.Conn) error {
+       select {
+       case <-l.closed:
+               return fmt.Errorf("accepted connection on closed listener")
+       case l.queue <- conn:
+               return nil
+       }
+}
+
+// A wrapper for the underlying oneshot or turbotunnel conn
+// because we need to reference our mapping to determine the client
+// address
+type SnowflakeClientConn struct {
+       net.Conn
+       address net.Addr
+}
+
+func (conn *SnowflakeClientConn) RemoteAddr() net.Addr {
+       return conn.address
+}
diff --git a/server/turbotunnel.go b/server/lib/turbotunnel.go
similarity index 99%
rename from server/turbotunnel.go
rename to server/lib/turbotunnel.go
index 1d00897..bb16fa3 100644
--- a/server/turbotunnel.go
+++ b/server/lib/turbotunnel.go
@@ -1,4 +1,4 @@
-package main
+package lib
 
 import (
        "sync"
diff --git a/server/turbotunnel_test.go b/server/lib/turbotunnel_test.go
similarity index 99%
rename from server/turbotunnel_test.go
rename to server/lib/turbotunnel_test.go
index c4bf02b..ba4cf60 100644
--- a/server/turbotunnel_test.go
+++ b/server/lib/turbotunnel_test.go
@@ -1,4 +1,4 @@
-package main
+package lib
 
 import (
        "encoding/binary"
diff --git a/server/server.go b/server/server.go
index 620cd50..b61d5b4 100644
--- a/server/server.go
+++ b/server/server.go
@@ -3,9 +3,6 @@
 package main
 
 import (
-       "bufio"
-       "bytes"
-       "crypto/tls"
        "flag"
        "fmt"
        "io"
@@ -19,38 +16,15 @@ import (
        "strings"
        "sync"
        "syscall"
-       "time"
 
-       pt "git.torproject.org/pluggable-transports/goptlib.git"
-       
"git.torproject.org/pluggable-transports/snowflake.git/common/encapsulation"
        "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
-       
"git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel"
-       
"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
-       "github.com/gorilla/websocket"
-       "github.com/xtaci/kcp-go/v5"
-       "github.com/xtaci/smux"
        "golang.org/x/crypto/acme/autocert"
-       "golang.org/x/net/http2"
+
+       pt "git.torproject.org/pluggable-transports/goptlib.git"
+       sf "git.torproject.org/pluggable-transports/snowflake.git/server/lib"
 )
 
 const ptMethodName = "snowflake"
-const requestTimeout = 10 * time.Second
-
-// How long to remember outgoing packets for a client, when we don't currently
-// have an active WebSocket connection corresponding to that client. Because a
-// client session may span multiple WebSocket connections, we keep packets we
-// aren't able to send immediately in memory, for a little while but not
-// indefinitely.
-const clientMapTimeout = 1 * time.Minute
-
-// How big to make the map of ClientIDs to IP addresses. The map is used in
-// turbotunnelMode to store a reasonable IP address for a client session that
-// may outlive any single WebSocket connection.
-const clientIDAddrMapCapacity = 1024
-
-// How long to wait for ListenAndServe or ListenAndServeTLS to return an error
-// before deciding that it's not going to return.
-const listenAndServeErrorTimeout = 100 * time.Millisecond
 
 var ptInfo pt.ServerInfo
 
@@ -92,366 +66,30 @@ func proxy(local *net.TCPConn, conn net.Conn) {
        wg.Wait()
 }
 
-// Return an address string suitable to pass into pt.DialOr.
-func clientAddr(clientIPParam string) string {
-       if clientIPParam == "" {
-               return ""
-       }
-       // Check if client addr is a valid IP
-       clientIP := net.ParseIP(clientIPParam)
-       if clientIP == nil {
-               return ""
-       }
-       // Check if client addr is 0.0.0.0 or [::]. Some proxies erroneously
-       // report an address of 0.0.0.0: https://bugs.torproject.org/33157.
-       if clientIP.IsUnspecified() {
-               return ""
-       }
-       // Add a dummy port number. USERADDR requires a port number.
-       return (&net.TCPAddr{IP: clientIP, Port: 1, Zone: ""}).String()
-}
-
-var upgrader = websocket.Upgrader{
-       CheckOrigin: func(r *http.Request) bool { return true },
-}
-
-// clientIDAddrMap stores short-term mappings from ClientIDs to IP addresses.
-// When we call pt.DialOr, tor wants us to provide a USERADDR string that
-// represents the remote IP address of the client (for metrics purposes, etc.).
-// This data structure bridges the gap between ServeHTTP, which knows about IP
-// addresses, and handleStream, which is what calls pt.DialOr. The common piece
-// of information linking both ends of the chain is the ClientID, which is
-// attached to the WebSocket connection and every session.
-var clientIDAddrMap = newClientIDMap(clientIDAddrMapCapacity)
-
-// overrideReadConn is a net.Conn with an overridden Read method. Compare to
-// recordingConn at
-// https://dave.cheney.net/2015/05/22/struct-composition-with-go.
-type overrideReadConn struct {
-       net.Conn
-       io.Reader
-}
-
-func (conn *overrideReadConn) Read(p []byte) (int, error) {
-       return conn.Reader.Read(p)
-}
-
-type HTTPHandler struct {
-       // pconn is the adapter layer between stream-oriented WebSocket
-       // connections and the packet-oriented KCP layer.
-       pconn *turbotunnel.QueuePacketConn
-}
-
-func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-       ws, err := upgrader.Upgrade(w, r, nil)
-       if err != nil {
-               log.Println(err)
-               return
-       }
-
-       conn := websocketconn.New(ws)
-       defer conn.Close()
-
-       // Pass the address of client as the remote address of incoming 
connection
-       clientIPParam := r.URL.Query().Get("client_ip")
-       addr := clientAddr(clientIPParam)
-
-       var token [len(turbotunnel.Token)]byte
-       _, err = io.ReadFull(conn, token[:])
-       if err != nil {
-               // Don't bother logging EOF: that happens with an unused
-               // connection, which clients make frequently as they maintain a
-               // pool of proxies.
-               if err != io.EOF {
-                       log.Printf("reading token: %v", err)
-               }
-               return
-       }
-
-       switch {
-       case bytes.Equal(token[:], turbotunnel.Token[:]):
-               err = turbotunnelMode(conn, addr, handler.pconn)
-       default:
-               // We didn't find a matching token, which means that we are
-               // dealing with a client that doesn't know about such things.
-               // "Unread" the token by constructing a new Reader and pass it
-               // to the old one-session-per-WebSocket mode.
-               conn2 := &overrideReadConn{Conn: conn, Reader: 
io.MultiReader(bytes.NewReader(token[:]), conn)}
-               err = oneshotMode(conn2, addr)
-       }
-       if err != nil {
-               log.Println(err)
-               return
-       }
-}
-
-// oneshotMode handles clients that did not send turbotunnel.Token at the start
-// of their stream. These clients use the WebSocket as a raw pipe, and expect
-// their session to begin and end when this single WebSocket does.
-func oneshotMode(conn net.Conn, addr string) error {
-       statsChannel <- addr != ""
-       or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
-       if err != nil {
-               return fmt.Errorf("failed to connect to ORPort: %s", err)
-       }
-       defer or.Close()
-
-       proxy(or, conn)
-
-       return nil
-}
-
-// turbotunnelMode handles clients that sent turbotunnel.Token at the start of
-// their stream. These clients expect to send and receive encapsulated packets,
-// with a long-lived session identified by ClientID.
-func turbotunnelMode(conn net.Conn, addr string, pconn 
*turbotunnel.QueuePacketConn) error {
-       // Read the ClientID prefix. Every packet encapsulated in this WebSocket
-       // connection pertains to the same ClientID.
-       var clientID turbotunnel.ClientID
-       _, err := io.ReadFull(conn, clientID[:])
-       if err != nil {
-               return fmt.Errorf("reading ClientID: %v", err)
-       }
-
-       // Store a a short-term mapping from the ClientID to the client IP
-       // address attached to this WebSocket connection. tor will want us to
-       // provide a client IP address when we call pt.DialOr. But a KCP session
-       // does not necessarily correspond to any single IP address--it's
-       // composed of packets that are carried in possibly multiple WebSocket
-       // streams. We apply the heuristic that the IP address of the most
-       // recent WebSocket connection that has had to do with a session, at the
-       // time the session is established, is the IP address that should be
-       // credited for the entire KCP session.
-       clientIDAddrMap.Set(clientID, addr)
-
-       errCh := make(chan error)
-
-       // The remainder of the WebSocket stream consists of encapsulated
-       // packets. We read them one by one and feed them into the
-       // QueuePacketConn on which kcp.ServeConn was set up, which eventually
-       // leads to KCP-level sessions in the acceptSessions function.
-       go func() {
-               for {
-                       p, err := encapsulation.ReadData(conn)
-                       if err != nil {
-                               errCh <- err
-                               break
-                       }
-                       pconn.QueueIncoming(p, clientID)
-               }
-       }()
-
-       // At the same time, grab packets addressed to this ClientID and
-       // encapsulate them into the downstream.
-       go func() {
-               // Buffer encapsulation.WriteData operations to keep length
-               // prefixes in the same send as the data that follows.
-               bw := bufio.NewWriter(conn)
-               for p := range pconn.OutgoingQueue(clientID) {
-                       _, err := encapsulation.WriteData(bw, p)
-                       if err == nil {
-                               err = bw.Flush()
-                       }
-                       if err != nil {
-                               errCh <- err
-                               break
-                       }
-               }
-       }()
-
-       // Wait until one of the above loops terminates. The closing of the
-       // WebSocket connection will terminate the other one.
-       <-errCh
-
-       return nil
-}
-
-// handleStream bidirectionally connects a client stream with the ORPort.
-func handleStream(stream net.Conn, addr string) error {
-       statsChannel <- addr != ""
-       or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
-       if err != nil {
-               return fmt.Errorf("connecting to ORPort: %v", err)
-       }
-       defer or.Close()
-
-       proxy(or, stream)
-
-       return nil
-}
-
-// acceptStreams layers an smux.Session on the KCP connection and awaits 
streams
-// on it. Passes each stream to handleStream.
-func acceptStreams(conn *kcp.UDPSession) error {
-       // Look up the IP address associated with this KCP session, via the
-       // ClientID that is returned by the session's RemoteAddr method.
-       addr, ok := 
clientIDAddrMap.Get(conn.RemoteAddr().(turbotunnel.ClientID))
-       if !ok {
-               // This means that the map is tending to run over capacity, not
-               // just that there was not client_ip on the incoming connection.
-               // We store "" in the map in the absence of client_ip. This log
-               // message means you should increase clientIDAddrMapCapacity.
-               log.Printf("no address in clientID-to-IP map (capacity %d)", 
clientIDAddrMapCapacity)
-       }
-
-       smuxConfig := smux.DefaultConfig()
-       smuxConfig.Version = 2
-       smuxConfig.KeepAliveTimeout = 10 * time.Minute
-       sess, err := smux.Server(conn, smuxConfig)
-       if err != nil {
-               return err
-       }
-
+func acceptLoop(ln net.Listener) {
        for {
-               stream, err := sess.AcceptStream()
+               conn, err := ln.Accept()
                if err != nil {
                        if err, ok := err.(net.Error); ok && err.Temporary() {
                                continue
                        }
-                       return err
+                       log.Printf("Snowflake accept error: %s", err)
+                       break
                }
-               go func() {
-                       defer stream.Close()
-                       err := handleStream(stream, addr)
-                       if err != nil {
-                               log.Printf("handleStream: %v", err)
-                       }
-               }()
-       }
-}
+               defer conn.Close()
 
-// acceptSessions listens for incoming KCP connections and passes them to
-// acceptStreams. It is handler.ServeHTTP that provides the network interface
-// that drives this function.
-func acceptSessions(ln *kcp.Listener) error {
-       for {
-               conn, err := ln.AcceptKCP()
+               addr := conn.RemoteAddr().String()
+               statsChannel <- addr != ""
+               or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
                if err != nil {
-                       if err, ok := err.(net.Error); ok && err.Temporary() {
-                               continue
-                       }
-                       return err
+                       log.Printf("failed to connect to ORPort: %s", err)
+                       continue
                }
-               // Permit coalescing the payloads of consecutive sends.
-               conn.SetStreamMode(true)
-               // Set the maximum send and receive window sizes to a high 
number
-               // Removes KCP bottlenecks: 
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40026
-               conn.SetWindowSize(65535, 65535)
-               // Disable the dynamic congestion window (limit only by the
-               // maximum of local and remote static windows).
-               conn.SetNoDelay(
-                       0, // default nodelay
-                       0, // default interval
-                       0, // default resend
-                       1, // nc=1 => congestion window off
-               )
-               go func() {
-                       defer conn.Close()
-                       err := acceptStreams(conn)
-                       if err != nil && err != io.ErrClosedPipe {
-                               log.Printf("acceptStreams: %v", err)
-                       }
-               }()
+               defer or.Close()
+               go proxy(or, conn)
        }
 }
 
-func initServer(addr *net.TCPAddr,
-       getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error),
-       listenAndServe func(*http.Server, chan<- error)) (*http.Server, error) {
-       // We're not capable of listening on port 0 (i.e., an ephemeral port
-       // unknown in advance). The reason is that while the net/http package
-       // exposes ListenAndServe and ListenAndServeTLS, those functions never
-       // return, so there's no opportunity to find out what the port number
-       // is, in between the Listen and Serve steps.
-       // https://groups.google.com/d/msg/Golang-nuts/3F1VRCCENp8/3hcayZiwYM8J
-       if addr.Port == 0 {
-               return nil, fmt.Errorf("cannot listen on port %d; configure a 
port using ServerTransportListenAddr", addr.Port)
-       }
-
-       handler := HTTPHandler{
-               // pconn is shared among all connections to this server. It
-               // overlays packet-based client sessions on top of ephemeral
-               // WebSocket connections.
-               pconn: turbotunnel.NewQueuePacketConn(addr, clientMapTimeout),
-       }
-       server := &http.Server{
-               Addr:        addr.String(),
-               Handler:     &handler,
-               ReadTimeout: requestTimeout,
-       }
-       // We need to override server.TLSConfig.GetCertificate--but first
-       // server.TLSConfig needs to be non-nil. If we just create our own new
-       // &tls.Config, it will lack the default settings that the net/http
-       // package sets up for things like HTTP/2. Therefore we first call
-       // http2.ConfigureServer for its side effect of initializing
-       // server.TLSConfig properly. An alternative would be to make a dummy
-       // net.Listener, call Serve on it, and let it return.
-       // https://github.com/golang/go/issues/16588#issuecomment-237386446
-       err := http2.ConfigureServer(server, nil)
-       if err != nil {
-               return server, err
-       }
-       server.TLSConfig.GetCertificate = getCertificate
-
-       // Another unfortunate effect of the inseparable net/http ListenAndServe
-       // is that we can't check for Listen errors like "permission denied" and
-       // "address already in use" without potentially entering the infinite
-       // loop of Serve. The hack we apply here is to wait a short time,
-       // listenAndServeErrorTimeout, to see if an error is returned (because
-       // it's better if the error message goes to the tor log through
-       // SMETHOD-ERROR than if it only goes to the snowflake log).
-       errChan := make(chan error)
-       go listenAndServe(server, errChan)
-       select {
-       case err = <-errChan:
-               break
-       case <-time.After(listenAndServeErrorTimeout):
-               break
-       }
-
-       // Start a KCP engine, set up to read and write its packets over the
-       // WebSocket connections that arrive at the web server.
-       // handler.ServeHTTP is responsible for encapsulation/decapsulation of
-       // packets on behalf of KCP. KCP takes those packets and turns them into
-       // sessions which appear in the acceptSessions function.
-       ln, err := kcp.ServeConn(nil, 0, 0, handler.pconn)
-       if err != nil {
-               server.Close()
-               return server, err
-       }
-       go func() {
-               defer ln.Close()
-               err := acceptSessions(ln)
-               if err != nil {
-                       log.Printf("acceptSessions: %v", err)
-               }
-       }()
-
-       return server, err
-}
-
-func startServer(addr *net.TCPAddr) (*http.Server, error) {
-       return initServer(addr, nil, func(server *http.Server, errChan chan<- 
error) {
-               log.Printf("listening with plain HTTP on %s", addr)
-               err := server.ListenAndServe()
-               if err != nil {
-                       log.Printf("error in ListenAndServe: %s", err)
-               }
-               errChan <- err
-       })
-}
-
-func startServerTLS(addr *net.TCPAddr, getCertificate 
func(*tls.ClientHelloInfo) (*tls.Certificate, error)) (*http.Server, error) {
-       return initServer(addr, getCertificate, func(server *http.Server, 
errChan chan<- error) {
-               log.Printf("listening with HTTPS on %s", addr)
-               err := server.ListenAndServeTLS("", "")
-               if err != nil {
-                       log.Printf("error in ListenAndServeTLS: %s", err)
-               }
-               errChan <- err
-       })
-}
-
 func getCertificateCacheDir() (string, error) {
        stateDir, err := pt.MakeStateDir()
        if err != nil {
@@ -535,7 +173,7 @@ func main() {
        // 
https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http-challenge
        needHTTP01Listener := !disableTLS
 
-       servers := make([]*http.Server, 0)
+       listeners := make([]net.Listener, 0)
        for _, bindaddr := range ptInfo.Bindaddrs {
                if bindaddr.MethodName != ptMethodName {
                        pt.SmethodError(bindaddr.MethodName, "no such method")
@@ -560,29 +198,47 @@ func main() {
                        go func() {
                                log.Fatal(server.Serve(lnHTTP01))
                        }()
-                       servers = append(servers, server)
+                       listeners = append(listeners, lnHTTP01)
                        needHTTP01Listener = false
                }
 
-               var server *http.Server
+               // We're not capable of listening on port 0 (i.e., an ephemeral 
port
+               // unknown in advance). The reason is that while the net/http 
package
+               // exposes ListenAndServe and ListenAndServeTLS, those 
functions never
+               // return, so there's no opportunity to find out what the port 
number
+               // is, in between the Listen and Serve steps.
+               // 
https://groups.google.com/d/msg/Golang-nuts/3F1VRCCENp8/3hcayZiwYM8J
+               if bindaddr.Addr.Port == 0 {
+                       err := fmt.Errorf(
+                               "cannot listen on port %d; configure a port 
using ServerTransportListenAddr",
+                               bindaddr.Addr.Port)
+                       log.Printf("error opening listener: %s", err)
+                       pt.SmethodError(bindaddr.MethodName, err.Error())
+                       continue
+               }
+
+               var transport *sf.Transport
                args := pt.Args{}
                if disableTLS {
                        args.Add("tls", "no")
-                       server, err = startServer(bindaddr.Addr)
+                       transport = sf.NewSnowflakeServer(nil)
                } else {
                        args.Add("tls", "yes")
                        for _, hostname := range acmeHostnames {
                                args.Add("hostname", hostname)
                        }
-                       server, err = startServerTLS(bindaddr.Addr, 
certManager.GetCertificate)
+                       transport = 
sf.NewSnowflakeServer(certManager.GetCertificate)
                }
+               ln, err := transport.Listen(bindaddr.Addr)
                if err != nil {
                        log.Printf("error opening listener: %s", err)
                        pt.SmethodError(bindaddr.MethodName, err.Error())
                        continue
                }
+               defer ln.Close()
+               go acceptLoop(ln)
                pt.SmethodArgs(bindaddr.MethodName, bindaddr.Addr, args)
-               servers = append(servers, server)
+               listeners = append(listeners, ln)
        }
        pt.SmethodsDone()
 
@@ -606,7 +262,7 @@ func main() {
 
        // Signal received, shut down.
        log.Printf("caught signal %q, exiting", sig)
-       for _, server := range servers {
-               server.Close()
+       for _, ln := range listeners {
+               ln.Close()
        }
 }
diff --git a/server/server_test.go b/server/server_test.go
deleted file mode 100644
index ba00d16..0000000
--- a/server/server_test.go
+++ /dev/null
@@ -1,153 +0,0 @@
-package main
-
-import (
-       "net"
-       "net/http"
-       "strconv"
-       "testing"
-
-       
"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
-       "github.com/gorilla/websocket"
-       . "github.com/smartystreets/goconvey/convey"
-)
-
-func TestClientAddr(t *testing.T) {
-       Convey("Testing clientAddr", t, func() {
-               // good tests
-               for _, test := range []struct {
-                       input    string
-                       expected net.IP
-               }{
-                       {"1.2.3.4", net.ParseIP("1.2.3.4")},
-                       {"1:2::3:4", net.ParseIP("1:2::3:4")},
-               } {
-                       useraddr := clientAddr(test.input)
-                       host, port, err := net.SplitHostPort(useraddr)
-                       if err != nil {
-                               t.Errorf("clientAddr(%q) → SplitHostPort 
error %v", test.input, err)
-                               continue
-                       }
-                       if !test.expected.Equal(net.ParseIP(host)) {
-                               t.Errorf("clientAddr(%q) → host %q, not %v", 
test.input, host, test.expected)
-                       }
-                       portNo, err := strconv.Atoi(port)
-                       if err != nil {
-                               t.Errorf("clientAddr(%q) → port %q", 
test.input, port)
-                               continue
-                       }
-                       if portNo == 0 {
-                               t.Errorf("clientAddr(%q) → port %d", 
test.input, portNo)
-                       }
-               }
-
-               // bad tests
-               for _, input := range []string{
-                       "",
-                       "abc",
-                       "1.2.3.4.5",
-                       "[12::34]",
-                       "0.0.0.0",
-                       "[::]",
-               } {
-                       useraddr := clientAddr(input)
-                       if useraddr != "" {
-                               t.Errorf("clientAddr(%q) → %q, not %q", 
input, useraddr, "")
-                       }
-               }
-       })
-}
-
-type StubHandler struct{}
-
-func (handler *StubHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-       ws, _ := upgrader.Upgrade(w, r, nil)
-
-       conn := websocketconn.New(ws)
-       defer conn.Close()
-
-       //dial stub OR
-       or, _ := net.DialTCP("tcp", nil, &net.TCPAddr{IP: 
net.ParseIP("localhost"), Port: 8889})
-
-       proxy(or, conn)
-}
-
-func Test(t *testing.T) {
-       Convey("Websocket server", t, func() {
-               //Set up the snowflake web server
-               ipStr, portStr, _ := net.SplitHostPort(":8888")
-               port, _ := strconv.ParseUint(portStr, 10, 16)
-               addr := &net.TCPAddr{IP: net.ParseIP(ipStr), Port: int(port)}
-               Convey("We don't listen on port 0", func() {
-                       addr = &net.TCPAddr{IP: net.ParseIP(ipStr), Port: 0}
-                       server, err := initServer(addr, nil,
-                               func(server *http.Server, errChan chan<- error) 
{
-                                       return
-                               })
-                       So(err, ShouldNotBeNil)
-                       So(server, ShouldBeNil)
-               })
-
-               Convey("Plain HTTP server accepts connections", func(c C) {
-                       server, err := startServer(addr)
-                       So(err, ShouldBeNil)
-
-                       ws, _, err := 
websocket.DefaultDialer.Dial("ws://localhost:8888", nil)
-                       wsConn := websocketconn.New(ws)
-                       So(err, ShouldEqual, nil)
-                       So(wsConn, ShouldNotEqual, nil)
-
-                       server.Close()
-                       wsConn.Close()
-
-               })
-               Convey("Handler proxies data", func(c C) {
-
-                       laddr := &net.TCPAddr{IP: net.ParseIP("localhost"), 
Port: 8889}
-
-                       go func() {
-
-                               //stub OR
-                               listener, err := net.ListenTCP("tcp", laddr)
-                               c.So(err, ShouldBeNil)
-                               conn, err := listener.Accept()
-                               c.So(err, ShouldBeNil)
-
-                               b := make([]byte, 5)
-                               n, err := conn.Read(b)
-                               c.So(err, ShouldBeNil)
-                               c.So(n, ShouldEqual, 5)
-                               c.So(b, ShouldResemble, []byte("Hello"))
-
-                               n, err = conn.Write([]byte("world!"))
-                               c.So(n, ShouldEqual, 6)
-                               c.So(err, ShouldBeNil)
-                       }()
-
-                       //overwite handler
-                       server, err := initServer(addr, nil,
-                               func(server *http.Server, errChan chan<- error) 
{
-                                       server.ListenAndServe()
-                               })
-                       So(err, ShouldBeNil)
-
-                       var handler StubHandler
-                       server.Handler = &handler
-
-                       ws, _, err := 
websocket.DefaultDialer.Dial("ws://localhost:8888", nil)
-                       So(err, ShouldEqual, nil)
-                       wsConn := websocketconn.New(ws)
-                       So(wsConn, ShouldNotEqual, nil)
-
-                       wsConn.Write([]byte("Hello"))
-                       b := make([]byte, 6)
-                       n, err := wsConn.Read(b)
-                       So(n, ShouldEqual, 6)
-                       So(b, ShouldResemble, []byte("world!"))
-
-                       wsConn.Close()
-                       server.Close()
-
-               })
-
-       })
-}

_______________________________________________
tor-commits mailing list
tor-commits@lists.torproject.org
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits

Reply via email to