This is an automated email from the ASF dual-hosted git repository. gosonzhang pushed a commit to branch INLONG-25 in repository https://gitbox.apache.org/repos/asf/incubator-inlong.git
commit 108fd1fbefedc73fd38aaf699a0cd1e33f3bb03a Author: Zijie Lu <[email protected]> AuthorDate: Fri Apr 30 14:37:41 2021 +0800 [INLONG-600]Multiplexed connection pool for Go sdk Signed-off-by: Zijie Lu <[email protected]> --- .../tubemq-client-go/codec/codec.go | 107 ++++++ tubemq-client-twins/tubemq-client-go/go.mod | 5 + .../tubemq-client-go/pool/multiplexed.go | 386 +++++++++++++++++++++ .../tubemq-client-go/pool/multlplexed_test.go | 119 +++++++ 4 files changed, 617 insertions(+) diff --git a/tubemq-client-twins/tubemq-client-go/codec/codec.go b/tubemq-client-twins/tubemq-client-go/codec/codec.go new file mode 100644 index 0000000..ee27e96 --- /dev/null +++ b/tubemq-client-twins/tubemq-client-go/codec/codec.go @@ -0,0 +1,107 @@ +package codec + +import ( + "bufio" + "encoding/binary" + "errors" + "io" +) + +const ( + RPCProtocolBeginToken uint32 = 0xFF7FF4FE + RPCMaxBufferSize uint32 = 8192 + frameHeadLen uint32 = 8 + maxBufferSize int = 128 * 1024 + defaultMsgSize int = 4096 + dataLen uint32 = 4 + listSizeLen uint32 = 4 + serialNoLen uint32 = 4 + beginTokenLen uint32 = 4 +) + +type Framer struct { + reader io.Reader + msg []byte +} + +func New(reader io.Reader) *Framer { + bufferReader := bufio.NewReaderSize(reader, maxBufferSize) + return &Framer{ + msg: make([]byte, defaultMsgSize), + reader: bufferReader, + } +} + +func (f *Framer) Decode() (*FrameResponse, error) { + num, err := io.ReadFull(f.reader, f.msg[:frameHeadLen]) + if err != nil { + return nil, err + } + if num != int(frameHeadLen) { + return nil, errors.New("framer: read frame header num invalid") + } + token := binary.BigEndian.Uint32(f.msg[:beginTokenLen]) + if token != RPCProtocolBeginToken { + return nil, errors.New("framer: read framer rpc protocol begin token not match") + } + num, err = io.ReadFull(f.reader, f.msg[frameHeadLen:frameHeadLen+listSizeLen]) + if num != int(listSizeLen) { + return nil, errors.New("framer: read invalid list size num") + } + listSize := binary.BigEndian.Uint32(f.msg[frameHeadLen : frameHeadLen+listSizeLen]) + totalLen := int(frameHeadLen) + size := make([]byte, 4) + for i := 0; i < int(listSize); i++ { + n, err := io.ReadFull(f.reader, size) + if err != nil { + return nil, err + } + if n != int(dataLen) { + return nil, errors.New("framer: read invalid size") + } + + s := int(binary.BigEndian.Uint32(size)) + if totalLen+s > len(f.msg) { + data := f.msg[:totalLen] + f.msg = make([]byte, totalLen+s) + copy(f.msg, data[:]) + } + + num, err = io.ReadFull(f.reader, f.msg[totalLen:totalLen+s]) + if err != nil { + return nil, err + } + if num != s { + return nil, errors.New("framer: read invalid data") + } + totalLen += s + } + + data := make([]byte, totalLen - int(frameHeadLen)) + copy(data, f.msg[frameHeadLen:totalLen]) + + return &FrameResponse{ + serialNo: binary.BigEndian.Uint32(f.msg[beginTokenLen : beginTokenLen+serialNoLen]), + responseBuf: data, + }, nil +} + +type FrameRequest struct { + requestID uint32 + req []byte +} + +type FrameResponse struct { + serialNo uint32 + responseBuf []byte +} + +func (f *FrameResponse) GetSerialNo() uint32 { + return f.serialNo +} + +func (f *FrameResponse) GetResponseBuf() []byte { + return f.responseBuf +} + +type Codec struct{} diff --git a/tubemq-client-twins/tubemq-client-go/go.mod b/tubemq-client-twins/tubemq-client-go/go.mod new file mode 100644 index 0000000..7c1a676 --- /dev/null +++ b/tubemq-client-twins/tubemq-client-go/go.mod @@ -0,0 +1,5 @@ +module github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go + +go 1.14 + +require github.com/stretchr/testify v1.7.0 diff --git a/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go b/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go new file mode 100644 index 0000000..5d38a14 --- /dev/null +++ b/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go @@ -0,0 +1,386 @@ +package pool + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io/ioutil" + "net" + "sync" + "time" + + "github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go/codec" +) + +var DefaultMultiplexedPool = New() + +var ( + // ErrConnClosed indicates that the connection is closed + ErrConnClosed = errors.New("connection is closed") + // ErrChanClose indicates the recv chan is closed + ErrChanClose = errors.New("unexpected recv chan close") + // ErrWriteBufferDone indicates write buffer done + ErrWriteBufferDone = errors.New("write buffer done") + // ErrAssertConnectionFail indicates connection assertion error + ErrAssertConnectionFail = errors.New("assert connection slice fail") +) + +const ( + Initial int = iota + Connected + Closing + Closed +) + +var queueSize = 10000 + +func New() *Multiplexed { + m := &Multiplexed{ + connections: new(sync.Map), + } + return m +} + +type writerBuffer struct { + buffer chan []byte + done <-chan struct{} +} + +func (w *writerBuffer) get() ([]byte, error) { + select { + case req := <-w.buffer: + return req, nil + case <-w.done: + return nil, ErrWriteBufferDone + } +} + +type recvReader struct { + ctx context.Context + recv chan *codec.FrameResponse +} + +type MultiplexedConnection struct { + serialNo uint32 + conn *Connection + reader *recvReader + done chan struct{} +} + +func (mc *MultiplexedConnection) Write(b []byte) error { + if err := mc.conn.send(b); err != nil { + mc.conn.remove(mc.serialNo) + return err + } + return nil +} + +func (mc *MultiplexedConnection) Read() (*codec.FrameResponse, error) { + select { + case <-mc.reader.ctx.Done(): + mc.conn.remove(mc.serialNo) + return nil, mc.reader.ctx.Err() + case v, ok := <-mc.reader.recv: + if ok { + return v, nil + } + if mc.conn.err != nil { + return nil, mc.conn.err + } + return nil, ErrChanClose + case <-mc.done: + return nil, mc.conn.err + } +} + +func (mc *MultiplexedConnection) recv(rsp *codec.FrameResponse) { + mc.reader.recv <- rsp + mc.conn.remove(rsp.GetSerialNo()) +} + +type DialOptions struct { + Network string + Address string + Timeout time.Duration + CACertFile string + TLSCertFile string + TLSKeyFile string + TLSServerName string +} + +type Connection struct { + err error + address string + mu sync.RWMutex + connections map[uint32]*MultiplexedConnection + framer *codec.Framer + conn net.Conn + done chan struct{} + mDone chan struct{} + buffer *writerBuffer + dialOpts *DialOptions + state int + multiplexed *Multiplexed +} + +func (c *Connection) new(ctx context.Context, serialNo uint32) (*MultiplexedConnection, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.err != nil { + return nil, c.err + } + + vc := &MultiplexedConnection{ + serialNo: serialNo, + conn: c, + done: c.mDone, + reader: &recvReader{ + ctx: ctx, + recv: make(chan *codec.FrameResponse, 1), + }, + } + + if prevConn, ok := c.connections[serialNo]; ok { + close(prevConn.reader.recv) + } + c.connections[serialNo] = vc + return vc, nil +} + +func (c *Connection) close(lastErr error, done chan struct{}) { + if lastErr == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + + if c.state == Closed { + return + } + + select { + case <-done: + return + default: + } + + c.state = Closing + c.err = lastErr + c.connections = make(map[uint32]*MultiplexedConnection) + close(c.done) + if c.conn != nil { + c.conn.Close() + } + err := c.reconnect() + if err != nil { + c.state = Closed + close(c.mDone) + c.multiplexed.connections.Delete(c) + } +} + +func (c *Connection) reconnect() error { + conn, err := dialWithTimeout(c.dialOpts) + if err != nil { + return err + } + c.done = make(chan struct{}) + c.conn = conn + c.framer = codec.New(conn) + c.buffer.done = c.done + c.state = Connected + c.err = nil + go c.reader() + go c.writer() + return nil +} + +func (c *Connection) writer() { + var lastErr error + for { + select { + case <-c.done: + return + default: + } + req, err := c.buffer.get() + if err != nil { + lastErr = err + break + } + if err := c.write(req); err != nil { + lastErr = err + break + } + } + c.close(lastErr, c.done) +} + +func (c *Connection) send(b []byte) error { + if c.state == Closed { + return ErrConnClosed + } + + select { + case c.buffer.buffer <- b: + return nil + case <-c.mDone: + return c.err + } +} + +func (c *Connection) remove(id uint32) { + c.mu.Lock() + delete(c.connections, id) + c.mu.Unlock() +} + +func (c *Connection) write(b []byte) error { + sent := 0 + for sent < len(b) { + n, err := c.conn.Write(b[sent:]) + if err != nil { + return err + } + sent += n + } + return nil +} + +func (c *Connection) reader() { + var lastErr error + for { + select { + case <-c.done: + return + default: + } + rsp, err := c.framer.Decode() + if err != nil { + lastErr = err + break + } + serialNo := rsp.GetSerialNo() + c.mu.RLock() + mc, ok := c.connections[serialNo] + c.mu.RUnlock() + if !ok { + continue + } + mc.reader.recv <- rsp + mc.conn.remove(rsp.GetSerialNo()) + } + c.close(lastErr, c.done) +} + +type Multiplexed struct { + connections *sync.Map +} + +func (p *Multiplexed) Get(ctx context.Context, address string, serialNo uint32) (*MultiplexedConnection, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if v, ok := p.connections.Load(address); ok { + if c, ok := v.(*Connection); ok { + return c.new(ctx, serialNo) + } + return nil, ErrAssertConnectionFail + } + + c := &Connection{ + address: address, + connections: make(map[uint32]*MultiplexedConnection), + done: make(chan struct{}), + mDone: make(chan struct{}), + state: Initial, + } + c.buffer = &writerBuffer{ + buffer: make(chan []byte, queueSize), + done: c.done, + } + p.connections.Store(address, c) + + conn, dialOpts, err := dial(ctx, address) + c.dialOpts = dialOpts + if err != nil { + return nil, err + } + c.framer = codec.New(conn) + c.conn = conn + c.state = Connected + go c.reader() + go c.writer() + return c.new(ctx, serialNo) +} + +func dial(ctx context.Context, address string) (net.Conn, *DialOptions, error) { + var timeout time.Duration + t, ok := ctx.Deadline() + if ok { + timeout = t.Sub(time.Now()) + } + dialOpts := &DialOptions{ + Network: "tcp", + Address: address, + Timeout: timeout, + } + select { + case <-ctx.Done(): + return nil, dialOpts, ctx.Err() + default: + } + conn, err := dialWithTimeout(dialOpts) + return conn, dialOpts, err +} + +func dialWithTimeout(opts *DialOptions) (net.Conn, error) { + if len(opts.CACertFile) == 0 { + return net.DialTimeout(opts.Network, opts.Address, opts.Timeout) + } + + tlsConf := &tls.Config{} + if opts.CACertFile == "none" { // 不需要检验服务证书 + tlsConf.InsecureSkipVerify = true + } else { + if len(opts.TLSServerName) == 0 { + opts.TLSServerName = opts.Address + } + tlsConf.ServerName = opts.TLSServerName + certPool, err := getCertPool(opts.CACertFile) + if err != nil { + return nil, err + } + + tlsConf.RootCAs = certPool + + if len(opts.TLSCertFile) != 0 { + cert, err := tls.LoadX509KeyPair(opts.TLSCertFile, opts.TLSKeyFile) + if err != nil { + return nil, err + } + tlsConf.Certificates = []tls.Certificate{cert} + } + } + return tls.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, opts.Network, opts.Address, tlsConf) +} + +func getCertPool(caCertFile string) (*x509.CertPool, error) { + if caCertFile != "root" { + ca, err := ioutil.ReadFile(caCertFile) + if err != nil { + return nil, err + } + certPool := x509.NewCertPool() + ok := certPool.AppendCertsFromPEM(ca) + if !ok { + return nil, err + } + return certPool, nil + } + return nil, nil +} diff --git a/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go b/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go new file mode 100644 index 0000000..6377032 --- /dev/null +++ b/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go @@ -0,0 +1,119 @@ +package pool + +import ( + "bytes" + "context" + "encoding/binary" + "io" + "log" + "net" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go/codec" +) + +var ( + address = "127.0.0.1:0" + ch = make(chan struct{}) + serialNo uint32 = 1 +) + +func init() { + go simpleForwardTCPServer(ch) + <-ch +} + +func simpleForwardTCPServer(ch chan struct{}) { + l, err := net.Listen("tcp", address) + if err != nil { + log.Fatal(err) + } + defer l.Close() + address = l.Addr().String() + + ch <- struct{}{} + + for { + conn, err := l.Accept() + if err != nil { + log.Fatal(err) + } + + go func() { + io.Copy(conn, conn) + }() + } +} + +func Encode(serialNo uint32, body []byte) ([]byte, error) { + l := len(body) + buf := bytes.NewBuffer(make([]byte, 0, 16+l)) + if err := binary.Write(buf, binary.BigEndian, codec.RPCProtocolBeginToken); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.BigEndian, serialNo); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.BigEndian, uint32(1)); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.BigEndian, uint32(len(body))); err != nil { + return nil, err + } + buf.Write(body) + return buf.Bytes(), nil +} + +func TestBasicMultiplexed(t *testing.T) { + serialNo := atomic.AddUint32(&serialNo, 1) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + m := New() + mc, err := m.Get(ctx, address, serialNo) + body := []byte("hello world") + + buf, err := Encode(serialNo, body) + assert.Nil(t, err) + assert.Nil(t, mc.Write(buf)) + + rsp, err := mc.Read() + assert.Nil(t, err) + assert.Equal(t, serialNo, rsp.GetSerialNo()) + assert.Equal(t, body, rsp.GetResponseBuf()) + assert.Equal(t, mc.Write(nil), nil) +} + +func TestConcurrentMultiplexed(t *testing.T) { + count := 1000 + m := New() + wg := sync.WaitGroup{} + wg.Add(count) + for i := 0; i < count; i++ { + go func(i int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + serialNo := atomic.AddUint32(&serialNo, 1) + mc, err := m.Get(ctx, address, serialNo) + assert.Nil(t, err) + + body := []byte("hello world" + strconv.Itoa(i)) + buf, err := Encode(serialNo, body) + assert.Nil(t, err) + assert.Nil(t, mc.Write(buf)) + + rsp, err := mc.Read() + assert.Nil(t, err) + assert.Equal(t, serialNo, rsp.GetSerialNo()) + assert.Equal(t, body, rsp.GetResponseBuf()) + }(i) + } + wg.Wait() +}
