This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new d582a86  THRIFT-5490: Use pooled buffer for THeaderTransport
d582a86 is described below

commit d582a861426c43c869e71d8d6ce598a33cbab316
Author: Yuxuan 'fishy' Wang <[email protected]>
AuthorDate: Thu Dec 16 14:44:47 2021 -0800

    THRIFT-5490: Use pooled buffer for THeaderTransport
    
    Client: go
    
    Instead of binding 2 buffers (read/write) to each THeaderTransport, grab
    one from the pool to be used for the whole read/write, and return it
    back to the pool after the read/write is done. This would help reduce
    the memory footprint from idle connections.
---
 lib/go/thrift/buf_pool.go              | 52 ++++++++++++++++++++++++++++++
 lib/go/thrift/header_transport.go      | 42 +++++++++++++-----------
 lib/go/thrift/header_transport_test.go | 59 ++++++++++++++++++++++++++++++++--
 3 files changed, 133 insertions(+), 20 deletions(-)

diff --git a/lib/go/thrift/buf_pool.go b/lib/go/thrift/buf_pool.go
new file mode 100644
index 0000000..9708ea0
--- /dev/null
+++ b/lib/go/thrift/buf_pool.go
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "bytes"
+       "sync"
+)
+
+var bufPool = sync.Pool{
+       New: func() interface{} {
+               return new(bytes.Buffer)
+       },
+}
+
+// getBufFromPool gets a buffer out of the pool and guarantees that it's reset
+// before return.
+func getBufFromPool() *bytes.Buffer {
+       buf := bufPool.Get().(*bytes.Buffer)
+       buf.Reset()
+       return buf
+}
+
+// returnBufToPool returns a buffer to the pool, and sets it to nil to avoid
+// accidental usage after it's returned.
+//
+// You usually want to use it this way:
+//
+//     buf := getBufFromPool()
+//     defer returnBufToPool(&buf)
+//     // use buf
+func returnBufToPool(buf **bytes.Buffer) {
+       bufPool.Put(*buf)
+       *buf = nil
+}
diff --git a/lib/go/thrift/header_transport.go 
b/lib/go/thrift/header_transport.go
index f5736df..5ec0454 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -28,7 +28,6 @@ import (
        "errors"
        "fmt"
        "io"
-       "io/ioutil"
 )
 
 // Size in bytes for 32-bit ints.
@@ -253,14 +252,14 @@ type THeaderTransport struct {
        // Reading related variables.
        reader *bufio.Reader
        // When frame is detected, we read the frame fully into frameBuffer.
-       frameBuffer bytes.Buffer
+       frameBuffer *bytes.Buffer
        // When it's non-nil, Read should read from frameReader instead of
        // reader, and EOF error indicates end of frame instead of end of all
        // transport.
        frameReader io.ReadCloser
 
        // Writing related variables
-       writeBuffer     bytes.Buffer
+       writeBuffer     *bytes.Buffer
        writeTransforms []THeaderTransformID
 
        clientType clientType
@@ -370,11 +369,14 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) 
error {
        t.reader.Discard(size32)
 
        // Read the frame fully into frameBuffer.
-       _, err = io.CopyN(&t.frameBuffer, t.reader, int64(frameSize))
+       if t.frameBuffer == nil {
+               t.frameBuffer = getBufFromPool()
+       }
+       _, err = io.CopyN(t.frameBuffer, t.reader, int64(frameSize))
        if err != nil {
                return err
        }
-       t.frameReader = ioutil.NopCloser(&t.frameBuffer)
+       t.frameReader = io.NopCloser(t.frameBuffer)
 
        // Peek and handle the next 32 bits.
        buf = t.frameBuffer.Bytes()[:size32]
@@ -405,7 +407,7 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) 
error {
 // It closes frameReader, and also resets frame related states.
 func (t *THeaderTransport) endOfFrame() error {
        defer func() {
-               t.frameBuffer.Reset()
+               returnBufToPool(&t.frameBuffer)
                t.frameReader = nil
        }()
        return t.frameReader.Close()
@@ -418,7 +420,7 @@ func (t *THeaderTransport) parseHeaders(ctx 
context.Context, frameSize uint32) e
 
        var err error
        var meta headerMeta
-       if err = binary.Read(&t.frameBuffer, binary.BigEndian, &meta); err != 
nil {
+       if err = binary.Read(t.frameBuffer, binary.BigEndian, &meta); err != 
nil {
                return err
        }
        frameSize -= headerMetaSize
@@ -432,7 +434,7 @@ func (t *THeaderTransport) parseHeaders(ctx 
context.Context, frameSize uint32) e
                )
        }
        headerBuf := NewTMemoryBuffer()
-       _, err = io.CopyN(headerBuf, &t.frameBuffer, headerLength)
+       _, err = io.CopyN(headerBuf, t.frameBuffer, headerLength)
        if err != nil {
                return err
        }
@@ -454,7 +456,7 @@ func (t *THeaderTransport) parseHeaders(ctx 
context.Context, frameSize uint32) e
        }
        if transformCount > 0 {
                reader := NewTransformReaderWithCapacity(
-                       &t.frameBuffer,
+                       t.frameBuffer,
                        int(transformCount),
                )
                t.frameReader = reader
@@ -569,16 +571,19 @@ func (t *THeaderTransport) Read(p []byte) (read int, err 
error) {
 //
 // You need to call Flush to actually write them to the transport.
 func (t *THeaderTransport) Write(p []byte) (int, error) {
+       if t.writeBuffer == nil {
+               t.writeBuffer = getBufFromPool()
+       }
        return t.writeBuffer.Write(p)
 }
 
 // Flush writes the appropriate header and the write buffer to the underlying 
transport.
 func (t *THeaderTransport) Flush(ctx context.Context) error {
-       if t.writeBuffer.Len() == 0 {
+       if t.writeBuffer == nil || t.writeBuffer.Len() == 0 {
                return nil
        }
 
-       defer t.writeBuffer.Reset()
+       defer returnBufToPool(&t.writeBuffer)
 
        switch t.clientType {
        default:
@@ -628,24 +633,25 @@ func (t *THeaderTransport) Flush(ctx context.Context) 
error {
                        }
                }
 
-               var payload bytes.Buffer
+               payload := getBufFromPool()
+               defer returnBufToPool(&payload)
                meta := headerMeta{
                        MagicFlags:   THeaderHeaderMagic + 
t.Flags&THeaderFlagsMask,
                        SequenceID:   t.SequenceID,
                        HeaderLength: uint16(headers.Len() / 4),
                }
-               if err := binary.Write(&payload, binary.BigEndian, meta); err 
!= nil {
+               if err := binary.Write(payload, binary.BigEndian, meta); err != 
nil {
                        return NewTTransportExceptionFromError(err)
                }
-               if _, err := io.Copy(&payload, headers); err != nil {
+               if _, err := io.Copy(payload, headers); err != nil {
                        return NewTTransportExceptionFromError(err)
                }
 
-               writer, err := NewTransformWriter(&payload, t.writeTransforms)
+               writer, err := NewTransformWriter(payload, t.writeTransforms)
                if err != nil {
                        return NewTTransportExceptionFromError(err)
                }
-               if _, err := io.Copy(writer, &t.writeBuffer); err != nil {
+               if _, err := io.Copy(writer, t.writeBuffer); err != nil {
                        return NewTTransportExceptionFromError(err)
                }
                if err := writer.Close(); err != nil {
@@ -659,7 +665,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error 
{
                        return NewTTransportExceptionFromError(err)
                }
                // Then write the payload
-               if _, err := io.Copy(t.transport, &payload); err != nil {
+               if _, err := io.Copy(t.transport, payload); err != nil {
                        return NewTTransportExceptionFromError(err)
                }
 
@@ -671,7 +677,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error 
{
                }
                fallthrough
        case clientUnframedBinary, clientUnframedCompact:
-               if _, err := io.Copy(t.transport, &t.writeBuffer); err != nil {
+               if _, err := io.Copy(t.transport, t.writeBuffer); err != nil {
                        return NewTTransportExceptionFromError(err)
                }
        }
diff --git a/lib/go/thrift/header_transport_test.go 
b/lib/go/thrift/header_transport_test.go
index 65e69ee..25ba8d3 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -23,7 +23,6 @@ import (
        "context"
        "fmt"
        "io"
-       "io/ioutil"
        "strings"
        "testing"
        "testing/quick"
@@ -87,7 +86,7 @@ func testTHeaderHeadersReadWriteProtocolID(t *testing.T, 
protoID THeaderProtocol
        if err := reader.ReadFrame(context.Background()); err != nil {
                t.Errorf("reader.ReadFrame returned error: %v", err)
        }
-       read, err := ioutil.ReadAll(reader)
+       read, err := io.ReadAll(reader)
        if err != nil {
                t.Errorf("Read returned error: %v", err)
        }
@@ -305,3 +304,59 @@ func TestSetTHeaderTransportProtocolID(t *testing.T) {
                t.Errorf("Expected protocol id %v, got %v", expected, actual)
        }
 }
+
+func TestTHeaderTransportReuseTransport(t *testing.T) {
+       const (
+               content = "Hello, world!"
+               n       = 10
+       )
+       trans := NewTMemoryBuffer()
+       reader := NewTHeaderTransport(trans)
+       writer := NewTHeaderTransport(trans)
+
+       t.Run("pair", func(t *testing.T) {
+               for i := 0; i < n; i++ {
+                       // write
+                       if _, err := io.Copy(writer, 
strings.NewReader(content)); err != nil {
+                               t.Fatalf("Failed to write on #%d: %v", i, err)
+                       }
+                       if err := writer.Flush(context.Background()); err != 
nil {
+                               t.Fatalf("Failed to flush on #%d: %v", i, err)
+                       }
+
+                       // read
+                       read, err := io.ReadAll(reader)
+                       if err != nil {
+                               t.Errorf("Failed to read on #%d: %v", i, err)
+                       }
+                       if string(read) != content {
+                               t.Errorf("Read #%d: want %q, got %q", i, 
content, read)
+                       }
+               }
+       })
+
+       t.Run("batched", func(t *testing.T) {
+               // write
+               for i := 0; i < n; i++ {
+                       if _, err := io.Copy(writer, 
strings.NewReader(content)); err != nil {
+                               t.Fatalf("Failed to write on #%d: %v", i, err)
+                       }
+                       if err := writer.Flush(context.Background()); err != 
nil {
+                               t.Fatalf("Failed to flush on #%d: %v", i, err)
+                       }
+               }
+
+               // read
+               for i := 0; i < n; i++ {
+                       buf := make([]byte, len(content))
+                       n, err := reader.Read(buf)
+                       if err != nil {
+                               t.Errorf("Failed to read on #%d: %v", i, err)
+                       }
+                       read := string(buf[:n])
+                       if string(read) != content {
+                               t.Errorf("Read #%d: want %q, got %q", i, 
content, read)
+                       }
+               }
+       })
+}

Reply via email to