TszKitLo40 commented on a change in pull request #463:
URL: https://github.com/apache/incubator-inlong/pull/463#discussion_r623661866
##########
File path: tubemq-client-twins/tubemq-client-go/pool/multiplexed.go
##########
@@ -0,0 +1,386 @@
+package pool
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
+ "io/ioutil"
+ "net"
+ "sync"
+ "time"
+
+
"github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go/codec"
+)
+
+var DefaultMultiplexedPool = New()
+
+var (
+ // ErrConnClosed indicates that the connection is closed
+ ErrConnClosed = errors.New("connection is closed")
+ // ErrChanClose indicates the recv chan is closed
+ ErrChanClose = errors.New("unexpected recv chan close")
+ // ErrWriteBufferDone indicates write buffer done
+ ErrWriteBufferDone = errors.New("write buffer done")
+ // ErrAssertConnectionFail indicates connection assertion error
+ ErrAssertConnectionFail = errors.New("assert connection slice fail")
+)
+
+const (
+ Initial int = iota
+ Connected
+ Closing
+ Closed
+)
+
+var queueSize = 10000
+
+func New() *Multiplexed {
+ m := &Multiplexed{
+ connections: new(sync.Map),
+ }
+ return m
+}
+
+type writerBuffer struct {
+ buffer chan []byte
+ done <-chan struct{}
+}
+
+func (w *writerBuffer) get() ([]byte, error) {
+ select {
+ case req := <-w.buffer:
+ return req, nil
+ case <-w.done:
+ return nil, ErrWriteBufferDone
+ }
+}
+
+type recvReader struct {
+ ctx context.Context
+ recv chan *codec.FrameResponse
+}
+
+type MultiplexedConnection struct {
+ serialNo uint32
+ conn *Connection
+ reader *recvReader
+ done chan struct{}
+}
+
+func (mc *MultiplexedConnection) Write(b []byte) error {
+ if err := mc.conn.send(b); err != nil {
+ mc.conn.remove(mc.serialNo)
+ return err
+ }
+ return nil
+}
+
+func (mc *MultiplexedConnection) Read() (*codec.FrameResponse, error) {
+ select {
+ case <-mc.reader.ctx.Done():
+ mc.conn.remove(mc.serialNo)
+ return nil, mc.reader.ctx.Err()
+ case v, ok := <-mc.reader.recv:
+ if ok {
+ return v, nil
+ }
+ if mc.conn.err != nil {
+ return nil, mc.conn.err
+ }
+ return nil, ErrChanClose
+ case <-mc.done:
+ return nil, mc.conn.err
+ }
+}
+
+func (mc *MultiplexedConnection) recv(rsp *codec.FrameResponse) {
+ mc.reader.recv <- rsp
+ mc.conn.remove(rsp.GetSerialNo())
+}
+
+type DialOptions struct {
+ Network string
+ Address string
+ Timeout time.Duration
+ CACertFile string
+ TLSCertFile string
+ TLSKeyFile string
+ TLSServerName string
+}
+
+type Connection struct {
+ err error
+ address string
+ mu sync.RWMutex
+ connections map[uint32]*MultiplexedConnection
+ framer *codec.Framer
+ conn net.Conn
+ done chan struct{}
+ mDone chan struct{}
+ buffer *writerBuffer
+ dialOpts *DialOptions
+ state int
+ multiplexed *Multiplexed
+}
+
+func (c *Connection) new(ctx context.Context, serialNo uint32)
(*MultiplexedConnection, error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.err != nil {
+ return nil, c.err
+ }
+
+ vc := &MultiplexedConnection{
+ serialNo: serialNo,
+ conn: c,
+ done: c.mDone,
+ reader: &recvReader{
+ ctx: ctx,
+ recv: make(chan *codec.FrameResponse, 1),
+ },
+ }
+
+ if prevConn, ok := c.connections[serialNo]; ok {
+ close(prevConn.reader.recv)
+ }
+ c.connections[serialNo] = vc
+ return vc, nil
+}
+
+func (c *Connection) close(lastErr error, done chan struct{}) {
+ if lastErr == nil {
+ return
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.state == Closed {
+ return
+ }
+
+ select {
+ case <-done:
+ return
+ default:
+ }
+
+ c.state = Closing
+ c.err = lastErr
+ c.connections = make(map[uint32]*MultiplexedConnection)
+ close(c.done)
+ if c.conn != nil {
+ c.conn.Close()
+ }
+ err := c.reconnect()
+ if err != nil {
+ c.state = Closed
+ close(c.mDone)
+ c.multiplexed.connections.Delete(c)
+ }
+}
+
+func (c *Connection) reconnect() error {
+ conn, err := dialWithTimeout(c.dialOpts)
+ if err != nil {
+ return err
+ }
+ c.done = make(chan struct{})
+ c.conn = conn
+ c.framer = codec.New(conn)
+ c.buffer.done = c.done
+ c.state = Connected
+ c.err = nil
+ go c.reader()
+ go c.writer()
+ return nil
+}
+
+func (c *Connection) writer() {
+ var lastErr error
+ for {
+ select {
+ case <-c.done:
+ return
+ default:
+ }
+ req, err := c.buffer.get()
+ if err != nil {
+ lastErr = err
+ break
+ }
+ if err := c.write(req); err != nil {
+ lastErr = err
+ break
+ }
+ }
+ c.close(lastErr, c.done)
+}
+
+func (c *Connection) send(b []byte) error {
+ if c.state == Closed {
+ return ErrConnClosed
+ }
+
+ select {
+ case c.buffer.buffer <- b:
+ return nil
+ case <-c.mDone:
+ return c.err
+ }
+}
+
+func (c *Connection) remove(id uint32) {
+ c.mu.Lock()
+ delete(c.connections, id)
+ c.mu.Unlock()
+}
+
+func (c *Connection) write(b []byte) error {
+ sent := 0
+ for sent < len(b) {
+ n, err := c.conn.Write(b[sent:])
+ if err != nil {
+ return err
+ }
+ sent += n
+ }
+ return nil
+}
+
+func (c *Connection) reader() {
+ var lastErr error
+ for {
+ select {
+ case <-c.done:
+ return
+ default:
+ }
+ rsp, err := c.framer.Decode()
+ if err != nil {
+ lastErr = err
+ break
+ }
+ serialNo := rsp.GetSerialNo()
+ c.mu.RLock()
+ mc, ok := c.connections[serialNo]
+ c.mu.RUnlock()
+ if !ok {
+ continue
+ }
+ mc.reader.recv <- rsp
+ mc.conn.remove(rsp.GetSerialNo())
+ }
+ c.close(lastErr, c.done)
+}
+
+type Multiplexed struct {
+ connections *sync.Map
+}
+
+func (p *Multiplexed) Get(ctx context.Context, address string, serialNo
uint32) (*MultiplexedConnection, error) {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
+ if v, ok := p.connections.Load(address); ok {
+ if c, ok := v.(*Connection); ok {
+ return c.new(ctx, serialNo)
+ }
+ return nil, ErrAssertConnectionFail
+ }
+
+ c := &Connection{
+ address: address,
+ connections: make(map[uint32]*MultiplexedConnection),
+ done: make(chan struct{}),
+ mDone: make(chan struct{}),
+ state: Initial,
+ }
+ c.buffer = &writerBuffer{
+ buffer: make(chan []byte, queueSize),
+ done: c.done,
+ }
+ p.connections.Store(address, c)
+
+ conn, dialOpts, err := dial(ctx, address)
+ c.dialOpts = dialOpts
+ if err != nil {
+ return nil, err
+ }
+ c.framer = codec.New(conn)
+ c.conn = conn
+ c.state = Connected
+ go c.reader()
+ go c.writer()
+ return c.new(ctx, serialNo)
+}
+
+func dial(ctx context.Context, address string) (net.Conn, *DialOptions, error)
{
+ var timeout time.Duration
+ t, ok := ctx.Deadline()
+ if ok {
+ timeout = t.Sub(time.Now())
+ }
+ dialOpts := &DialOptions{
+ Network: "tcp",
+ Address: address,
+ Timeout: timeout,
+ }
+ select {
+ case <-ctx.Done():
+ return nil, dialOpts, ctx.Err()
+ default:
+ }
+ conn, err := dialWithTimeout(dialOpts)
+ return conn, dialOpts, err
+}
+
+func dialWithTimeout(opts *DialOptions) (net.Conn, error) {
+ if len(opts.CACertFile) == 0 {
+ return net.DialTimeout(opts.Network, opts.Address, opts.Timeout)
+ }
+
+ tlsConf := &tls.Config{}
+ if opts.CACertFile == "none" { // 不需要检验服务证书
Review comment:
I have removed this comment because it is obvious.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]