commit e53705519e1aace78da1fad441c49114c90b858b
Author: David Fifield <[email protected]>
Date:   Sat Mar 26 12:05:06 2016 -0700

    Add a copy of websocket-server in the server directory.
    
    From commit ea3ee6b5cf035a9465c02b4a2b38882c9c6a7050 in
    https://git.torproject.org/pluggable-transports/websocket.git
---
 server/server.go | 291 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 291 insertions(+)

diff --git a/server/server.go b/server/server.go
new file mode 100644
index 0000000..6fb5a0b
--- /dev/null
+++ b/server/server.go
@@ -0,0 +1,291 @@
+// Tor websocket server transport plugin.
+//
+// Usage in torrc:
+//     ExtORPort 6669
+//     ServerTransportPlugin websocket exec ./websocket-server --port 9901
+package main
+
+import (
+       "encoding/base64"
+       "errors"
+       "flag"
+       "fmt"
+       "io"
+       "log"
+       "net"
+       "net/http"
+       "os"
+       "os/signal"
+       "sync"
+       "syscall"
+       "time"
+
+       "../websocket"
+
+       "git.torproject.org/pluggable-transports/goptlib.git"
+)
+
+const ptMethodName = "websocket"
+const requestTimeout = 10 * time.Second
+
+// "4/3+1" accounts for possible base64 encoding.
+const maxMessageSize = 64*1024*4/3 + 1
+
+var logFile = os.Stderr
+
+var ptInfo pt.ServerInfo
+
+// When a connection handler starts, +1 is written to this channel; when it
+// ends, -1 is written.
+var handlerChan = make(chan int)
+
+func usage() {
+       fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0])
+       fmt.Printf("WebSocket server pluggable transport for Tor.\n")
+       fmt.Printf("Works only as a managed proxy.\n")
+       fmt.Printf("\n")
+       fmt.Printf("  -h, --help   show this help.\n")
+       fmt.Printf("  --log FILE   log messages to FILE (default stderr).\n")
+       fmt.Printf("  --port PORT  listen on PORT (overrides Tor's requested 
port).\n")
+}
+
+// An abstraction that makes an underlying WebSocket connection look like an
+// io.ReadWriteCloser. It internally takes care of things like base64 encoding
+// and decoding.
+type webSocketConn struct {
+       Ws         *websocket.WebSocket
+       Base64     bool
+       messageBuf []byte
+}
+
+// Implements io.Reader.
+func (conn *webSocketConn) Read(b []byte) (n int, err error) {
+       for len(conn.messageBuf) == 0 {
+               var m websocket.Message
+               m, err = conn.Ws.ReadMessage()
+               if err != nil {
+                       return
+               }
+               if m.Opcode == 8 {
+                       err = io.EOF
+                       return
+               }
+               if conn.Base64 {
+                       if m.Opcode != 1 {
+                               err = errors.New(fmt.Sprintf("got non-text 
opcode %d with the base64 subprotocol", m.Opcode))
+                               return
+                       }
+                       conn.messageBuf = make([]byte, 
base64.StdEncoding.DecodedLen(len(m.Payload)))
+                       var num int
+                       num, err = base64.StdEncoding.Decode(conn.messageBuf, 
m.Payload)
+                       if err != nil {
+                               return
+                       }
+                       conn.messageBuf = conn.messageBuf[:num]
+               } else {
+                       if m.Opcode != 2 {
+                               err = errors.New(fmt.Sprintf("got non-binary 
opcode %d with no subprotocol", m.Opcode))
+                               return
+                       }
+                       conn.messageBuf = m.Payload
+               }
+       }
+
+       n = copy(b, conn.messageBuf)
+       conn.messageBuf = conn.messageBuf[n:]
+
+       return
+}
+
+// Implements io.Writer.
+func (conn *webSocketConn) Write(b []byte) (n int, err error) {
+       if conn.Base64 {
+               buf := make([]byte, base64.StdEncoding.EncodedLen(len(b)))
+               base64.StdEncoding.Encode(buf, b)
+               err = conn.Ws.WriteMessage(1, buf)
+               if err != nil {
+                       return
+               }
+               n = len(b)
+       } else {
+               err = conn.Ws.WriteMessage(2, b)
+               n = len(b)
+       }
+       return
+}
+
+// Implements io.Closer.
+func (conn *webSocketConn) Close() error {
+       // Ignore any error in trying to write a Close frame.
+       _ = conn.Ws.WriteFrame(8, nil)
+       return conn.Ws.Conn.Close()
+}
+
+// Create a new webSocketConn.
+func newWebSocketConn(ws *websocket.WebSocket) webSocketConn {
+       var conn webSocketConn
+       conn.Ws = ws
+       conn.Base64 = (ws.Subprotocol == "base64")
+       return conn
+}
+
+// Copy from WebSocket to socket and vice versa.
+func proxy(local *net.TCPConn, conn *webSocketConn) {
+       var wg sync.WaitGroup
+       wg.Add(2)
+
+       go func() {
+               _, err := io.Copy(conn, local)
+               if err != nil {
+                       log.Printf("error copying ORPort to WebSocket")
+               }
+               local.CloseRead()
+               conn.Close()
+               wg.Done()
+       }()
+       go func() {
+               _, err := io.Copy(local, conn)
+               if err != nil {
+                       log.Printf("error copying WebSocket to ORPort")
+               }
+               local.CloseWrite()
+               conn.Close()
+               wg.Done()
+       }()
+
+       wg.Wait()
+}
+
+func webSocketHandler(ws *websocket.WebSocket) {
+       // Undo timeouts on HTTP request handling.
+       ws.Conn.SetDeadline(time.Time{})
+       conn := newWebSocketConn(ws)
+       defer conn.Close()
+
+       handlerChan <- 1
+       defer func() {
+               handlerChan <- -1
+       }()
+
+       or, err := pt.DialOr(&ptInfo, ws.Conn.RemoteAddr().String(), 
ptMethodName)
+       if err != nil {
+               log.Printf("Failed to connect to ORPort: " + err.Error())
+               return
+       }
+       defer or.Close()
+
+       proxy(or, &conn)
+}
+
+func startListener(addr *net.TCPAddr) (*net.TCPListener, error) {
+       ln, err := net.ListenTCP("tcp", addr)
+       if err != nil {
+               return nil, err
+       }
+       go func() {
+               defer ln.Close()
+               var config websocket.Config
+               config.Subprotocols = []string{"base64"}
+               config.MaxMessageSize = maxMessageSize
+               s := &http.Server{
+                       Handler:     config.Handler(webSocketHandler),
+                       ReadTimeout: requestTimeout,
+               }
+               err = s.Serve(ln)
+               if err != nil {
+                       log.Printf("http.Serve: " + err.Error())
+               }
+       }()
+       return ln, nil
+}
+
+func main() {
+       var logFilename string
+       var port int
+
+       flag.Usage = usage
+       flag.StringVar(&logFilename, "log", "", "log file to write to")
+       flag.IntVar(&port, "port", 0, "port to listen on if unspecified by Tor")
+       flag.Parse()
+
+       if logFilename != "" {
+               f, err := os.OpenFile(logFilename, 
os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
+               if err != nil {
+                       fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", 
logFilename, err.Error())
+                       os.Exit(1)
+               }
+               log.SetOutput(f)
+       }
+
+       log.SetFlags(log.LstdFlags | log.LUTC)
+       log.Printf("starting")
+       var err error
+       ptInfo, err = pt.ServerSetup(nil)
+       if err != nil {
+               log.Printf("error in setup: %s", err)
+               os.Exit(1)
+       }
+
+       listeners := make([]*net.TCPListener, 0)
+       for _, bindaddr := range ptInfo.Bindaddrs {
+               // Override tor's requested port (which is 0 if this transport
+               // has not been run before) with the one requested by the --port
+               // option.
+               if port != 0 {
+                       bindaddr.Addr.Port = port
+               }
+
+               switch bindaddr.MethodName {
+               case ptMethodName:
+                       ln, err := startListener(bindaddr.Addr)
+                       if err != nil {
+                               pt.SmethodError(bindaddr.MethodName, 
err.Error())
+                               break
+                       }
+                       pt.Smethod(bindaddr.MethodName, ln.Addr())
+                       log.Printf("listening on %s", ln.Addr().String())
+                       listeners = append(listeners, ln)
+               default:
+                       pt.SmethodError(bindaddr.MethodName, "no such method")
+               }
+       }
+       pt.SmethodsDone()
+
+       var numHandlers int = 0
+       var sig os.Signal
+       sigChan := make(chan os.Signal, 1)
+       signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
+
+       // wait for first signal
+       sig = nil
+       for sig == nil {
+               select {
+               case n := <-handlerChan:
+                       numHandlers += n
+               case sig = <-sigChan:
+               }
+       }
+       log.Printf("Got first signal %q with %d running handlers.", sig, 
numHandlers)
+       for _, ln := range listeners {
+               ln.Close()
+       }
+
+       if sig == syscall.SIGTERM {
+               log.Printf("Caught signal %q, exiting.", sig)
+               return
+       }
+
+       // wait for second signal or no more handlers
+       sig = nil
+       for sig == nil && numHandlers != 0 {
+               select {
+               case n := <-handlerChan:
+                       numHandlers += n
+                       log.Printf("%d remaining handlers.", numHandlers)
+               case sig = <-sigChan:
+               }
+       }
+       if sig != nil {
+               log.Printf("Got second signal %q with %d running handlers.", 
sig, numHandlers)
+       }
+}



_______________________________________________
tor-commits mailing list
[email protected]
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits

Reply via email to