This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 999e6e3  THRIFT-5490: Use pooled buffer for TFramedTransport
999e6e3 is described below

commit 999e6e3bce217acb35b44440fd656cf169d47ed8
Author: Yuxuan 'fishy' Wang <[email protected]>
AuthorDate: Fri Dec 17 10:39:07 2021 -0800

    THRIFT-5490: Use pooled buffer for TFramedTransport
    
    Client: go
    
    Follow up on d582a8614, do the same thing on TFramedTransport.
    
    Also update the test on the implementation of THeaderTransport to make
    sure that small reads are not broken.
---
 lib/go/thrift/framed_transport.go      | 61 ++++++++++++++++++++--------
 lib/go/thrift/framed_transport_test.go | 73 ++++++++++++++++++++++++++++++++++
 lib/go/thrift/header_transport_test.go | 39 +++++++++++++++---
 3 files changed, 150 insertions(+), 23 deletions(-)

diff --git a/lib/go/thrift/framed_transport.go 
b/lib/go/thrift/framed_transport.go
index 2156dd7..c8bd35e 100644
--- a/lib/go/thrift/framed_transport.go
+++ b/lib/go/thrift/framed_transport.go
@@ -36,10 +36,10 @@ type TFramedTransport struct {
 
        cfg *TConfiguration
 
-       writeBuf bytes.Buffer
+       writeBuf *bytes.Buffer
 
        reader  *bufio.Reader
-       readBuf bytes.Buffer
+       readBuf *bytes.Buffer
 
        buffer [4]byte
 }
@@ -129,18 +129,29 @@ func (p *TFramedTransport) Close() error {
 }
 
 func (p *TFramedTransport) Read(buf []byte) (read int, err error) {
-       read, err = p.readBuf.Read(buf)
-       if err != io.EOF {
-               return
-       }
+       defer func() {
+               // Make sure we return the read buffer back to pool
+               // after we finished reading from it.
+               if p.readBuf != nil && p.readBuf.Len() == 0 {
+                       returnBufToPool(&p.readBuf)
+               }
+       }()
+
+       if p.readBuf != nil {
 
-       // For bytes.Buffer.Read, EOF would only happen when read is zero,
-       // but still, do a sanity check,
-       // in case that behavior is changed in a future version of go stdlib.
-       // When that happens, just return nil error,
-       // and let the caller call Read again to read the next frame.
-       if read > 0 {
-               return read, nil
+               read, err = p.readBuf.Read(buf)
+               if err != io.EOF {
+                       return
+               }
+
+               // For bytes.Buffer.Read, EOF would only happen when read is 
zero,
+               // but still, do a sanity check,
+               // in case that behavior is changed in a future version of go 
stdlib.
+               // When that happens, just return nil error,
+               // and let the caller call Read again to read the next frame.
+               if read > 0 {
+                       return read, nil
+               }
        }
 
        // Reaching here means that the last Read finished the last frame,
@@ -162,31 +173,39 @@ func (p *TFramedTransport) ReadByte() (c byte, err error) 
{
        return
 }
 
+func (p *TFramedTransport) ensureWriteBufferBeforeWrite() {
+       if p.writeBuf == nil {
+               p.writeBuf = getBufFromPool()
+       }
+}
+
 func (p *TFramedTransport) Write(buf []byte) (int, error) {
+       p.ensureWriteBufferBeforeWrite()
        n, err := p.writeBuf.Write(buf)
        return n, NewTTransportExceptionFromError(err)
 }
 
 func (p *TFramedTransport) WriteByte(c byte) error {
+       p.ensureWriteBufferBeforeWrite()
        return p.writeBuf.WriteByte(c)
 }
 
 func (p *TFramedTransport) WriteString(s string) (n int, err error) {
+       p.ensureWriteBufferBeforeWrite()
        return p.writeBuf.WriteString(s)
 }
 
 func (p *TFramedTransport) Flush(ctx context.Context) error {
+       defer returnBufToPool(&p.writeBuf)
        size := p.writeBuf.Len()
        buf := p.buffer[:4]
        binary.BigEndian.PutUint32(buf, uint32(size))
        _, err := p.transport.Write(buf)
        if err != nil {
-               p.writeBuf.Reset()
                return NewTTransportExceptionFromError(err)
        }
        if size > 0 {
-               if _, err := io.Copy(p.transport, &p.writeBuf); err != nil {
-                       p.writeBuf.Reset()
+               if _, err := io.Copy(p.transport, p.writeBuf); err != nil {
                        return NewTTransportExceptionFromError(err)
                }
        }
@@ -195,6 +214,11 @@ func (p *TFramedTransport) Flush(ctx context.Context) 
error {
 }
 
 func (p *TFramedTransport) readFrame() error {
+       if p.readBuf != nil {
+               returnBufToPool(&p.readBuf)
+       }
+       p.readBuf = getBufFromPool()
+
        buf := p.buffer[:4]
        if _, err := io.ReadFull(p.reader, buf); err != nil {
                return err
@@ -203,11 +227,14 @@ func (p *TFramedTransport) readFrame() error {
        if size > uint32(p.cfg.GetMaxFrameSize()) {
                return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, 
fmt.Sprintf("Incorrect frame size (%d)", size))
        }
-       _, err := io.CopyN(&p.readBuf, p.reader, int64(size))
+       _, err := io.CopyN(p.readBuf, p.reader, int64(size))
        return NewTTransportExceptionFromError(err)
 }
 
 func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) {
+       if p.readBuf == nil {
+               return 0
+       }
        return uint64(p.readBuf.Len())
 }
 
diff --git a/lib/go/thrift/framed_transport_test.go 
b/lib/go/thrift/framed_transport_test.go
index 8f683ef..4e7d9ca 100644
--- a/lib/go/thrift/framed_transport_test.go
+++ b/lib/go/thrift/framed_transport_test.go
@@ -20,6 +20,9 @@
 package thrift
 
 import (
+       "context"
+       "io"
+       "strings"
        "testing"
 )
 
@@ -27,3 +30,73 @@ func TestFramedTransport(t *testing.T) {
        trans := NewTFramedTransport(NewTMemoryBuffer())
        TransportTest(t, trans, trans)
 }
+
+func TestTFramedTransportReuseTransport(t *testing.T) {
+       const (
+               content = "Hello, world!"
+               n       = 10
+       )
+       trans := NewTMemoryBuffer()
+       reader := NewTFramedTransport(trans)
+       writer := NewTFramedTransport(trans)
+
+       t.Run("pair", func(t *testing.T) {
+               for i := 0; i < n; i++ {
+                       // write
+                       if _, err := io.Copy(writer, 
strings.NewReader(content)); err != nil {
+                               t.Fatalf("Failed to write on #%d: %v", i, err)
+                       }
+                       if err := writer.Flush(context.Background()); err != 
nil {
+                               t.Fatalf("Failed to flush on #%d: %v", i, err)
+                       }
+
+                       // read
+                       read, err := io.ReadAll(oneAtATimeReader{reader})
+                       if err != nil {
+                               t.Errorf("Failed to read on #%d: %v", i, err)
+                       }
+                       if string(read) != content {
+                               t.Errorf("Read #%d: want %q, got %q", i, 
content, read)
+                       }
+               }
+       })
+
+       t.Run("batched", func(t *testing.T) {
+               // write
+               for i := 0; i < n; i++ {
+                       if _, err := io.Copy(writer, 
strings.NewReader(content)); err != nil {
+                               t.Fatalf("Failed to write on #%d: %v", i, err)
+                       }
+                       if err := writer.Flush(context.Background()); err != 
nil {
+                               t.Fatalf("Failed to flush on #%d: %v", i, err)
+                       }
+               }
+
+               // read
+               for i := 0; i < n; i++ {
+                       const (
+                               size = len(content)
+                       )
+                       var buf []byte
+                       var err error
+                       if i%2 == 0 {
+                               // on even calls, use oneAtATimeReader to make
+                               // sure that small reads are fine
+                               buf, err = 
io.ReadAll(io.LimitReader(oneAtATimeReader{reader}, int64(size)))
+                       } else {
+                               // on odd calls, make sure that we don't read
+                               // more than written per frame
+                               buf = make([]byte, size*2)
+                               var n int
+                               n, err = reader.Read(buf)
+                               buf = buf[:n]
+                       }
+                       if err != nil {
+                               t.Errorf("Failed to read on #%d: %v", i, err)
+                       }
+                       if string(buf) != content {
+                               t.Errorf("Read #%d: want %q, got %q", i, 
content, buf)
+                       }
+               }
+       })
+}
diff --git a/lib/go/thrift/header_transport_test.go 
b/lib/go/thrift/header_transport_test.go
index 25ba8d3..44d0284 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -325,7 +325,7 @@ func TestTHeaderTransportReuseTransport(t *testing.T) {
                        }
 
                        // read
-                       read, err := io.ReadAll(reader)
+                       read, err := io.ReadAll(oneAtATimeReader{reader})
                        if err != nil {
                                t.Errorf("Failed to read on #%d: %v", i, err)
                        }
@@ -348,15 +348,42 @@ func TestTHeaderTransportReuseTransport(t *testing.T) {
 
                // read
                for i := 0; i < n; i++ {
-                       buf := make([]byte, len(content))
-                       n, err := reader.Read(buf)
+                       const (
+                               size = len(content)
+                       )
+                       var buf []byte
+                       var err error
+                       if i%2 == 0 {
+                               // on even calls, use oneAtATimeReader to make
+                               // sure that small reads are fine
+                               buf, err = 
io.ReadAll(io.LimitReader(oneAtATimeReader{reader}, int64(size)))
+                       } else {
+                               // on odd calls, make sure that we don't read
+                               // more than written per frame
+                               buf = make([]byte, size*2)
+                               var n int
+                               n, err = reader.Read(buf)
+                               buf = buf[:n]
+                       }
                        if err != nil {
                                t.Errorf("Failed to read on #%d: %v", i, err)
                        }
-                       read := string(buf[:n])
-                       if string(read) != content {
-                               t.Errorf("Read #%d: want %q, got %q", i, 
content, read)
+                       if string(buf) != content {
+                               t.Errorf("Read #%d: want %q, got %q", i, 
content, buf)
                        }
                }
        })
 }
+
+type oneAtATimeReader struct {
+       io.Reader
+}
+
+// oneAtATimeReader forces every Read call to only read 1 byte out,
+// thus forces the underlying reader's Read to be called multiple times.
+func (o oneAtATimeReader) Read(buf []byte) (int, error) {
+       if len(buf) < 1 {
+               return o.Reader.Read(buf)
+       }
+       return o.Reader.Read(buf[:1])
+}

Reply via email to