This is an automated email from the ASF dual-hosted git repository.

spacewander pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix-go-plugin-runner.git

commit e3d2617b44efc9268ac19a0498aea82c03ed567b
Author: spacewander <[email protected]>
AuthorDate: Thu May 20 17:01:05 2021 +0800

    fix: concurrent issues
---
 internal/plugin/conf.go                    | 19 ++++-------
 internal/plugin/conf_test.go               | 54 ++++++++++++++++++++++++++----
 internal/server/error.go                   | 11 +++---
 internal/server/error_test.go              |  6 ++--
 internal/server/server.go                  | 14 ++++++--
 internal/{server/error.go => util/pool.go} | 36 ++++++++------------
 6 files changed, 87 insertions(+), 53 deletions(-)

diff --git a/internal/plugin/conf.go b/internal/plugin/conf.go
index 6c79cb6..8e18ecc 100644
--- a/internal/plugin/conf.go
+++ b/internal/plugin/conf.go
@@ -16,9 +16,11 @@ package plugin
 
 import (
        "strconv"
+       "sync/atomic"
        "time"
 
        "github.com/ReneKroon/ttlcache/v2"
+       "github.com/apache/apisix-go-plugin-runner/internal/util"
        A6 "github.com/api7/ext-plugin-proto/go/A6"
        pc "github.com/api7/ext-plugin-proto/go/A6/PrepareConf"
        flatbuffers "github.com/google/flatbuffers/go"
@@ -31,8 +33,6 @@ type ConfEntry struct {
 type RuleConf []ConfEntry
 
 var (
-       builder = flatbuffers.NewBuilder(1024)
-
        cache        *ttlcache.Cache
        cacheCounter uint32 = 0
 )
@@ -41,20 +41,14 @@ func InitConfCache(ttl time.Duration) {
        cache = ttlcache.NewCache()
        cache.SetTTL(ttl)
        cache.SkipTTLExtensionOnHit(false)
+       cacheCounter = 0
 }
 
 func genCacheToken() uint32 {
-       cacheCounter++
-       if cacheCounter == 0 {
-               // overflow, skip 0 which means none
-               cacheCounter++
-       }
-       return cacheCounter
+       return atomic.AddUint32(&cacheCounter, 1)
 }
 
-func PrepareConf(buf []byte) ([]byte, error) {
-       builder.Reset()
-
+func PrepareConf(buf []byte) (*flatbuffers.Builder, error) {
        req := pc.GetRootAsReq(buf, 0)
        entries := make(RuleConf, req.ConfLength())
 
@@ -72,11 +66,12 @@ func PrepareConf(buf []byte) ([]byte, error) {
                return nil, err
        }
 
+       builder := util.GetBuilder()
        pc.RespStart(builder)
        pc.RespAddConfToken(builder, token)
        root := pc.RespEnd(builder)
        builder.Finish(root)
-       return builder.FinishedBytes(), nil
+       return builder, nil
 }
 
 func GetRuleConf(token uint32) (RuleConf, error) {
diff --git a/internal/plugin/conf_test.go b/internal/plugin/conf_test.go
index b4135a5..aa2ded4 100644
--- a/internal/plugin/conf_test.go
+++ b/internal/plugin/conf_test.go
@@ -15,6 +15,8 @@
 package plugin
 
 import (
+       "sort"
+       "sync"
        "testing"
        "time"
 
@@ -34,15 +36,52 @@ func TestPrepareConf(t *testing.T) {
        builder.Finish(root)
        b := builder.FinishedBytes()
 
-       out, _ := PrepareConf(b)
+       bd, _ := PrepareConf(b)
+       out := bd.FinishedBytes()
        resp := pc.GetRootAsResp(out, 0)
        assert.Equal(t, uint32(1), resp.ConfToken())
 
-       out, _ = PrepareConf(b)
+       bd, _ = PrepareConf(b)
+       out = bd.FinishedBytes()
        resp = pc.GetRootAsResp(out, 0)
        assert.Equal(t, uint32(2), resp.ConfToken())
 }
 
+func TestPrepareConfConcurrently(t *testing.T) {
+       InitConfCache(10 * time.Millisecond)
+
+       builder := flatbuffers.NewBuilder(1024)
+       pc.ReqStart(builder)
+       root := pc.ReqEnd(builder)
+       builder.Finish(root)
+       b := builder.FinishedBytes()
+
+       n := 10
+       var wg sync.WaitGroup
+       res := make([][]byte, n)
+       for i := 0; i < n; i++ {
+               wg.Add(1)
+               go func(i int) {
+                       bd, err := PrepareConf(b)
+                       assert.Nil(t, err)
+                       res[i] = bd.FinishedBytes()[:]
+                       wg.Done()
+               }(i)
+       }
+       wg.Wait()
+
+       tokens := make([]int, n)
+       for i := 0; i < n; i++ {
+               resp := pc.GetRootAsResp(res[i], 0)
+               tokens[i] = int(resp.ConfToken())
+       }
+
+       sort.Ints(tokens)
+       for i := 0; i < n; i++ {
+               assert.Equal(t, i+1, tokens[i])
+       }
+}
+
 func TestGetRuleConf(t *testing.T) {
        InitConfCache(1 * time.Millisecond)
        builder := flatbuffers.NewBuilder(1024)
@@ -51,15 +90,16 @@ func TestGetRuleConf(t *testing.T) {
        builder.Finish(root)
        b := builder.FinishedBytes()
 
-       out, _ := PrepareConf(b)
+       bd, _ := PrepareConf(b)
+       out := bd.FinishedBytes()
        resp := pc.GetRootAsResp(out, 0)
-       assert.Equal(t, uint32(3), resp.ConfToken())
+       assert.Equal(t, uint32(1), resp.ConfToken())
 
-       res, _ := GetRuleConf(3)
+       res, _ := GetRuleConf(1)
        assert.Equal(t, 0, len(res))
 
        time.Sleep(2 * time.Millisecond)
-       _, err := GetRuleConf(3)
+       _, err := GetRuleConf(1)
        assert.Equal(t, ttlcache.ErrNotFound, err)
 }
 
@@ -85,7 +125,7 @@ func TestGetRuleConfCheckConf(t *testing.T) {
        b := builder.FinishedBytes()
 
        PrepareConf(b)
-       res, _ := GetRuleConf(4)
+       res, _ := GetRuleConf(1)
        assert.Equal(t, 1, len(res))
        assert.Equal(t, "echo", res[0].Name)
 }
diff --git a/internal/server/error.go b/internal/server/error.go
index 8ba517c..81d77c3 100644
--- a/internal/server/error.go
+++ b/internal/server/error.go
@@ -18,15 +18,12 @@ import (
        "github.com/ReneKroon/ttlcache/v2"
        A6Err "github.com/api7/ext-plugin-proto/go/A6/Err"
        flatbuffers "github.com/google/flatbuffers/go"
-)
 
-var (
-       builder = flatbuffers.NewBuilder(256)
+       "github.com/apache/apisix-go-plugin-runner/internal/util"
 )
 
-func ReportError(err error) []byte {
-       builder.Reset()
-
+func ReportError(err error) *flatbuffers.Builder {
+       builder := util.GetBuilder()
        A6Err.RespStart(builder)
 
        var code A6Err.Code
@@ -40,5 +37,5 @@ func ReportError(err error) []byte {
        A6Err.RespAddCode(builder, code)
        resp := A6Err.RespEnd(builder)
        builder.Finish(resp)
-       return builder.FinishedBytes()
+       return builder
 }
diff --git a/internal/server/error_test.go b/internal/server/error_test.go
index 4d8bb21..fb5b426 100644
--- a/internal/server/error_test.go
+++ b/internal/server/error_test.go
@@ -29,12 +29,14 @@ func TestReportErrorCacheToken(t *testing.T) {
 
        _, err := plugin.GetRuleConf(uint32(999999))
        b := ReportError(err)
-       resp := A6Err.GetRootAsResp(b, 0)
+       out := b.FinishedBytes()
+       resp := A6Err.GetRootAsResp(out, 0)
        assert.Equal(t, A6Err.CodeCONF_TOKEN_NOT_FOUND, resp.Code())
 }
 
 func TestReportErrorUnknownErr(t *testing.T) {
        b := ReportError(io.EOF)
-       resp := A6Err.GetRootAsResp(b, 0)
+       out := b.FinishedBytes()
+       resp := A6Err.GetRootAsResp(out, 0)
        assert.Equal(t, A6Err.CodeSERVICE_UNAVAILABLE, resp.Code())
 }
diff --git a/internal/server/server.go b/internal/server/server.go
index c2deef9..52c5624 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -28,6 +28,8 @@ import (
 
        "github.com/apache/apisix-go-plugin-runner/internal/log"
        "github.com/apache/apisix-go-plugin-runner/internal/plugin"
+       "github.com/apache/apisix-go-plugin-runner/internal/util"
+       flatbuffers "github.com/google/flatbuffers/go"
 )
 
 const (
@@ -92,14 +94,15 @@ func handleConn(c net.Conn) {
                        break
                }
 
-               var out []byte
+               var bd *flatbuffers.Builder
                switch ty {
                case RPCPrepareConf:
-                       out, err = plugin.PrepareConf(buf)
+                       bd, err = plugin.PrepareConf(buf)
                default:
                        err = fmt.Errorf("unknown type %d", ty)
                }
 
+               out := bd.FinishedBytes()
                size := len(out)
                if size > MaxDataSize {
                        err = fmt.Errorf("the max length of data is %d but got 
%d", MaxDataSize, size)
@@ -108,7 +111,9 @@ func handleConn(c net.Conn) {
 
                if err != nil {
                        ty = RPCError
-                       out = ReportError(err)
+                       util.PutBuilder(bd)
+                       bd = ReportError(err)
+                       out = bd.FinishedBytes()
                }
 
                binary.BigEndian.PutUint32(header, uint32(size))
@@ -117,14 +122,17 @@ func handleConn(c net.Conn) {
                n, err = c.Write(header)
                if err != nil {
                        writeErr(n, err)
+                       util.PutBuilder(bd)
                        break
                }
 
                n, err = c.Write(out)
                if err != nil {
                        writeErr(n, err)
+                       util.PutBuilder(bd)
                        break
                }
+               util.PutBuilder(bd)
        }
 }
 
diff --git a/internal/server/error.go b/internal/util/pool.go
similarity index 62%
copy from internal/server/error.go
copy to internal/util/pool.go
index 8ba517c..bf6a746 100644
--- a/internal/server/error.go
+++ b/internal/util/pool.go
@@ -12,33 +12,25 @@
 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 // See the License for the specific language governing permissions and
 // limitations under the License.
-package server
+package util
 
 import (
-       "github.com/ReneKroon/ttlcache/v2"
-       A6Err "github.com/api7/ext-plugin-proto/go/A6/Err"
-       flatbuffers "github.com/google/flatbuffers/go"
-)
+       "sync"
 
-var (
-       builder = flatbuffers.NewBuilder(256)
+       flatbuffers "github.com/google/flatbuffers/go"
 )
 
-func ReportError(err error) []byte {
-       builder.Reset()
-
-       A6Err.RespStart(builder)
+var builderPool = sync.Pool{
+       New: func() interface{} {
+               return flatbuffers.NewBuilder(256)
+       },
+}
 
-       var code A6Err.Code
-       switch err {
-       case ttlcache.ErrNotFound:
-               code = A6Err.CodeCONF_TOKEN_NOT_FOUND
-       default:
-               code = A6Err.CodeSERVICE_UNAVAILABLE
-       }
+func GetBuilder() *flatbuffers.Builder {
+       return builderPool.Get().(*flatbuffers.Builder)
+}
 
-       A6Err.RespAddCode(builder, code)
-       resp := A6Err.RespEnd(builder)
-       builder.Finish(resp)
-       return builder.FinishedBytes()
+func PutBuilder(b *flatbuffers.Builder) {
+       b.Reset()
+       builderPool.Put(b)
 }

Reply via email to