This is an automated email from the ASF dual-hosted git repository. gosonzhang pushed a commit to branch INLONG-25 in repository https://gitbox.apache.org/repos/asf/incubator-inlong.git
commit 058ad9504843b49ce2852f420ad890e1aad8b78f Author: Zijie Lu <[email protected]> AuthorDate: Fri Apr 30 17:00:52 2021 +0800 Address review comments Signed-off-by: Zijie Lu <[email protected]> --- .../tubemq-client-go/codec/codec.go | 63 ++++++----- .../{pool => multiplexed}/multiplexed.go | 119 +++++++++++---------- .../{pool => multiplexed}/multlplexed_test.go | 6 +- 3 files changed, 99 insertions(+), 89 deletions(-) diff --git a/tubemq-client-twins/tubemq-client-go/codec/codec.go b/tubemq-client-twins/tubemq-client-go/codec/codec.go index f66fac9..a6f3fee 100644 --- a/tubemq-client-twins/tubemq-client-go/codec/codec.go +++ b/tubemq-client-twins/tubemq-client-go/codec/codec.go @@ -36,40 +36,49 @@ const ( beginTokenLen uint32 = 4 ) -type Framer struct { +type TransportResponse interface { + GetSerialNo() uint32 + GetResponseBuf() []byte +} + +type Decoder interface { + Decode() (TransportResponse, error) +} + +type TubeMQDecoder struct { reader io.Reader msg []byte } -func New(reader io.Reader) *Framer { +func New(reader io.Reader) *TubeMQDecoder { bufferReader := bufio.NewReaderSize(reader, maxBufferSize) - return &Framer{ + return &TubeMQDecoder{ msg: make([]byte, defaultMsgSize), reader: bufferReader, } } -func (f *Framer) Decode() (*FrameResponse, error) { - num, err := io.ReadFull(f.reader, f.msg[:frameHeadLen]) +func (t *TubeMQDecoder) Decode() (TransportResponse, error) { + num, err := io.ReadFull(t.reader, t.msg[:frameHeadLen]) if err != nil { return nil, err } if num != int(frameHeadLen) { return nil, errors.New("framer: read frame header num invalid") } - token := binary.BigEndian.Uint32(f.msg[:beginTokenLen]) + token := binary.BigEndian.Uint32(t.msg[:beginTokenLen]) if token != RPCProtocolBeginToken { return nil, errors.New("framer: read framer rpc protocol begin token not match") } - num, err = io.ReadFull(f.reader, f.msg[frameHeadLen:frameHeadLen+listSizeLen]) + num, err = io.ReadFull(t.reader, t.msg[frameHeadLen:frameHeadLen+listSizeLen]) if num != int(listSizeLen) { return nil, errors.New("framer: read invalid list size num") } - listSize := binary.BigEndian.Uint32(f.msg[frameHeadLen : frameHeadLen+listSizeLen]) + listSize := binary.BigEndian.Uint32(t.msg[frameHeadLen : frameHeadLen+listSizeLen]) totalLen := int(frameHeadLen) size := make([]byte, 4) for i := 0; i < int(listSize); i++ { - n, err := io.ReadFull(f.reader, size) + n, err := io.ReadFull(t.reader, size) if err != nil { return nil, err } @@ -78,13 +87,13 @@ func (f *Framer) Decode() (*FrameResponse, error) { } s := int(binary.BigEndian.Uint32(size)) - if totalLen+s > len(f.msg) { - data := f.msg[:totalLen] - f.msg = make([]byte, totalLen+s) - copy(f.msg, data[:]) + if totalLen+s > len(t.msg) { + data := t.msg[:totalLen] + t.msg = make([]byte, totalLen+s) + copy(t.msg, data[:]) } - num, err = io.ReadFull(f.reader, f.msg[totalLen:totalLen+s]) + num, err = io.ReadFull(t.reader, t.msg[totalLen:totalLen+s]) if err != nil { return nil, err } @@ -94,31 +103,29 @@ func (f *Framer) Decode() (*FrameResponse, error) { totalLen += s } - data := make([]byte, totalLen - int(frameHeadLen)) - copy(data, f.msg[frameHeadLen:totalLen]) + data := make([]byte, totalLen-int(frameHeadLen)) + copy(data, t.msg[frameHeadLen:totalLen]) - return &FrameResponse{ - serialNo: binary.BigEndian.Uint32(f.msg[beginTokenLen : beginTokenLen+serialNoLen]), + return TubeMQResponse{ + serialNo: binary.BigEndian.Uint32(t.msg[beginTokenLen : beginTokenLen+serialNoLen]), responseBuf: data, }, nil } -type FrameRequest struct { - requestID uint32 - req []byte +type TubeMQRequest struct { + serialNo uint32 + req []byte } -type FrameResponse struct { +type TubeMQResponse struct { serialNo uint32 responseBuf []byte } -func (f *FrameResponse) GetSerialNo() uint32 { - return f.serialNo +func (t TubeMQResponse) GetSerialNo() uint32 { + return t.serialNo } -func (f *FrameResponse) GetResponseBuf() []byte { - return f.responseBuf +func (t TubeMQResponse) GetResponseBuf() []byte { + return t.responseBuf } - -type Codec struct{} diff --git a/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go b/tubemq-client-twins/tubemq-client-go/multiplexed/multiplexed.go similarity index 90% rename from tubemq-client-twins/tubemq-client-go/pool/multiplexed.go rename to tubemq-client-twins/tubemq-client-go/multiplexed/multiplexed.go index d5a60b1..050b608 100644 --- a/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go +++ b/tubemq-client-twins/tubemq-client-go/multiplexed/multiplexed.go @@ -15,7 +15,11 @@ * limitations under the License. */ -package pool +// Package multiplexed defines the multiplexed connection pool for sending +// request and receiving response. After receiving the response, it will +// be decoded and returned to the client. It is used for the underlying communication +// with TubeMQ. +package multiplexed import ( "context" @@ -30,8 +34,6 @@ import ( "github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go/codec" ) -var DefaultMultiplexedPool = New() - var ( // ErrConnClosed indicates that the connection is closed ErrConnClosed = errors.New("connection is closed") @@ -50,15 +52,60 @@ const ( Closed ) -var queueSize = 10000 +const queueSize = 10000 + +type Pool struct { + connections *sync.Map +} -func New() *Multiplexed { - m := &Multiplexed{ +func NewPool() *Pool { + m := &Pool{ connections: new(sync.Map), } return m } + +func (p *Pool) Get(ctx context.Context, address string, serialNo uint32) (*MultiplexedConnection, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if v, ok := p.connections.Load(address); ok { + if c, ok := v.(*Connection); ok { + return c.new(ctx, serialNo) + } + return nil, ErrAssertConnectionFail + } + + c := &Connection{ + address: address, + connections: make(map[uint32]*MultiplexedConnection), + done: make(chan struct{}), + mDone: make(chan struct{}), + state: Initial, + } + c.buffer = &writerBuffer{ + buffer: make(chan []byte, queueSize), + done: c.done, + } + p.connections.Store(address, c) + + conn, dialOpts, err := dial(ctx, address) + c.dialOpts = dialOpts + if err != nil { + return nil, err + } + c.decoder = codec.New(conn) + c.conn = conn + c.state = Connected + go c.reader() + go c.writer() + return c.new(ctx, serialNo) +} + type writerBuffer struct { buffer chan []byte done <-chan struct{} @@ -75,7 +122,7 @@ func (w *writerBuffer) get() ([]byte, error) { type recvReader struct { ctx context.Context - recv chan *codec.FrameResponse + recv chan codec.TransportResponse } type MultiplexedConnection struct { @@ -93,7 +140,7 @@ func (mc *MultiplexedConnection) Write(b []byte) error { return nil } -func (mc *MultiplexedConnection) Read() (*codec.FrameResponse, error) { +func (mc *MultiplexedConnection) Read() (codec.TransportResponse, error) { select { case <-mc.reader.ctx.Done(): mc.conn.remove(mc.serialNo) @@ -111,7 +158,7 @@ func (mc *MultiplexedConnection) Read() (*codec.FrameResponse, error) { } } -func (mc *MultiplexedConnection) recv(rsp *codec.FrameResponse) { +func (mc *MultiplexedConnection) recv(rsp *codec.TubeMQResponse) { mc.reader.recv <- rsp mc.conn.remove(rsp.GetSerialNo()) } @@ -131,14 +178,14 @@ type Connection struct { address string mu sync.RWMutex connections map[uint32]*MultiplexedConnection - framer *codec.Framer + decoder codec.Decoder conn net.Conn done chan struct{} mDone chan struct{} buffer *writerBuffer dialOpts *DialOptions state int - multiplexed *Multiplexed + multiplexed *Pool } func (c *Connection) new(ctx context.Context, serialNo uint32) (*MultiplexedConnection, error) { @@ -154,7 +201,7 @@ func (c *Connection) new(ctx context.Context, serialNo uint32) (*MultiplexedConn done: c.mDone, reader: &recvReader{ ctx: ctx, - recv: make(chan *codec.FrameResponse, 1), + recv: make(chan codec.TransportResponse, 1), }, } @@ -204,7 +251,7 @@ func (c *Connection) reconnect() error { } c.done = make(chan struct{}) c.conn = conn - c.framer = codec.New(conn) + c.decoder = codec.New(conn) c.buffer.done = c.done c.state = Connected c.err = nil @@ -273,7 +320,7 @@ func (c *Connection) reader() { return default: } - rsp, err := c.framer.Decode() + rsp, err := c.decoder.Decode() if err != nil { lastErr = err break @@ -291,50 +338,6 @@ func (c *Connection) reader() { c.close(lastErr, c.done) } -type Multiplexed struct { - connections *sync.Map -} - -func (p *Multiplexed) Get(ctx context.Context, address string, serialNo uint32) (*MultiplexedConnection, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - if v, ok := p.connections.Load(address); ok { - if c, ok := v.(*Connection); ok { - return c.new(ctx, serialNo) - } - return nil, ErrAssertConnectionFail - } - - c := &Connection{ - address: address, - connections: make(map[uint32]*MultiplexedConnection), - done: make(chan struct{}), - mDone: make(chan struct{}), - state: Initial, - } - c.buffer = &writerBuffer{ - buffer: make(chan []byte, queueSize), - done: c.done, - } - p.connections.Store(address, c) - - conn, dialOpts, err := dial(ctx, address) - c.dialOpts = dialOpts - if err != nil { - return nil, err - } - c.framer = codec.New(conn) - c.conn = conn - c.state = Connected - go c.reader() - go c.writer() - return c.new(ctx, serialNo) -} - func dial(ctx context.Context, address string) (net.Conn, *DialOptions, error) { var timeout time.Duration t, ok := ctx.Deadline() diff --git a/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go b/tubemq-client-twins/tubemq-client-go/multiplexed/multlplexed_test.go similarity index 98% rename from tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go rename to tubemq-client-twins/tubemq-client-go/multiplexed/multlplexed_test.go index e136ebd..4584607 100644 --- a/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go +++ b/tubemq-client-twins/tubemq-client-go/multiplexed/multlplexed_test.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package pool +package multiplexed import ( "bytes" @@ -92,7 +92,7 @@ func TestBasicMultiplexed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - m := New() + m := NewPool() mc, err := m.Get(ctx, address, serialNo) body := []byte("hello world") @@ -109,7 +109,7 @@ func TestBasicMultiplexed(t *testing.T) { func TestConcurrentMultiplexed(t *testing.T) { count := 1000 - m := New() + m := NewPool() wg := sync.WaitGroup{} wg.Add(count) for i := 0; i < count; i++ {
