commit 50e4f4fd61596bab254cb34e850c9ae63d82f891
Author: idk <hankhill19...@gmail.com>
Date:   Mon Oct 25 22:51:40 2021 -0400

    Turn the proxy code into a library
    
    Allow other go programs to easily import the snowflake proxy library and
    start/stop a snowflake proxy.
---
 proxy/{ => lib}/proxy-go_test.go |   8 +-
 proxy/{ => lib}/snowflake.go     | 205 ++++++++++++++++++++++-----------------
 proxy/{ => lib}/tokens.go        |   2 +-
 proxy/{ => lib}/tokens_test.go   |   2 +-
 proxy/{ => lib}/util.go          |  18 +++-
 proxy/{ => lib}/webrtcconn.go    |   2 +-
 proxy/main.go                    |  48 +++++++++
 7 files changed, 185 insertions(+), 100 deletions(-)

diff --git a/proxy/proxy-go_test.go b/proxy/lib/proxy-go_test.go
similarity index 98%
rename from proxy/proxy-go_test.go
rename to proxy/lib/proxy-go_test.go
index 6fb5a0b9..af71648 100644
--- a/proxy/proxy-go_test.go
+++ b/proxy/lib/proxy-go_test.go
@@ -1,4 +1,4 @@
-package main
+package snowflake
 
 import (
        "bytes"
@@ -365,7 +365,7 @@ func TestBrokerInteractions(t *testing.T) {
                                b,
                        }
 
-                       sdp := broker.pollOffer(sampleOffer)
+                       sdp := broker.pollOffer(sampleOffer, nil)
                        expectedSDP, _ := strconv.Unquote(sampleSDP)
                        So(sdp.SDP, ShouldResemble, expectedSDP)
                })
@@ -379,7 +379,7 @@ func TestBrokerInteractions(t *testing.T) {
                                b,
                        }
 
-                       sdp := broker.pollOffer(sampleOffer)
+                       sdp := broker.pollOffer(sampleOffer, nil)
                        So(sdp, ShouldBeNil)
                })
                Convey("sends answer to broker", func() {
@@ -478,7 +478,7 @@ func TestUtilityFuncs(t *testing.T) {
        Convey("CopyLoop", t, func() {
                c1, s1 := net.Pipe()
                c2, s2 := net.Pipe()
-               go CopyLoop(s1, s2)
+               go copyLoop(s1, s2, nil)
                go func() {
                        bytes := []byte("Hello!")
                        c1.Write(bytes)
diff --git a/proxy/snowflake.go b/proxy/lib/snowflake.go
similarity index 72%
rename from proxy/snowflake.go
rename to proxy/lib/snowflake.go
index 7d7f9a2..e35eabd 100644
--- a/proxy/snowflake.go
+++ b/proxy/lib/snowflake.go
@@ -1,10 +1,9 @@
-package main
+package snowflake
 
 import (
        "bytes"
        "crypto/rand"
        "encoding/base64"
-       "flag"
        "fmt"
        "io"
        "io/ioutil"
@@ -12,27 +11,44 @@ import (
        "net"
        "net/http"
        "net/url"
-       "os"
        "strings"
        "sync"
        "time"
 
        "git.torproject.org/pluggable-transports/snowflake.git/common/messages"
-       "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
        "git.torproject.org/pluggable-transports/snowflake.git/common/util"
        
"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
        "github.com/gorilla/websocket"
        "github.com/pion/webrtc/v3"
 )
 
-const defaultBrokerURL = "https://snowflake-broker.torproject.net/";
-const defaultProbeURL = "https://snowflake-broker.torproject.net:8443/probe";
-const defaultRelayURL = "wss://snowflake.torproject.net/"
-const defaultSTUNURL = "stun:stun.stunprotocol.org:3478"
+// DefaultBrokerURL is the bamsoftware.com broker, 
https://snowflake-broker.bamsoftware.com
+// Changing this will change the default broker. The recommended way of 
changing
+// the broker that gets used is by passing an argument to Main.
+const DefaultBrokerURL = "https://snowflake-broker.bamsoftware.com/";
+
+// DefaultProbeURL is the torproject.org  ProbeURL, 
https://snowflake-broker.torproject.net:8443/probe
+// Changing this will change the default Probe URL. The recommended way of 
changing
+// the probe that gets used is by passing an argument to Main.
+const DefaultProbeURL = "https://snowflake-broker.torproject.net:8443/probe";
+
+// DefaultRelayURL is the bamsoftware.com  Websocket Relay, 
wss://snowflake.bamsoftware.com/
+// Changing this will change the default Relay URL. The recommended way of 
changing
+// the relay that gets used is by passing an argument to Main.
+const DefaultRelayURL = "wss://snowflake.bamsoftware.com/"
+
+// DefaultSTUNURL is a stunprotocol.org STUN URL. 
stun:stun.stunprotocol.org:3478
+// Changing this will change the default STUN URL. The recommended way of 
changing
+// the STUN Server that gets used is by passing an argument to Main.
+const DefaultSTUNURL = "stun:stun.stunprotocol.org:3478"
 const pollInterval = 5 * time.Second
+
 const (
-       NATUnknown      = "unknown"
-       NATRestricted   = "restricted"
+       // NATUnknown represents a NAT type which is unknown.
+       NATUnknown = "unknown"
+       // NATRestricted represents a restricted NAT.
+       NATRestricted = "restricted"
+       // NATUnrestricted represents an unrestricted NAT.
        NATUnrestricted = "unrestricted"
 )
 
@@ -43,7 +59,6 @@ const dataChannelTimeout = 20 * time.Second
 const readLimit = 100000 //Maximum number of bytes to be read from an HTTP 
request
 
 var broker *SignalingServer
-var relayURL string
 
 var currentNATType = NATUnknown
 
@@ -57,6 +72,18 @@ var (
        client http.Client
 )
 
+// SnowflakeProxy is a structure which is used to configure an embedded
+// Snowflake in another Go application.
+type SnowflakeProxy struct {
+       Capacity           uint
+       StunURL            string
+       RawBrokerURL       string
+       KeepLocalAddresses bool
+       RelayURL           string
+       LogOutput          io.Writer
+       shutdown           chan struct{}
+}
+
 // Checks whether an IP address is a remote address for the client
 func isRemoteAddress(ip net.IP) bool {
        return !(util.IsLocal(ip) || ip.IsUnspecified() || ip.IsLoopback())
@@ -81,6 +108,7 @@ func limitedRead(r io.Reader, limit int64) ([]byte, error) {
        return p, err
 }
 
+// SignalingServer keeps track of the SignalingServer in use by the Snowflake
 type SignalingServer struct {
        url                *url.URL
        transport          http.RoundTripper
@@ -102,6 +130,7 @@ func newSignalingServer(rawURL string, keepLocalAddresses 
bool) (*SignalingServe
        return s, nil
 }
 
+// Post sends a POST request to the SignalingServer
 func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) 
{
 
        req, err := http.NewRequest("POST", path, payload)
@@ -121,7 +150,7 @@ func (s *SignalingServer) Post(path string, payload 
io.Reader) ([]byte, error) {
        return limitedRead(resp.Body, readLimit)
 }
 
-func (s *SignalingServer) pollOffer(sid string) *webrtc.SessionDescription {
+func (s *SignalingServer) pollOffer(sid string, shutdown chan struct{}) 
*webrtc.SessionDescription {
        brokerPath := s.url.ResolveReference(&url.URL{Path: "proxy"})
 
        ticker := time.NewTicker(pollInterval)
@@ -129,31 +158,36 @@ func (s *SignalingServer) pollOffer(sid string) 
*webrtc.SessionDescription {
 
        // Run the loop once before hitting the ticker
        for ; true; <-ticker.C {
-               numClients := int((tokens.count() / 8) * 8) // Round down to 8
-               body, err := messages.EncodePollRequest(sid, "standalone", 
currentNATType, numClients)
-               if err != nil {
-                       log.Printf("Error encoding poll message: %s", 
err.Error())
+               select {
+               case <-shutdown:
                        return nil
-               }
-               resp, err := s.Post(brokerPath.String(), bytes.NewBuffer(body))
-               if err != nil {
-                       log.Printf("error polling broker: %s", err.Error())
-               }
+               default:
+                       numClients := int((tokens.count() / 8) * 8) // Round 
down to 8
+                       body, err := messages.EncodePollRequest(sid, 
"standalone", currentNATType, numClients)
+                       if err != nil {
+                               log.Printf("Error encoding poll message: %s", 
err.Error())
+                               return nil
+                       }
+                       resp, err := s.Post(brokerPath.String(), 
bytes.NewBuffer(body))
+                       if err != nil {
+                               log.Printf("error polling broker: %s", 
err.Error())
+                       }
 
-               offer, _, err := messages.DecodePollResponse(resp)
-               if err != nil {
-                       log.Printf("Error reading broker response: %s", 
err.Error())
-                       log.Printf("body: %s", resp)
-                       return nil
-               }
-               if offer != "" {
-                       offer, err := util.DeserializeSessionDescription(offer)
+                       offer, _, err := messages.DecodePollResponse(resp)
                        if err != nil {
-                               log.Printf("Error processing session 
description: %s", err.Error())
+                               log.Printf("Error reading broker response: %s", 
err.Error())
+                               log.Printf("body: %s", resp)
                                return nil
                        }
-                       return offer
+                       if offer != "" {
+                               offer, err := 
util.DeserializeSessionDescription(offer)
+                               if err != nil {
+                                       log.Printf("Error processing session 
description: %s", err.Error())
+                                       return nil
+                               }
+                               return offer
 
+                       }
                }
        }
        return nil
@@ -192,33 +226,41 @@ func (s *SignalingServer) sendAnswer(sid string, pc 
*webrtc.PeerConnection) erro
        return nil
 }
 
-func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) {
-       var wg sync.WaitGroup
+func copyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser, shutdown chan 
struct{}) {
+       var once sync.Once
+       defer c2.Close()
+       defer c1.Close()
+       done := make(chan struct{})
        copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
-               defer wg.Done()
                // Ignore io.ErrClosedPipe because it is likely caused by the
                // termination of copyer in the other direction.
                if _, err := io.Copy(dst, src); err != nil && err != 
io.ErrClosedPipe {
                        log.Printf("io.Copy inside CopyLoop generated an error: 
%v", err)
                }
-               dst.Close()
-               src.Close()
+               once.Do(func() {
+                       close(done)
+               })
        }
-       wg.Add(2)
+
        go copyer(c1, c2)
        go copyer(c2, c1)
-       wg.Wait()
+
+       select {
+       case <-done:
+       case <-shutdown:
+       }
+       log.Println("copy loop ended")
 }
 
 // We pass conn.RemoteAddr() as an additional parameter, rather than calling
 // conn.RemoteAddr() inside this function, as a workaround for a hang that
 // otherwise occurs inside of conn.pc.RemoteDescription() (called by
 // RemoteAddr). https://bugs.torproject.org/18628#comment:8
-func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
+func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr 
net.Addr) {
        defer conn.Close()
        defer tokens.ret()
 
-       u, err := url.Parse(relayURL)
+       u, err := url.Parse(sf.RelayURL)
        if err != nil {
                log.Fatalf("invalid relay url: %s", err)
        }
@@ -241,7 +283,7 @@ func datachannelHandler(conn *webRTCConn, remoteAddr 
net.Addr) {
        wsConn := websocketconn.New(ws)
        log.Printf("connected to relay")
        defer wsConn.Close()
-       CopyLoop(conn, wsConn)
+       copyLoop(conn, wsConn, sf.shutdown)
        log.Printf("datachannelHandler ends")
 }
 
@@ -249,7 +291,7 @@ func datachannelHandler(conn *webRTCConn, remoteAddr 
net.Addr) {
 // candidates is complete and the answer is available in LocalDescription.
 // Installs an OnDataChannel callback that creates a webRTCConn and passes it 
to
 // datachannelHandler.
-func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription,
+func (sf *SnowflakeProxy) makePeerConnectionFromOffer(sdp 
*webrtc.SessionDescription,
        config webrtc.Configuration,
        dataChan chan struct{},
        handler func(conn *webRTCConn, remoteAddr net.Addr)) 
(*webrtc.PeerConnection, error) {
@@ -333,7 +375,7 @@ func makePeerConnectionFromOffer(sdp 
*webrtc.SessionDescription,
 
 // Create a new PeerConnection. Blocks until the gathering of ICE
 // candidates is complete and the answer is available in LocalDescription.
-func makeNewPeerConnection(config webrtc.Configuration,
+func (sf *SnowflakeProxy) makeNewPeerConnection(config webrtc.Configuration,
        dataChan chan struct{}) (*webrtc.PeerConnection, error) {
 
        pc, err := webrtc.NewPeerConnection(config)
@@ -383,15 +425,15 @@ func makeNewPeerConnection(config webrtc.Configuration,
        return pc, nil
 }
 
-func runSession(sid string) {
-       offer := broker.pollOffer(sid)
+func (sf *SnowflakeProxy) runSession(sid string) {
+       offer := broker.pollOffer(sid, sf.shutdown)
        if offer == nil {
                log.Printf("bad offer from broker")
                tokens.ret()
                return
        }
        dataChan := make(chan struct{})
-       pc, err := makePeerConnectionFromOffer(offer, config, dataChan, 
datachannelHandler)
+       pc, err := sf.makePeerConnectionFromOffer(offer, config, dataChan, 
sf.datachannelHandler)
        if err != nil {
                log.Printf("error making WebRTC connection: %s", err)
                tokens.ret()
@@ -421,53 +463,28 @@ func runSession(sid string) {
        }
 }
 
-func main() {
-       var capacity uint
-       var stunURL string
-       var logFilename string
-       var rawBrokerURL string
-       var unsafeLogging bool
-       var keepLocalAddresses bool
-
-       flag.UintVar(&capacity, "capacity", 0, "maximum concurrent clients")
-       flag.StringVar(&rawBrokerURL, "broker", defaultBrokerURL, "broker URL")
-       flag.StringVar(&relayURL, "relay", defaultRelayURL, "websocket relay 
URL")
-       flag.StringVar(&stunURL, "stun", defaultSTUNURL, "stun URL")
-       flag.StringVar(&logFilename, "log", "", "log filename")
-       flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs 
from being scrubbed")
-       flag.BoolVar(&keepLocalAddresses, "keep-local-addresses", false, "keep 
local LAN address ICE candidates")
-       flag.Parse()
-
-       var logOutput io.Writer = os.Stderr
+// Start configures and starts a Snowflake, fully formed and special. In the
+// case of an empty map, defaults are configured automatically and can be
+// found in the GoDoc and in main.go
+func (sf *SnowflakeProxy) Start() {
+
+       sf.shutdown = make(chan struct{})
+
        log.SetFlags(log.LstdFlags | log.LUTC)
-       if logFilename != "" {
-               f, err := os.OpenFile(logFilename, 
os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
-               if err != nil {
-                       log.Fatal(err)
-               }
-               defer f.Close()
-               logOutput = io.MultiWriter(os.Stderr, f)
-       }
-       if unsafeLogging {
-               log.SetOutput(logOutput)
-       } else {
-               // We want to send the log output through our scrubber first
-               log.SetOutput(&safelog.LogScrubber{Output: logOutput})
-       }
 
        log.Println("starting")
 
        var err error
-       broker, err = newSignalingServer(rawBrokerURL, keepLocalAddresses)
+       broker, err = newSignalingServer(sf.RawBrokerURL, sf.KeepLocalAddresses)
        if err != nil {
                log.Fatal(err)
        }
 
-       _, err = url.Parse(stunURL)
+       _, err = url.Parse(sf.StunURL)
        if err != nil {
                log.Fatalf("invalid stun url: %s", err)
        }
-       _, err = url.Parse(relayURL)
+       _, err = url.Parse(sf.RelayURL)
        if err != nil {
                log.Fatalf("invalid relay url: %s", err)
        }
@@ -475,27 +492,37 @@ func main() {
        config = webrtc.Configuration{
                ICEServers: []webrtc.ICEServer{
                        {
-                               URLs: []string{stunURL},
+                               URLs: []string{sf.StunURL},
                        },
                },
        }
-       tokens = newTokens(capacity)
+       tokens = newTokens(sf.Capacity)
 
        // use probetest to determine NAT compatability
-       checkNATType(config, defaultProbeURL)
+       sf.checkNATType(config, DefaultProbeURL)
        log.Printf("NAT type: %s", currentNATType)
 
        ticker := time.NewTicker(pollInterval)
        defer ticker.Stop()
 
        for ; true; <-ticker.C {
-               tokens.get()
-               sessionID := genSessionID()
-               runSession(sessionID)
+               select {
+               case <-sf.shutdown:
+                       return
+               default:
+                       tokens.get()
+                       sessionID := genSessionID()
+                       sf.runSession(sessionID)
+               }
        }
 }
 
-func checkNATType(config webrtc.Configuration, probeURL string) {
+// Stop calls close on the sf.shutdown channel shutting down the Snowflake.
+func (sf *SnowflakeProxy) Stop() {
+       close(sf.shutdown)
+}
+
+func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL 
string) {
 
        probe, err := newSignalingServer(probeURL, false)
        if err != nil {
@@ -504,7 +531,7 @@ func checkNATType(config webrtc.Configuration, probeURL 
string) {
 
        // create offer
        dataChan := make(chan struct{})
-       pc, err := makeNewPeerConnection(config, dataChan)
+       pc, err := sf.makeNewPeerConnection(config, dataChan)
        if err != nil {
                log.Printf("error making WebRTC connection: %s", err)
                return
diff --git a/proxy/tokens.go b/proxy/lib/tokens.go
similarity index 97%
rename from proxy/tokens.go
rename to proxy/lib/tokens.go
index fedb8f7..1331778 100644
--- a/proxy/tokens.go
+++ b/proxy/lib/tokens.go
@@ -1,4 +1,4 @@
-package main
+package snowflake
 
 import (
        "sync/atomic"
diff --git a/proxy/tokens_test.go b/proxy/lib/tokens_test.go
similarity index 96%
rename from proxy/tokens_test.go
rename to proxy/lib/tokens_test.go
index 622cc05..702a887 100644
--- a/proxy/tokens_test.go
+++ b/proxy/lib/tokens_test.go
@@ -1,4 +1,4 @@
-package main
+package snowflake
 
 import (
        "testing"
diff --git a/proxy/util.go b/proxy/lib/util.go
similarity index 71%
rename from proxy/util.go
rename to proxy/lib/util.go
index d737056..c6613d9 100644
--- a/proxy/util.go
+++ b/proxy/lib/util.go
@@ -1,21 +1,28 @@
-package main
+package snowflake
 
 import (
        "fmt"
        "time"
 )
 
+// BytesLogger is an interface which is used to allow logging the throughput
+// of the Snowflake. A default BytesLogger(BytesNullLogger) does nothing.
 type BytesLogger interface {
        AddOutbound(int)
        AddInbound(int)
        ThroughputSummary() string
 }
 
-// Default BytesLogger does nothing.
+// BytesNullLogger Default BytesLogger does nothing.
 type BytesNullLogger struct{}
 
-func (b BytesNullLogger) AddOutbound(amount int)    {}
-func (b BytesNullLogger) AddInbound(amount int)     {}
+// AddOutbound in BytesNullLogger does nothing
+func (b BytesNullLogger) AddOutbound(amount int) {}
+
+// AddInbound in BytesNullLogger does nothing
+func (b BytesNullLogger) AddInbound(amount int) {}
+
+// ThroughputSummary in BytesNullLogger does nothing
 func (b BytesNullLogger) ThroughputSummary() string { return "" }
 
 // BytesSyncLogger uses channels to safely log from multiple sources with 
output
@@ -50,14 +57,17 @@ func (b *BytesSyncLogger) log() {
        }
 }
 
+// AddOutbound add a number of bytes to the outbound total reported by the 
logger
 func (b *BytesSyncLogger) AddOutbound(amount int) {
        b.outboundChan <- amount
 }
 
+// AddInbound add a number of bytes to the inbound total reported by the logger
 func (b *BytesSyncLogger) AddInbound(amount int) {
        b.inboundChan <- amount
 }
 
+// ThroughputSummary view a formatted summary of the throughput totals
 func (b *BytesSyncLogger) ThroughputSummary() string {
        var inUnit, outUnit string
        units := []string{"B", "KB", "MB", "GB"}
diff --git a/proxy/webrtcconn.go b/proxy/lib/webrtcconn.go
similarity index 99%
rename from proxy/webrtcconn.go
rename to proxy/lib/webrtcconn.go
index 5d95919..5c6192b 100644
--- a/proxy/webrtcconn.go
+++ b/proxy/lib/webrtcconn.go
@@ -1,4 +1,4 @@
-package main
+package snowflake
 
 import (
        "fmt"
diff --git a/proxy/main.go b/proxy/main.go
new file mode 100644
index 0000000..12b3752
--- /dev/null
+++ b/proxy/main.go
@@ -0,0 +1,48 @@
+package main
+
+import (
+       "flag"
+       "io"
+       "log"
+       "os"
+
+       "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
+       "git.torproject.org/pluggable-transports/snowflake.git/proxy/lib"
+)
+
+func main() {
+       capacity := flag.Int("capacity", 10, "maximum concurrent clients")
+       stunURL := flag.String("stun", snowflake.DefaultSTUNURL, "broker URL")
+       logFilename := flag.String("log", "", "log filename")
+       rawBrokerURL := flag.String("broker", snowflake.DefaultBrokerURL, 
"broker URL")
+       unsafeLogging := flag.Bool("unsafe-logging", false, "prevent logs from 
being scrubbed")
+       keepLocalAddresses := flag.Bool("keep-local-addresses", false, "keep 
local LAN address ICE candidates")
+       relayURL := flag.String("relay", snowflake.DefaultRelayURL, "websocket 
relay URL")
+
+       flag.Parse()
+
+       sf := snowflake.SnowflakeProxy{
+               Capacity:           uint(*capacity),
+               StunURL:            *stunURL,
+               RawBrokerURL:       *rawBrokerURL,
+               KeepLocalAddresses: *keepLocalAddresses,
+               RelayURL:           *relayURL,
+               LogOutput:          os.Stderr,
+       }
+
+       if *logFilename != "" {
+               f, err := os.OpenFile(*logFilename, 
os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
+               if err != nil {
+                       log.Fatal(err)
+               }
+               defer f.Close()
+               sf.LogOutput = io.MultiWriter(os.Stderr, f)
+       }
+       if *unsafeLogging {
+               log.SetOutput(sf.LogOutput)
+       } else {
+               log.SetOutput(&safelog.LogScrubber{Output: sf.LogOutput})
+       }
+
+       sf.Start()
+}

_______________________________________________
tor-commits mailing list
tor-commits@lists.torproject.org
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits

Reply via email to