commit 2bf0e5457e88f54344f895f4754cf98c05c1d7b0
Author: Serene Han <[email protected]>
Date:   Tue Jun 14 17:07:21 2016 -0700

    pull copyLoop out of goroutine, better pop and reset
---
 client/client_test.go | 38 +++++++++++++++++++++++++++++---------
 client/interfaces.go  |  7 +++----
 client/peers.go       | 30 +++++++++++++++++++-----------
 client/snowflake.go   | 21 ++++++++++++---------
 client/webrtc.go      |  9 +++++++--
 5 files changed, 70 insertions(+), 35 deletions(-)

diff --git a/client/client_test.go b/client/client_test.go
index 41d4870..b5236a0 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -65,9 +65,9 @@ func (f FakeSocksConn) Grant(addr *net.TCPAddr) error { 
return nil }
 
 type FakePeers struct{ toRelease *webRTCConn }
 
-func (f FakePeers) Collect() error          { return nil }
-func (f FakePeers) Pop() Snowflake          { return nil }
-func (f FakePeers) Melted() <-chan struct{} { return nil }
+func (f FakePeers) Collect() (Snowflake, error) { return &webRTCConn{}, nil }
+func (f FakePeers) Pop() Snowflake              { return nil }
+func (f FakePeers) Melted() <-chan struct{}     { return nil }
 
 func TestSnowflakeClient(t *testing.T) {
 
@@ -81,16 +81,16 @@ func TestSnowflakeClient(t *testing.T) {
 
                Convey("Collecting a Snowflake requires a Tongue.", func() {
                        p := NewPeers(1)
-                       err := p.Collect()
+                       _, err := p.Collect()
                        So(err, ShouldNotBeNil)
                        So(p.Count(), ShouldEqual, 0)
                        // Set the dialer so that collection is possible.
                        p.Tongue = FakeDialer{}
-                       err = p.Collect()
+                       _, err = p.Collect()
                        So(err, ShouldBeNil)
                        So(p.Count(), ShouldEqual, 1)
                        // S
-                       err = p.Collect()
+                       _, err = p.Collect()
                })
 
                Convey("Collection continues until capacity.", func() {
@@ -100,13 +100,13 @@ func TestSnowflakeClient(t *testing.T) {
                        // Fill up to capacity.
                        for i := 0; i < c; i++ {
                                fmt.Println("Adding snowflake ", i)
-                               err := p.Collect()
+                               _, err := p.Collect()
                                So(err, ShouldBeNil)
                                So(p.Count(), ShouldEqual, i+1)
                        }
                        // But adding another gives an error.
                        So(p.Count(), ShouldEqual, c)
-                       err := p.Collect()
+                       _, err := p.Collect()
                        So(err, ShouldNotBeNil)
                        So(p.Count(), ShouldEqual, c)
 
@@ -116,7 +116,7 @@ func TestSnowflakeClient(t *testing.T) {
                        So(s, ShouldNotBeNil)
                        So(p.Count(), ShouldEqual, c-1)
 
-                       err = p.Collect()
+                       _, err = p.Collect()
                        So(err, ShouldBeNil)
                        So(p.Count(), ShouldEqual, c)
                })
@@ -149,6 +149,26 @@ func TestSnowflakeClient(t *testing.T) {
                        So(p.Count(), ShouldEqual, 0)
                })
 
+               Convey("Pop skips over closed peers.", func() {
+                       p := NewPeers(4)
+                       p.Tongue = FakeDialer{}
+                       wc1, _ := p.Collect()
+                       wc2, _ := p.Collect()
+                       wc3, _ := p.Collect()
+                       So(wc1, ShouldNotBeNil)
+                       So(wc2, ShouldNotBeNil)
+                       So(wc3, ShouldNotBeNil)
+                       wc1.Close()
+                       r := p.Pop()
+                       So(p.Count(), ShouldEqual, 2)
+                       So(r, ShouldEqual, wc2)
+                       wc4, _ := p.Collect()
+                       wc2.Close()
+                       wc3.Close()
+                       r = p.Pop()
+                       So(r, ShouldEqual, wc4)
+               })
+
        })
 
        Convey("Snowflake", t, func() {
diff --git a/client/interfaces.go b/client/interfaces.go
index 502eefb..f18987a 100644
--- a/client/interfaces.go
+++ b/client/interfaces.go
@@ -17,10 +17,9 @@ type Resetter interface {
 // Interface for a single remote WebRTC peer.
 // In the Client context, "Snowflake" refers to the remote browser proxy.
 type Snowflake interface {
-       io.ReadWriter
+       io.ReadWriteCloser
        Resetter
        Connector
-       Close() error
 }
 
 // Interface for catching Snowflakes. (aka the remote dialer)
@@ -34,7 +33,7 @@ type SnowflakeCollector interface {
 
        // Add a Snowflake to the collection.
        // Implementation should decide how to connect and maintain the 
webRTCConn.
-       Collect() error
+       Collect() (Snowflake, error)
 
        // Remove and return the most available Snowflake from the collection.
        Pop() Snowflake
@@ -52,6 +51,6 @@ type SocksConnector interface {
 
 // Interface for the Snowflake's transport. (Typically just webrtc.DataChannel)
 type SnowflakeDataChannel interface {
+       io.Closer
        Send([]byte)
-       Close() error
 }
diff --git a/client/peers.go b/client/peers.go
index 74b804f..098dd81 100644
--- a/client/peers.go
+++ b/client/peers.go
@@ -40,34 +40,43 @@ func NewPeers(max int) *Peers {
 }
 
 // As part of |SnowflakeCollector| interface.
-func (p *Peers) Collect() error {
+func (p *Peers) Collect() (Snowflake, error) {
        cnt := p.Count()
        s := fmt.Sprintf("Currently at [%d/%d]", cnt, p.capacity)
        if cnt >= p.capacity {
                s := fmt.Sprintf("At capacity [%d/%d]", cnt, p.capacity)
-               return errors.New(s)
+               return nil, errors.New(s)
        }
        log.Println("WebRTC: Collecting a new Snowflake.", s)
        // Engage the Snowflake Catching interface, which must be available.
        if nil == p.Tongue {
-               return errors.New("Missing Tongue to catch Snowflakes with.")
+               return nil, errors.New("Missing Tongue to catch Snowflakes 
with.")
        }
+       // BUG: some broker conflict here.
        connection, err := p.Tongue.Catch()
-       if nil == connection || nil != err {
-               return err
+       if nil != err {
+               return nil, err
        }
        // Track new valid Snowflake in internal collection and pass along.
        p.activePeers.PushBack(connection)
        p.snowflakeChan <- connection
-       return nil
+       return connection, nil
 }
 
 // As part of |SnowflakeCollector| interface.
 func (p *Peers) Pop() Snowflake {
-       // Blocks until an available snowflake appears.
-       snowflake, ok := <-p.snowflakeChan
-       if !ok {
-               return nil
+       // Blocks until an available, valid snowflake appears.
+       var snowflake Snowflake
+       var ok bool
+       for nil == snowflake {
+               snowflake, ok = <-p.snowflakeChan
+               conn := snowflake.(*webRTCConn)
+               if !ok {
+                       return nil
+               }
+               if conn.closed {
+                       snowflake = nil
+               }
        }
        // Set to use the same rate-limited traffic logger to keep consistency.
        snowflake.(*webRTCConn).BytesLogger = p.BytesLogger
@@ -105,7 +114,6 @@ func (p *Peers) End() {
        p.melt <- struct{}{}
        cnt := p.Count()
        for e := p.activePeers.Front(); e != nil; {
-               log.Println(e, e.Value)
                next := e.Next()
                conn := e.Value.(*webRTCConn)
                conn.Close()
diff --git a/client/snowflake.go b/client/snowflake.go
index d122494..f172611 100644
--- a/client/snowflake.go
+++ b/client/snowflake.go
@@ -31,7 +31,7 @@ var handlerChan = make(chan int)
 func ConnectLoop(snowflakes SnowflakeCollector) {
        for {
                // Check if ending is necessary.
-               err := snowflakes.Collect()
+               _, err := snowflakes.Collect()
                if nil != err {
                        log.Println("WebRTC:", err,
                                " Retrying in", ReconnectTimeout, "seconds...")
@@ -51,6 +51,7 @@ func socksAcceptLoop(ln *pt.SocksListener, snowflakes 
SnowflakeCollector) error
        defer ln.Close()
        log.Println("Started SOCKS listener.")
        for {
+               log.Println("SOCKS listening...")
                conn, err := ln.AcceptSocks()
                log.Println("SOCKS accepted: ", conn.Req)
                if err != nil {
@@ -81,20 +82,22 @@ func handler(socks SocksConnector, snowflakes 
SnowflakeCollector) error {
                return errors.New("handler: Received invalid Snowflake")
        }
        defer socks.Close()
-       log.Println("---- Snowflake assigned ----")
+       log.Println("---- Handler: snowflake assigned ----")
        err := socks.Grant(&net.TCPAddr{IP: net.IPv4zero, Port: 0})
        if err != nil {
                return err
        }
 
-       // Begin exchanging data.
-       // BUG(serene): There's a leak here when multiplexed.
-       go copyLoop(socks, snowflake)
+       go func() {
+               // When WebRTC resets, close the SOCKS connection, which ends
+               // the copyLoop below and induces new handler.
+               snowflake.WaitForReset()
+               socks.Close()
+       }()
 
-       // When WebRTC resets, close the SOCKS connection, which induces new 
handler.
-       // TODO: Double check this / fix it.
-       snowflake.WaitForReset()
-       log.Println("---- Closed ---")
+       // Begin exchanging data.
+       copyLoop(socks, snowflake)
+       log.Println("---- Handler: closed ---")
        return nil
 }
 
diff --git a/client/webrtc.go b/client/webrtc.go
index 6cd5da6..4c7a3c8 100644
--- a/client/webrtc.go
+++ b/client/webrtc.go
@@ -31,7 +31,7 @@ type webRTCConn struct {
        BytesLogger
 }
 
-// Read bytes from remote WebRTC.
+// Read bytes from local SOCKS.
 // As part of |io.ReadWriter|
 func (c *webRTCConn) Read(b []byte) (int, error) {
        return c.recvPipe.Read(b)
@@ -62,11 +62,11 @@ func (c *webRTCConn) Close() error {
 
 // As part of |Resetter|
 func (c *webRTCConn) Reset() {
+       c.Close()
        go func() {
                c.reset <- struct{}{}
                log.Println("WebRTC resetting...")
        }()
-       c.Close()
 }
 
 // As part of |Resetter|
@@ -282,6 +282,11 @@ func (c *webRTCConn) cleanup() {
        if nil != c.errorChannel {
                close(c.errorChannel)
        }
+       // Close this side of the SOCKS pipe.
+       if nil != c.writePipe {
+               c.writePipe.Close()
+               c.writePipe = nil
+       }
        if nil != c.transport {
                log.Printf("WebRTC: closing DataChannel")
                dataChannel := c.transport

_______________________________________________
tor-commits mailing list
[email protected]
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits

Reply via email to