This is an automated email from the ASF dual-hosted git repository.

dcelasun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 397645a  THRIFT-5069: Make TDeserializer resource pool friendly
397645a is described below

commit 397645ac24874b6f54d88b2700e56be090753825
Author: Yuxuan 'fishy' Wang <[email protected]>
AuthorDate: Sat Jan 18 12:55:51 2020 -0800

    THRIFT-5069: Make TDeserializer resource pool friendly
    
    Client: go
    
    This change improves performance when using TDeserializer with a
    resource pool. See https://issues.apache.org/jira/browse/THRIFT-5069 for
    more context.
    
    Also add TSerializerPool and TDeserializerPool, which are thread-safe
    versions of TSerializer and TDeserializer. Benchmark result shows that
    they are both faster and use less memory than the plain version:
    
        $ go test -bench Serializer -benchmem
        goos: darwin
        goarch: amd64
        BenchmarkSerializer/baseline-8            577558              1930 
ns/op             512 B/op          6 allocs/op
        BenchmarkSerializer/plain-8               452712              2638 
ns/op            2976 B/op         16 allocs/op
        BenchmarkSerializer/pool-8                591698              2032 
ns/op             512 B/op          6 allocs/op
        PASS
---
 CHANGES.md                       |   6 +
 lib/go/thrift/deserializer.go    |  46 +++++-
 lib/go/thrift/serializer.go      |  34 ++++
 lib/go/thrift/serializer_test.go | 325 ++++++++++++++++++++++++++++++---------
 4 files changed, 332 insertions(+), 79 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index e179a63..1dddab9 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -7,10 +7,16 @@
 - [THRIFT-4990](https://issues.apache.org/jira/browse/THRIFT-4990) - Upgrade 
to .NET Core 3.1 (LTS)
 - [THRIFT-4981](https://issues.apache.org/jira/browse/THRIFT-4981) - Remove 
deprecated netcore bindings from the code base
 - [THRIFT-5006](https://issues.apache.org/jira/browse/THRIFT-5006) - Implement 
DEFAULT_MAX_LENGTH at TFramedTransport
+- [THRIFT-5069](https://issues.apache.org/jira/browse/THRIFT-5069) - In Go 
library TDeserializer.Transport is now typed \*TMemoryBuffer instead of 
TTransport
 
 ### Java
 
 - [THRIFT-5022](https://issues.apache.org/jira/browse/THRIFT-5022) - 
TIOStreamTransport.isOpen returns true for one-sided transports (see 
THRIFT-2530).
+
+### Go
+
+- [THRIFT-5069](https://issues.apache.org/jira/browse/THRIFT-5069) - Add 
TSerializerPool and TDeserializerPool, which are thread-safe versions of 
TSerializer and TDeserializer.
+
 ## 0.13.0
 
 ### New Languages
diff --git a/lib/go/thrift/deserializer.go b/lib/go/thrift/deserializer.go
index 91a0983..2ab8214 100644
--- a/lib/go/thrift/deserializer.go
+++ b/lib/go/thrift/deserializer.go
@@ -19,14 +19,17 @@
 
 package thrift
 
+import (
+       "sync"
+)
+
 type TDeserializer struct {
-       Transport TTransport
+       Transport *TMemoryBuffer
        Protocol  TProtocol
 }
 
 func NewTDeserializer() *TDeserializer {
-       var transport TTransport
-       transport = NewTMemoryBufferLen(1024)
+       transport := NewTMemoryBufferLen(1024)
 
        protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport)
 
@@ -36,6 +39,8 @@ func NewTDeserializer() *TDeserializer {
 }
 
 func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) {
+       t.Transport.Reset()
+
        err = nil
        if _, err = t.Transport.Write([]byte(s)); err != nil {
                return
@@ -47,6 +52,8 @@ func (t *TDeserializer) ReadString(msg TStruct, s string) 
(err error) {
 }
 
 func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) {
+       t.Transport.Reset()
+
        err = nil
        if _, err = t.Transport.Write(b); err != nil {
                return
@@ -56,3 +63,36 @@ func (t *TDeserializer) Read(msg TStruct, b []byte) (err 
error) {
        }
        return
 }
+
+// TDeserializerPool is the thread-safe version of TDeserializer,
+// it uses resource pool of TDeserializer under the hood.
+//
+// It must be initialized with NewTDeserializerPool.
+type TDeserializerPool struct {
+       pool sync.Pool
+}
+
+// NewTDeserializerPool creates a new TDeserializerPool.
+//
+// NewTDeserializer can be used as the arg here.
+func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool {
+       return &TDeserializerPool{
+               pool: sync.Pool{
+                       New: func() interface{} {
+                               return f()
+                       },
+               },
+       }
+}
+
+func (t *TDeserializerPool) ReadString(msg TStruct, s string) error {
+       d := t.pool.Get().(*TDeserializer)
+       defer t.pool.Put(d)
+       return d.ReadString(msg, s)
+}
+
+func (t *TDeserializerPool) Read(msg TStruct, b []byte) error {
+       d := t.pool.Get().(*TDeserializer)
+       defer t.pool.Put(d)
+       return d.Read(msg, b)
+}
diff --git a/lib/go/thrift/serializer.go b/lib/go/thrift/serializer.go
index 1ff4d37..d85d204 100644
--- a/lib/go/thrift/serializer.go
+++ b/lib/go/thrift/serializer.go
@@ -21,6 +21,7 @@ package thrift
 
 import (
        "context"
+       "sync"
 )
 
 type TSerializer struct {
@@ -77,3 +78,36 @@ func (t *TSerializer) Write(ctx context.Context, msg 
TStruct) (b []byte, err err
        b = append(b, t.Transport.Bytes()...)
        return
 }
+
+// TSerializerPool is the thread-safe version of TSerializer, it uses resource
+// pool of TSerializer under the hood.
+//
+// It must be initialized with NewTSerializerPool.
+type TSerializerPool struct {
+       pool sync.Pool
+}
+
+// NewTSerializerPool creates a new TSerializerPool.
+//
+// NewTSerializer can be used as the arg here.
+func NewTSerializerPool(f func() *TSerializer) *TSerializerPool {
+       return &TSerializerPool{
+               pool: sync.Pool{
+                       New: func() interface{} {
+                               return f()
+                       },
+               },
+       }
+}
+
+func (t *TSerializerPool) WriteString(ctx context.Context, msg TStruct) 
(string, error) {
+       s := t.pool.Get().(*TSerializer)
+       defer t.pool.Put(s)
+       return s.WriteString(ctx, msg)
+}
+
+func (t *TSerializerPool) Write(ctx context.Context, msg TStruct) ([]byte, 
error) {
+       s := t.pool.Get().(*TSerializer)
+       defer t.pool.Put(s)
+       return s.Write(ctx, msg)
+}
diff --git a/lib/go/thrift/serializer_test.go b/lib/go/thrift/serializer_test.go
index 32227ef..52ebdca 100644
--- a/lib/go/thrift/serializer_test.go
+++ b/lib/go/thrift/serializer_test.go
@@ -23,122 +23,193 @@ import (
        "context"
        "errors"
        "fmt"
+       "sync"
+       "sync/atomic"
        "testing"
+       "testing/quick"
 )
 
 type ProtocolFactory interface {
        GetProtocol(t TTransport) TProtocol
 }
 
-func compareStructs(m, m1 MyTestStruct) (bool, error) {
+func compareStructs(m, m1 MyTestStruct) error {
        switch {
        case m.On != m1.On:
-               return false, errors.New("Boolean not equal")
+               return errors.New("Boolean not equal")
        case m.B != m1.B:
-               return false, errors.New("Byte not equal")
+               return errors.New("Byte not equal")
        case m.Int16 != m1.Int16:
-               return false, errors.New("Int16 not equal")
+               return errors.New("Int16 not equal")
        case m.Int32 != m1.Int32:
-               return false, errors.New("Int32 not equal")
+               return errors.New("Int32 not equal")
        case m.Int64 != m1.Int64:
-               return false, errors.New("Int64 not equal")
+               return errors.New("Int64 not equal")
        case m.D != m1.D:
-               return false, errors.New("Double not equal")
+               return errors.New("Double not equal")
        case m.St != m1.St:
-               return false, errors.New("String not equal")
+               return errors.New("String not equal")
 
        case len(m.Bin) != len(m1.Bin):
-               return false, errors.New("Binary size not equal")
+               return errors.New("Binary size not equal")
        case len(m.Bin) == len(m1.Bin):
                for i := range m.Bin {
                        if m.Bin[i] != m1.Bin[i] {
-                               return false, errors.New("Binary not equal")
+                               return errors.New("Binary not equal")
                        }
                }
        case len(m.StringMap) != len(m1.StringMap):
-               return false, errors.New("StringMap size not equal")
+               return errors.New("StringMap size not equal")
        case len(m.StringList) != len(m1.StringList):
-               return false, errors.New("StringList size not equal")
+               return errors.New("StringList size not equal")
        case len(m.StringSet) != len(m1.StringSet):
-               return false, errors.New("StringSet size not equal")
+               return errors.New("StringSet size not equal")
 
        case m.E != m1.E:
-               return false, errors.New("MyTestEnum not equal")
+               return errors.New("MyTestEnum not equal")
 
        default:
-               return true, nil
+               return nil
 
        }
-       return true, nil
+       return nil
 }
 
-func ProtocolTest1(test *testing.T, pf ProtocolFactory) (bool, error) {
+type serializer interface {
+       WriteString(context.Context, TStruct) (string, error)
+}
+
+type deserializer interface {
+       ReadString(TStruct, string) error
+}
+
+func plainSerializer(pf ProtocolFactory) serializer {
        t := NewTSerializer()
        t.Protocol = pf.GetProtocol(t.Transport)
-       var m = MyTestStruct{}
-       m.On = true
-       m.B = int8(0)
-       m.Int16 = 1
-       m.Int32 = 2
-       m.Int64 = 3
-       m.D = 4.1
-       m.St = "Test"
-       m.Bin = make([]byte, 10)
-       m.StringMap = make(map[string]string, 5)
-       m.StringList = make([]string, 5)
-       m.StringSet = make(map[string]struct{}, 5)
-       m.E = 2
-
-       s, err := t.WriteString(context.Background(), &m)
-       if err != nil {
-               return false, errors.New(fmt.Sprintf("Unable to Serialize 
struct\n\t %s", err))
-       }
+       return t
+}
 
-       t1 := NewTDeserializer()
-       t1.Protocol = pf.GetProtocol(t1.Transport)
-       var m1 = MyTestStruct{}
-       if err = t1.ReadString(&m1, s); err != nil {
-               return false, errors.New(fmt.Sprintf("Unable to Deserialize 
struct\n\t %s", err))
+func poolSerializer(pf ProtocolFactory) serializer {
+       return NewTSerializerPool(
+               func() *TSerializer {
+                       return plainSerializer(pf).(*TSerializer)
+               },
+       )
+}
 
-       }
+func plainDeserializer(pf ProtocolFactory) deserializer {
+       d := NewTDeserializer()
+       d.Protocol = pf.GetProtocol(d.Transport)
+       return d
+}
 
-       return compareStructs(m, m1)
+func poolDeserializer(pf ProtocolFactory) deserializer {
+       return NewTDeserializerPool(
+               func() *TDeserializer {
+                       return plainDeserializer(pf).(*TDeserializer)
+               },
+       )
+}
 
+type constructors struct {
+       Label        string
+       Serializer   func(pf ProtocolFactory) serializer
+       Deserializer func(pf ProtocolFactory) deserializer
 }
 
-func ProtocolTest2(test *testing.T, pf ProtocolFactory) (bool, error) {
-       t := NewTSerializer()
-       t.Protocol = pf.GetProtocol(t.Transport)
-       var m = MyTestStruct{}
-       m.On = false
-       m.B = int8(0)
-       m.Int16 = 1
-       m.Int32 = 2
-       m.Int64 = 3
-       m.D = 4.1
-       m.St = "Test"
-       m.Bin = make([]byte, 10)
-       m.StringMap = make(map[string]string, 5)
-       m.StringList = make([]string, 5)
-       m.StringSet = make(map[string]struct{}, 5)
-       m.E = 2
-
-       s, err := t.WriteString(context.Background(), &m)
-       if err != nil {
-               return false, errors.New(fmt.Sprintf("Unable to Serialize 
struct\n\t %s", err))
+var implementations = []constructors{
+       {
+               Label:        "plain",
+               Serializer:   plainSerializer,
+               Deserializer: plainDeserializer,
+       },
+       {
+               Label:        "pool",
+               Serializer:   poolSerializer,
+               Deserializer: poolDeserializer,
+       },
+}
 
-       }
+func ProtocolTest1(t *testing.T, pf ProtocolFactory) {
+       for _, impl := range implementations {
+               t.Run(
+                       impl.Label,
+                       func(test *testing.T) {
+                               t := impl.Serializer(pf)
+                               var m = MyTestStruct{}
+                               m.On = true
+                               m.B = int8(0)
+                               m.Int16 = 1
+                               m.Int32 = 2
+                               m.Int64 = 3
+                               m.D = 4.1
+                               m.St = "Test"
+                               m.Bin = make([]byte, 10)
+                               m.StringMap = make(map[string]string, 5)
+                               m.StringList = make([]string, 5)
+                               m.StringSet = make(map[string]struct{}, 5)
+                               m.E = 2
+
+                               s, err := t.WriteString(context.Background(), 
&m)
+                               if err != nil {
+                                       test.Fatalf("Unable to Serialize 
struct: %v", err)
+
+                               }
+
+                               t1 := impl.Deserializer(pf)
+                               var m1 MyTestStruct
+                               if err = t1.ReadString(&m1, s); err != nil {
+                                       test.Fatalf("Unable to Deserialize 
struct: %v", err)
 
-       t1 := NewTDeserializer()
-       t1.Protocol = pf.GetProtocol(t1.Transport)
-       var m1 = MyTestStruct{}
-       if err = t1.ReadString(&m1, s); err != nil {
-               return false, errors.New(fmt.Sprintf("Unable to Deserialize 
struct\n\t %s", err))
+                               }
 
+                               if err := compareStructs(m, m1); err != nil {
+                                       test.Error(err)
+                               }
+                       },
+               )
        }
+}
+
+func ProtocolTest2(t *testing.T, pf ProtocolFactory) {
+       for _, impl := range implementations {
+               t.Run(
+                       impl.Label,
+                       func(test *testing.T) {
+                               t := impl.Serializer(pf)
+                               var m = MyTestStruct{}
+                               m.On = false
+                               m.B = int8(0)
+                               m.Int16 = 1
+                               m.Int32 = 2
+                               m.Int64 = 3
+                               m.D = 4.1
+                               m.St = "Test"
+                               m.Bin = make([]byte, 10)
+                               m.StringMap = make(map[string]string, 5)
+                               m.StringList = make([]string, 5)
+                               m.StringSet = make(map[string]struct{}, 5)
+                               m.E = 2
+
+                               s, err := t.WriteString(context.Background(), 
&m)
+                               if err != nil {
+                                       test.Fatalf("Unable to Serialize 
struct: %v", err)
+
+                               }
 
-       return compareStructs(m, m1)
+                               t1 := impl.Deserializer(pf)
+                               var m1 MyTestStruct
+                               if err = t1.ReadString(&m1, s); err != nil {
+                                       test.Fatalf("Unable to Deserialize 
struct: %v", err)
 
+                               }
+
+                               if err := compareStructs(m, m1); err != nil {
+                                       test.Error(err)
+                               }
+                       },
+               )
+       }
 }
 
 func TestSerializer(t *testing.T) {
@@ -150,21 +221,123 @@ func TestSerializer(t *testing.T) {
        //protocol_factories["SimpleJSON"] = NewTSimpleJSONProtocolFactory() - 
write only, can't be read back by design
        protocol_factories["JSON"] = NewTJSONProtocolFactory()
 
-       var tests map[string]func(*testing.T, ProtocolFactory) (bool, error)
-       tests = make(map[string]func(*testing.T, ProtocolFactory) (bool, error))
+       tests := make(map[string]func(*testing.T, ProtocolFactory))
        tests["Test 1"] = ProtocolTest1
        tests["Test 2"] = ProtocolTest2
        //tests["Test 3"] = ProtocolTest3 // Example of how to add additional 
tests
 
        for name, pf := range protocol_factories {
+               t.Run(
+                       name,
+                       func(t *testing.T) {
+                               for label, f := range tests {
+                                       t.Run(
+                                               label,
+                                               func(t *testing.T) {
+                                                       f(t, pf)
+                                               },
+                                       )
+                               }
+                       },
+               )
+       }
+
+}
 
-               for test, f := range tests {
+func TestSerializerPoolAsync(t *testing.T) {
+       var wg sync.WaitGroup
+       var counter int64
+       s := NewTSerializerPool(NewTSerializer)
+       d := NewTDeserializerPool(NewTDeserializer)
+       f := func(i int64) bool {
+               wg.Add(1)
+               go func() {
+                       defer wg.Done()
+                       t.Run(
+                               fmt.Sprintf("#%d-%d", atomic.AddInt64(&counter, 
1), i),
+                               func(t *testing.T) {
+                                       m := MyTestStruct{
+                                               Int64: i,
+                                       }
+                                       str, err := 
s.WriteString(context.Background(), &m)
+                                       if err != nil {
+                                               t.Fatal("serialize:", err)
+                                       }
+                                       var m1 MyTestStruct
+                                       if err = d.ReadString(&m1, str); err != 
nil {
+                                               t.Fatal("deserialize:", err)
 
-                       if s, err := f(t, pf); !s || err != nil {
-                               t.Errorf("%s Failed for %s protocol\n\t %s", 
test, name, err)
-                       }
+                                       }
 
-               }
+                                       if err := compareStructs(m, m1); err != 
nil {
+                                               t.Error(err)
+                                       }
+                               },
+                       )
+               }()
+               return true
+       }
+       quick.Check(f, nil)
+       wg.Wait()
+}
+
+func BenchmarkSerializer(b *testing.B) {
+       sharedSerializer := NewTSerializer()
+       poolSerializer := NewTSerializerPool(NewTSerializer)
+       sharedDeserializer := NewTDeserializer()
+       poolDeserializer := NewTDeserializerPool(NewTDeserializer)
+
+       cases := []struct {
+               Label        string
+               Serializer   func() serializer
+               Deserializer func() deserializer
+       }{
+               {
+                       // Baseline uses shared plain serializer/deserializer
+                       Label: "baseline",
+                       Serializer: func() serializer {
+                               return sharedSerializer
+                       },
+                       Deserializer: func() deserializer {
+                               return sharedDeserializer
+                       },
+               },
+               {
+                       // Plain creates new serializer/deserializer on every 
run,
+                       // as that's how it's used in real world
+                       Label: "plain",
+                       Serializer: func() serializer {
+                               return NewTSerializer()
+                       },
+                       Deserializer: func() deserializer {
+                               return NewTDeserializer()
+                       },
+               },
+               {
+                       // Pool uses the shared pool serializer/deserializer
+                       Label: "pool",
+                       Serializer: func() serializer {
+                               return poolSerializer
+                       },
+                       Deserializer: func() deserializer {
+                               return poolDeserializer
+                       },
+               },
        }
 
+       for _, c := range cases {
+               b.Run(
+                       c.Label,
+                       func(b *testing.B) {
+                               for i := 0; i < b.N; i++ {
+                                       s := c.Serializer()
+                                       m := MyTestStruct{}
+                                       str, _ := 
s.WriteString(context.Background(), &m)
+                                       var m1 MyTestStruct
+                                       d := c.Deserializer()
+                                       d.ReadString(&m1, str)
+                               }
+                       },
+               )
+       }
 }

Reply via email to