This is an automated email from the ASF dual-hosted git repository.

gosonzhang pushed a commit to branch INLONG-25
in repository https://gitbox.apache.org/repos/asf/incubator-inlong.git

commit 108fd1fbefedc73fd38aaf699a0cd1e33f3bb03a
Author: Zijie Lu <[email protected]>
AuthorDate: Fri Apr 30 14:37:41 2021 +0800

    [INLONG-600]Multiplexed connection pool for Go sdk
    
    Signed-off-by: Zijie Lu <[email protected]>
---
 .../tubemq-client-go/codec/codec.go                | 107 ++++++
 tubemq-client-twins/tubemq-client-go/go.mod        |   5 +
 .../tubemq-client-go/pool/multiplexed.go           | 386 +++++++++++++++++++++
 .../tubemq-client-go/pool/multlplexed_test.go      | 119 +++++++
 4 files changed, 617 insertions(+)

diff --git a/tubemq-client-twins/tubemq-client-go/codec/codec.go 
b/tubemq-client-twins/tubemq-client-go/codec/codec.go
new file mode 100644
index 0000000..ee27e96
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/codec/codec.go
@@ -0,0 +1,107 @@
+package codec
+
+import (
+       "bufio"
+       "encoding/binary"
+       "errors"
+       "io"
+)
+
+const (
+       RPCProtocolBeginToken uint32 = 0xFF7FF4FE
+       RPCMaxBufferSize      uint32 = 8192
+       frameHeadLen          uint32 = 8
+       maxBufferSize         int    = 128 * 1024
+       defaultMsgSize        int    = 4096
+       dataLen               uint32 = 4
+       listSizeLen           uint32 = 4
+       serialNoLen           uint32 = 4
+       beginTokenLen         uint32 = 4
+)
+
+type Framer struct {
+       reader io.Reader
+       msg    []byte
+}
+
+func New(reader io.Reader) *Framer {
+       bufferReader := bufio.NewReaderSize(reader, maxBufferSize)
+       return &Framer{
+               msg:    make([]byte, defaultMsgSize),
+               reader: bufferReader,
+       }
+}
+
+func (f *Framer) Decode() (*FrameResponse, error) {
+       num, err := io.ReadFull(f.reader, f.msg[:frameHeadLen])
+       if err != nil {
+               return nil, err
+       }
+       if num != int(frameHeadLen) {
+               return nil, errors.New("framer: read frame header num invalid")
+       }
+       token := binary.BigEndian.Uint32(f.msg[:beginTokenLen])
+       if token != RPCProtocolBeginToken {
+               return nil, errors.New("framer: read framer rpc protocol begin 
token not match")
+       }
+       num, err = io.ReadFull(f.reader, 
f.msg[frameHeadLen:frameHeadLen+listSizeLen])
+       if num != int(listSizeLen) {
+               return nil, errors.New("framer: read invalid list size num")
+       }
+       listSize := binary.BigEndian.Uint32(f.msg[frameHeadLen : 
frameHeadLen+listSizeLen])
+       totalLen := int(frameHeadLen)
+       size := make([]byte, 4)
+       for i := 0; i < int(listSize); i++ {
+               n, err := io.ReadFull(f.reader, size)
+               if err != nil {
+                       return nil, err
+               }
+               if n != int(dataLen) {
+                       return nil, errors.New("framer: read invalid size")
+               }
+
+               s := int(binary.BigEndian.Uint32(size))
+               if totalLen+s > len(f.msg) {
+                       data := f.msg[:totalLen]
+                       f.msg = make([]byte, totalLen+s)
+                       copy(f.msg, data[:])
+               }
+
+               num, err = io.ReadFull(f.reader, f.msg[totalLen:totalLen+s])
+               if err != nil {
+                       return nil, err
+               }
+               if num != s {
+                       return nil, errors.New("framer: read invalid data")
+               }
+               totalLen += s
+       }
+
+       data := make([]byte, totalLen - int(frameHeadLen))
+       copy(data, f.msg[frameHeadLen:totalLen])
+
+       return &FrameResponse{
+               serialNo:    binary.BigEndian.Uint32(f.msg[beginTokenLen : 
beginTokenLen+serialNoLen]),
+               responseBuf: data,
+       }, nil
+}
+
+type FrameRequest struct {
+       requestID uint32
+       req       []byte
+}
+
+type FrameResponse struct {
+       serialNo    uint32
+       responseBuf []byte
+}
+
+func (f *FrameResponse) GetSerialNo() uint32 {
+       return f.serialNo
+}
+
+func (f *FrameResponse) GetResponseBuf() []byte {
+       return f.responseBuf
+}
+
+type Codec struct{}
diff --git a/tubemq-client-twins/tubemq-client-go/go.mod 
b/tubemq-client-twins/tubemq-client-go/go.mod
new file mode 100644
index 0000000..7c1a676
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/go.mod
@@ -0,0 +1,5 @@
+module github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go
+
+go 1.14
+
+require github.com/stretchr/testify v1.7.0
diff --git a/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go 
b/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go
new file mode 100644
index 0000000..5d38a14
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go
@@ -0,0 +1,386 @@
+package pool
+
+import (
+       "context"
+       "crypto/tls"
+       "crypto/x509"
+       "errors"
+       "io/ioutil"
+       "net"
+       "sync"
+       "time"
+
+       
"github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go/codec"
+)
+
+var DefaultMultiplexedPool = New()
+
+var (
+       // ErrConnClosed indicates that the connection is closed
+       ErrConnClosed = errors.New("connection is closed")
+       // ErrChanClose indicates the recv chan is closed
+       ErrChanClose = errors.New("unexpected recv chan close")
+       // ErrWriteBufferDone indicates write buffer done
+       ErrWriteBufferDone = errors.New("write buffer done")
+       // ErrAssertConnectionFail indicates connection assertion error
+       ErrAssertConnectionFail = errors.New("assert connection slice fail")
+)
+
+const (
+       Initial int = iota
+       Connected
+       Closing
+       Closed
+)
+
+var queueSize = 10000
+
+func New() *Multiplexed {
+       m := &Multiplexed{
+               connections: new(sync.Map),
+       }
+       return m
+}
+
+type writerBuffer struct {
+       buffer chan []byte
+       done   <-chan struct{}
+}
+
+func (w *writerBuffer) get() ([]byte, error) {
+       select {
+       case req := <-w.buffer:
+               return req, nil
+       case <-w.done:
+               return nil, ErrWriteBufferDone
+       }
+}
+
+type recvReader struct {
+       ctx  context.Context
+       recv chan *codec.FrameResponse
+}
+
+type MultiplexedConnection struct {
+       serialNo uint32
+       conn     *Connection
+       reader   *recvReader
+       done     chan struct{}
+}
+
+func (mc *MultiplexedConnection) Write(b []byte) error {
+       if err := mc.conn.send(b); err != nil {
+               mc.conn.remove(mc.serialNo)
+               return err
+       }
+       return nil
+}
+
+func (mc *MultiplexedConnection) Read() (*codec.FrameResponse, error) {
+       select {
+       case <-mc.reader.ctx.Done():
+               mc.conn.remove(mc.serialNo)
+               return nil, mc.reader.ctx.Err()
+       case v, ok := <-mc.reader.recv:
+               if ok {
+                       return v, nil
+               }
+               if mc.conn.err != nil {
+                       return nil, mc.conn.err
+               }
+               return nil, ErrChanClose
+       case <-mc.done:
+               return nil, mc.conn.err
+       }
+}
+
+func (mc *MultiplexedConnection) recv(rsp *codec.FrameResponse) {
+       mc.reader.recv <- rsp
+       mc.conn.remove(rsp.GetSerialNo())
+}
+
+type DialOptions struct {
+       Network       string
+       Address       string
+       Timeout       time.Duration
+       CACertFile    string
+       TLSCertFile   string
+       TLSKeyFile    string
+       TLSServerName string
+}
+
+type Connection struct {
+       err         error
+       address     string
+       mu          sync.RWMutex
+       connections map[uint32]*MultiplexedConnection
+       framer      *codec.Framer
+       conn        net.Conn
+       done        chan struct{}
+       mDone       chan struct{}
+       buffer      *writerBuffer
+       dialOpts    *DialOptions
+       state       int
+       multiplexed *Multiplexed
+}
+
+func (c *Connection) new(ctx context.Context, serialNo uint32) 
(*MultiplexedConnection, error) {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       if c.err != nil {
+               return nil, c.err
+       }
+
+       vc := &MultiplexedConnection{
+               serialNo: serialNo,
+               conn:     c,
+               done:     c.mDone,
+               reader: &recvReader{
+                       ctx:  ctx,
+                       recv: make(chan *codec.FrameResponse, 1),
+               },
+       }
+
+       if prevConn, ok := c.connections[serialNo]; ok {
+               close(prevConn.reader.recv)
+       }
+       c.connections[serialNo] = vc
+       return vc, nil
+}
+
+func (c *Connection) close(lastErr error, done chan struct{}) {
+       if lastErr == nil {
+               return
+       }
+       c.mu.Lock()
+       defer c.mu.Unlock()
+
+       if c.state == Closed {
+               return
+       }
+
+       select {
+       case <-done:
+               return
+       default:
+       }
+
+       c.state = Closing
+       c.err = lastErr
+       c.connections = make(map[uint32]*MultiplexedConnection)
+       close(c.done)
+       if c.conn != nil {
+               c.conn.Close()
+       }
+       err := c.reconnect()
+       if err != nil {
+               c.state = Closed
+               close(c.mDone)
+               c.multiplexed.connections.Delete(c)
+       }
+}
+
+func (c *Connection) reconnect() error {
+       conn, err := dialWithTimeout(c.dialOpts)
+       if err != nil {
+               return err
+       }
+       c.done = make(chan struct{})
+       c.conn = conn
+       c.framer = codec.New(conn)
+       c.buffer.done = c.done
+       c.state = Connected
+       c.err = nil
+       go c.reader()
+       go c.writer()
+       return nil
+}
+
+func (c *Connection) writer() {
+       var lastErr error
+       for {
+               select {
+               case <-c.done:
+                       return
+               default:
+               }
+               req, err := c.buffer.get()
+               if err != nil {
+                       lastErr = err
+                       break
+               }
+               if err := c.write(req); err != nil {
+                       lastErr = err
+                       break
+               }
+       }
+       c.close(lastErr, c.done)
+}
+
+func (c *Connection) send(b []byte) error {
+       if c.state == Closed {
+               return ErrConnClosed
+       }
+
+       select {
+       case c.buffer.buffer <- b:
+               return nil
+       case <-c.mDone:
+               return c.err
+       }
+}
+
+func (c *Connection) remove(id uint32) {
+       c.mu.Lock()
+       delete(c.connections, id)
+       c.mu.Unlock()
+}
+
+func (c *Connection) write(b []byte) error {
+       sent := 0
+       for sent < len(b) {
+               n, err := c.conn.Write(b[sent:])
+               if err != nil {
+                       return err
+               }
+               sent += n
+       }
+       return nil
+}
+
+func (c *Connection) reader() {
+       var lastErr error
+       for {
+               select {
+               case <-c.done:
+                       return
+               default:
+               }
+               rsp, err := c.framer.Decode()
+               if err != nil {
+                       lastErr = err
+                       break
+               }
+               serialNo := rsp.GetSerialNo()
+               c.mu.RLock()
+               mc, ok := c.connections[serialNo]
+               c.mu.RUnlock()
+               if !ok {
+                       continue
+               }
+               mc.reader.recv <- rsp
+               mc.conn.remove(rsp.GetSerialNo())
+       }
+       c.close(lastErr, c.done)
+}
+
+type Multiplexed struct {
+       connections *sync.Map
+}
+
+func (p *Multiplexed) Get(ctx context.Context, address string, serialNo 
uint32) (*MultiplexedConnection, error) {
+       select {
+       case <-ctx.Done():
+               return nil, ctx.Err()
+       default:
+       }
+
+       if v, ok := p.connections.Load(address); ok {
+               if c, ok := v.(*Connection); ok {
+                       return c.new(ctx, serialNo)
+               }
+               return nil, ErrAssertConnectionFail
+       }
+
+       c := &Connection{
+               address:     address,
+               connections: make(map[uint32]*MultiplexedConnection),
+               done:        make(chan struct{}),
+               mDone:       make(chan struct{}),
+               state:       Initial,
+       }
+       c.buffer = &writerBuffer{
+               buffer: make(chan []byte, queueSize),
+               done:   c.done,
+       }
+       p.connections.Store(address, c)
+
+       conn, dialOpts, err := dial(ctx, address)
+       c.dialOpts = dialOpts
+       if err != nil {
+               return nil, err
+       }
+       c.framer = codec.New(conn)
+       c.conn = conn
+       c.state = Connected
+       go c.reader()
+       go c.writer()
+       return c.new(ctx, serialNo)
+}
+
+func dial(ctx context.Context, address string) (net.Conn, *DialOptions, error) 
{
+       var timeout time.Duration
+       t, ok := ctx.Deadline()
+       if ok {
+               timeout = t.Sub(time.Now())
+       }
+       dialOpts := &DialOptions{
+               Network: "tcp",
+               Address: address,
+               Timeout: timeout,
+       }
+       select {
+       case <-ctx.Done():
+               return nil, dialOpts, ctx.Err()
+       default:
+       }
+       conn, err := dialWithTimeout(dialOpts)
+       return conn, dialOpts, err
+}
+
+func dialWithTimeout(opts *DialOptions) (net.Conn, error) {
+       if len(opts.CACertFile) == 0 {
+               return net.DialTimeout(opts.Network, opts.Address, opts.Timeout)
+       }
+
+       tlsConf := &tls.Config{}
+       if opts.CACertFile == "none" { // 不需要检验服务证书
+               tlsConf.InsecureSkipVerify = true
+       } else {
+               if len(opts.TLSServerName) == 0 {
+                       opts.TLSServerName = opts.Address
+               }
+               tlsConf.ServerName = opts.TLSServerName
+               certPool, err := getCertPool(opts.CACertFile)
+               if err != nil {
+                       return nil, err
+               }
+
+               tlsConf.RootCAs = certPool
+
+               if len(opts.TLSCertFile) != 0 {
+                       cert, err := tls.LoadX509KeyPair(opts.TLSCertFile, 
opts.TLSKeyFile)
+                       if err != nil {
+                               return nil, err
+                       }
+                       tlsConf.Certificates = []tls.Certificate{cert}
+               }
+       }
+       return tls.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, 
opts.Network, opts.Address, tlsConf)
+}
+
+func getCertPool(caCertFile string) (*x509.CertPool, error) {
+       if caCertFile != "root" {
+               ca, err := ioutil.ReadFile(caCertFile)
+               if err != nil {
+                       return nil, err
+               }
+               certPool := x509.NewCertPool()
+               ok := certPool.AppendCertsFromPEM(ca)
+               if !ok {
+                       return nil, err
+               }
+               return certPool, nil
+       }
+       return nil, nil
+}
diff --git a/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go 
b/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go
new file mode 100644
index 0000000..6377032
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go
@@ -0,0 +1,119 @@
+package pool
+
+import (
+       "bytes"
+       "context"
+       "encoding/binary"
+       "io"
+       "log"
+       "net"
+       "strconv"
+       "sync"
+       "sync/atomic"
+       "testing"
+       "time"
+
+       "github.com/stretchr/testify/assert"
+
+       
"github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go/codec"
+)
+
+var (
+       address         = "127.0.0.1:0"
+       ch              = make(chan struct{})
+       serialNo uint32 = 1
+)
+
+func init() {
+       go simpleForwardTCPServer(ch)
+       <-ch
+}
+
+func simpleForwardTCPServer(ch chan struct{}) {
+       l, err := net.Listen("tcp", address)
+       if err != nil {
+               log.Fatal(err)
+       }
+       defer l.Close()
+       address = l.Addr().String()
+
+       ch <- struct{}{}
+
+       for {
+               conn, err := l.Accept()
+               if err != nil {
+                       log.Fatal(err)
+               }
+
+               go func() {
+                       io.Copy(conn, conn)
+               }()
+       }
+}
+
+func Encode(serialNo uint32, body []byte) ([]byte, error) {
+       l := len(body)
+       buf := bytes.NewBuffer(make([]byte, 0, 16+l))
+       if err := binary.Write(buf, binary.BigEndian, 
codec.RPCProtocolBeginToken); err != nil {
+               return nil, err
+       }
+       if err := binary.Write(buf, binary.BigEndian, serialNo); err != nil {
+               return nil, err
+       }
+       if err := binary.Write(buf, binary.BigEndian, uint32(1)); err != nil {
+               return nil, err
+       }
+       if err := binary.Write(buf, binary.BigEndian, uint32(len(body))); err 
!= nil {
+               return nil, err
+       }
+       buf.Write(body)
+       return buf.Bytes(), nil
+}
+
+func TestBasicMultiplexed(t *testing.T) {
+       serialNo := atomic.AddUint32(&serialNo, 1)
+       ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+       defer cancel()
+
+       m := New()
+       mc, err := m.Get(ctx, address, serialNo)
+       body := []byte("hello world")
+
+       buf, err := Encode(serialNo, body)
+       assert.Nil(t, err)
+       assert.Nil(t, mc.Write(buf))
+
+       rsp, err := mc.Read()
+       assert.Nil(t, err)
+       assert.Equal(t, serialNo, rsp.GetSerialNo())
+       assert.Equal(t, body, rsp.GetResponseBuf())
+       assert.Equal(t, mc.Write(nil), nil)
+}
+
+func TestConcurrentMultiplexed(t *testing.T) {
+       count := 1000
+       m := New()
+       wg := sync.WaitGroup{}
+       wg.Add(count)
+       for i := 0; i < count; i++ {
+               go func(i int) {
+                       defer wg.Done()
+                       ctx, cancel := 
context.WithTimeout(context.Background(), time.Second)
+                       defer cancel()
+                       serialNo := atomic.AddUint32(&serialNo, 1)
+                       mc, err := m.Get(ctx, address, serialNo)
+                       assert.Nil(t, err)
+
+                       body := []byte("hello world" + strconv.Itoa(i))
+                       buf, err := Encode(serialNo, body)
+                       assert.Nil(t, err)
+                       assert.Nil(t, mc.Write(buf))
+
+                       rsp, err := mc.Read()
+                       assert.Nil(t, err)
+                       assert.Equal(t, serialNo, rsp.GetSerialNo())
+                       assert.Equal(t, body, rsp.GetResponseBuf())
+               }(i)
+       }
+       wg.Wait()
+}

Reply via email to