This is an automated email from the ASF dual-hosted git repository.

soulbird pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix-go-plugin-runner.git


The following commit(s) were added to refs/heads/master by this push:
     new b941b73  fix(#114): fix transfering large body failed (#124)
b941b73 is described below

commit b941b73439d0aa98fc2d99c94a195034e616a562
Author: dongjunduo <[email protected]>
AuthorDate: Tue Dec 27 13:55:09 2022 +0800

    fix(#114): fix transfering large body failed (#124)
    
    * fix(#114): fix transfering large body failed
---
 internal/http/request.go              |  8 ++---
 internal/http/request_test.go         | 22 ++++++------
 internal/http/response.go             |  8 ++---
 internal/http/response_test.go        | 22 ++++++------
 internal/server/server.go             |  8 ++---
 internal/server/server_test.go        |  2 +-
 internal/util/msg.go                  | 25 +++++++++++++
 internal/util/{msg.go => msg_test.go} | 67 +++++++++++++++--------------------
 8 files changed, 88 insertions(+), 74 deletions(-)

diff --git a/internal/http/request.go b/internal/http/request.go
index 561e1b7..5c00172 100644
--- a/internal/http/request.go
+++ b/internal/http/request.go
@@ -366,19 +366,19 @@ func (r *Request) askExtraInfo(builder 
*flatbuffers.Builder,
        binary.BigEndian.PutUint32(header, uint32(size))
        header[0] = util.RPCExtraInfo
 
-       n, err := c.Write(header)
+       n, err := util.WriteBytes(c, header, len(header))
        if err != nil {
                util.WriteErr(n, err)
                return nil, common.ErrConnClosed
        }
 
-       n, err = c.Write(out)
+       n, err = util.WriteBytes(c, out, size)
        if err != nil {
                util.WriteErr(n, err)
                return nil, common.ErrConnClosed
        }
 
-       n, err = c.Read(header)
+       n, err = util.ReadBytes(c, header, util.HeaderLen)
        if util.ReadErr(n, err, util.HeaderLen) {
                return nil, common.ErrConnClosed
        }
@@ -390,7 +390,7 @@ func (r *Request) askExtraInfo(builder *flatbuffers.Builder,
        log.Infof("receive rpc type: %d data length: %d", ty, length)
 
        buf := make([]byte, length)
-       n, err = c.Read(buf)
+       n, err = util.ReadBytes(c, buf, int(length))
        if util.ReadErr(n, err, int(length)) {
                return nil, common.ErrConnClosed
        }
diff --git a/internal/http/request_test.go b/internal/http/request_test.go
index 51fe954..d6b58ed 100644
--- a/internal/http/request_test.go
+++ b/internal/http/request_test.go
@@ -305,7 +305,7 @@ func TestVar(t *testing.T) {
 
        go func() {
                header := make([]byte, util.HeaderLen)
-               n, err := sc.Read(header)
+               n, err := util.ReadBytes(sc, header, util.HeaderLen)
                if util.ReadErr(n, err, util.HeaderLen) {
                        return
                }
@@ -316,7 +316,7 @@ func TestVar(t *testing.T) {
                length := binary.BigEndian.Uint32(header)
 
                buf := make([]byte, length)
-               n, err = sc.Read(buf)
+               n, err = util.ReadBytes(sc, buf, int(length))
                if util.ReadErr(n, err, int(length)) {
                        return
                }
@@ -336,13 +336,13 @@ func TestVar(t *testing.T) {
                binary.BigEndian.PutUint32(header, uint32(size))
                header[0] = util.RPCExtraInfo
 
-               n, err = sc.Write(header)
+               n, err = util.WriteBytes(sc, header, len(header))
                if err != nil {
                        util.WriteErr(n, err)
                        return
                }
 
-               n, err = sc.Write(out)
+               n, err = util.WriteBytes(sc, out, size)
                if err != nil {
                        util.WriteErr(n, err)
                        return
@@ -365,7 +365,7 @@ func TestVar_FailedToSendExtraInfoReq(t *testing.T) {
 
        go func() {
                header := make([]byte, util.HeaderLen)
-               n, err := sc.Read(header)
+               n, err := util.ReadBytes(sc, header, util.HeaderLen)
                if util.ReadErr(n, err, util.HeaderLen) {
                        return
                }
@@ -385,7 +385,7 @@ func TestVar_FailedToReadExtraInfoResp(t *testing.T) {
 
        go func() {
                header := make([]byte, util.HeaderLen)
-               n, err := sc.Read(header)
+               n, err := util.ReadBytes(sc, header, util.HeaderLen)
                if util.ReadErr(n, err, util.HeaderLen) {
                        return
                }
@@ -396,7 +396,7 @@ func TestVar_FailedToReadExtraInfoResp(t *testing.T) {
                length := binary.BigEndian.Uint32(header)
 
                buf := make([]byte, length)
-               n, err = sc.Read(buf)
+               n, err = util.ReadBytes(sc, buf, int(length))
                if util.ReadErr(n, err, int(length)) {
                        return
                }
@@ -458,7 +458,7 @@ func TestBody(t *testing.T) {
 
        go func() {
                header := make([]byte, util.HeaderLen)
-               n, err := sc.Read(header)
+               n, err := util.ReadBytes(sc, header, util.HeaderLen)
                if util.ReadErr(n, err, util.HeaderLen) {
                        return
                }
@@ -469,7 +469,7 @@ func TestBody(t *testing.T) {
                length := binary.BigEndian.Uint32(header)
 
                buf := make([]byte, length)
-               n, err = sc.Read(buf)
+               n, err = util.ReadBytes(sc, buf, int(length))
                if util.ReadErr(n, err, int(length)) {
                        return
                }
@@ -488,13 +488,13 @@ func TestBody(t *testing.T) {
                binary.BigEndian.PutUint32(header, uint32(size))
                header[0] = util.RPCExtraInfo
 
-               n, err = sc.Write(header)
+               n, err = util.WriteBytes(sc, header, len(header))
                if err != nil {
                        util.WriteErr(n, err)
                        return
                }
 
-               n, err = sc.Write(out)
+               n, err = util.WriteBytes(sc, out, size)
                if err != nil {
                        util.WriteErr(n, err)
                        return
diff --git a/internal/http/response.go b/internal/http/response.go
index 7c97f42..7d7870e 100644
--- a/internal/http/response.go
+++ b/internal/http/response.go
@@ -73,19 +73,19 @@ func (r *Response) askExtraInfo(builder 
*flatbuffers.Builder,
        binary.BigEndian.PutUint32(header, uint32(size))
        header[0] = util.RPCExtraInfo
 
-       n, err := c.Write(header)
+       n, err := util.WriteBytes(c, header, len(header))
        if err != nil {
                util.WriteErr(n, err)
                return nil, common.ErrConnClosed
        }
 
-       n, err = c.Write(out)
+       n, err = util.WriteBytes(c, out, size)
        if err != nil {
                util.WriteErr(n, err)
                return nil, common.ErrConnClosed
        }
 
-       n, err = c.Read(header)
+       n, err = util.ReadBytes(c, header, util.HeaderLen)
        if util.ReadErr(n, err, util.HeaderLen) {
                return nil, common.ErrConnClosed
        }
@@ -97,7 +97,7 @@ func (r *Response) askExtraInfo(builder *flatbuffers.Builder,
        log.Infof("receive rpc type: %d data length: %d", ty, length)
 
        buf := make([]byte, length)
-       n, err = c.Read(buf)
+       n, err = util.ReadBytes(c, buf, int(length))
        if util.ReadErr(n, err, int(length)) {
                return nil, common.ErrConnClosed
        }
diff --git a/internal/http/response_test.go b/internal/http/response_test.go
index 128bf73..4062dd6 100644
--- a/internal/http/response_test.go
+++ b/internal/http/response_test.go
@@ -202,7 +202,7 @@ func TestResponse_Var(t *testing.T) {
 
        go func() {
                header := make([]byte, util.HeaderLen)
-               n, err := sc.Read(header)
+               n, err := util.ReadBytes(sc, header, util.HeaderLen)
                if util.ReadErr(n, err, util.HeaderLen) {
                        return
                }
@@ -213,7 +213,7 @@ func TestResponse_Var(t *testing.T) {
                length := binary.BigEndian.Uint32(header)
 
                buf := make([]byte, length)
-               n, err = sc.Read(buf)
+               n, err = util.ReadBytes(sc, buf, int(length))
                if util.ReadErr(n, err, int(length)) {
                        return
                }
@@ -233,13 +233,13 @@ func TestResponse_Var(t *testing.T) {
                binary.BigEndian.PutUint32(header, uint32(size))
                header[0] = util.RPCExtraInfo
 
-               n, err = sc.Write(header)
+               n, err = util.WriteBytes(sc, header, len(header))
                if err != nil {
                        util.WriteErr(n, err)
                        return
                }
 
-               n, err = sc.Write(out)
+               n, err = util.WriteBytes(sc, out, size)
                if err != nil {
                        util.WriteErr(n, err)
                        return
@@ -262,7 +262,7 @@ func TestResponse_Var_FailedToSendExtraInfoReq(t 
*testing.T) {
 
        go func() {
                header := make([]byte, util.HeaderLen)
-               n, err := sc.Read(header)
+               n, err := util.ReadBytes(sc, header, util.HeaderLen)
                if util.ReadErr(n, err, util.HeaderLen) {
                        return
                }
@@ -282,7 +282,7 @@ func TestResponse_FailedToReadExtraInfoResp(t *testing.T) {
 
        go func() {
                header := make([]byte, util.HeaderLen)
-               n, err := sc.Read(header)
+               n, err := util.ReadBytes(sc, header, util.HeaderLen)
                if util.ReadErr(n, err, util.HeaderLen) {
                        return
                }
@@ -293,7 +293,7 @@ func TestResponse_FailedToReadExtraInfoResp(t *testing.T) {
                length := binary.BigEndian.Uint32(header)
 
                buf := make([]byte, length)
-               n, err = sc.Read(buf)
+               n, err = util.ReadBytes(sc, buf, int(length))
                if util.ReadErr(n, err, int(length)) {
                        return
                }
@@ -314,7 +314,7 @@ func TestRead(t *testing.T) {
 
        go func() {
                header := make([]byte, util.HeaderLen)
-               n, err := sc.Read(header)
+               n, err := util.ReadBytes(sc, header, util.HeaderLen)
                if util.ReadErr(n, err, util.HeaderLen) {
                        return
                }
@@ -325,7 +325,7 @@ func TestRead(t *testing.T) {
                length := binary.BigEndian.Uint32(header)
 
                buf := make([]byte, length)
-               n, err = sc.Read(buf)
+               n, err = util.ReadBytes(sc, buf, int(length))
                if util.ReadErr(n, err, int(length)) {
                        return
                }
@@ -344,13 +344,13 @@ func TestRead(t *testing.T) {
                binary.BigEndian.PutUint32(header, uint32(size))
                header[0] = util.RPCExtraInfo
 
-               n, err = sc.Write(header)
+               n, err = util.WriteBytes(sc, header, len(header))
                if err != nil {
                        util.WriteErr(n, err)
                        return
                }
 
-               n, err = sc.Write(out)
+               n, err = util.WriteBytes(sc, out, size)
                if err != nil {
                        util.WriteErr(n, err)
                        return
diff --git a/internal/server/server.go b/internal/server/server.go
index 5ac291f..e8a0d8c 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -117,7 +117,7 @@ func handleConn(c net.Conn) {
 
        header := make([]byte, util.HeaderLen)
        for {
-               n, err := c.Read(header)
+               n, err := util.ReadBytes(c, header, util.HeaderLen)
                if util.ReadErr(n, err, util.HeaderLen) {
                        break
                }
@@ -131,7 +131,7 @@ func handleConn(c net.Conn) {
                log.Infof("receive rpc type: %d data length: %d", ty, length)
 
                buf := make([]byte, length)
-               n, err = c.Read(buf)
+               n, err = util.ReadBytes(c, buf, int(length))
                if util.ReadErr(n, err, int(length)) {
                        break
                }
@@ -142,13 +142,13 @@ func handleConn(c net.Conn) {
                binary.BigEndian.PutUint32(header, uint32(size))
                header[0] = respTy
 
-               n, err = c.Write(header)
+               n, err = util.WriteBytes(c, header, len(header))
                if err != nil {
                        util.WriteErr(n, err)
                        break
                }
 
-               n, err = c.Write(out)
+               n, err = util.WriteBytes(c, out, size)
                if err != nil {
                        util.WriteErr(n, err)
                        break
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index b854973..a976b18 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -121,7 +121,7 @@ func TestRun(t *testing.T) {
                conn, err := net.DialTimeout("unix", addr[len("unix:"):], 
1*time.Second)
                assert.NotNil(t, conn, err)
                defer conn.Close()
-               conn.Write(c.header)
+               util.WriteBytes(conn, c.header, len(c.header))
        }
 
        syscall.Kill(syscall.Getpid(), syscall.SIGINT)
diff --git a/internal/util/msg.go b/internal/util/msg.go
index f95e2d9..78121ff 100644
--- a/internal/util/msg.go
+++ b/internal/util/msg.go
@@ -20,6 +20,7 @@ package util
 import (
        "fmt"
        "io"
+       "net"
 
        flatbuffers "github.com/google/flatbuffers/go"
 
@@ -65,3 +66,27 @@ func WriteErr(n int, err error) {
                log.Errorf("write: %s", err)
        }
 }
+
+func ReadBytes(c net.Conn, b []byte, n int) (int, error) {
+       l := 0
+       for l < n {
+               tmp, err := c.Read(b[l:])
+               if err != nil {
+                       return l + tmp, err
+               }
+               l += tmp
+       }
+       return l, nil
+}
+
+func WriteBytes(c net.Conn, b []byte, n int) (int, error) {
+       l := 0
+       for l < n {
+               tmp, err := c.Write(b[l:])
+               if err != nil {
+                       return l + tmp, err
+               }
+               l += tmp
+       }
+       return l, nil
+}
diff --git a/internal/util/msg.go b/internal/util/msg_test.go
similarity index 53%
copy from internal/util/msg.go
copy to internal/util/msg_test.go
index f95e2d9..24a526e 100644
--- a/internal/util/msg.go
+++ b/internal/util/msg_test.go
@@ -18,50 +18,39 @@
 package util
 
 import (
-       "fmt"
-       "io"
+       "math/rand"
+       "net"
+       "testing"
+       "time"
 
-       flatbuffers "github.com/google/flatbuffers/go"
-
-       "github.com/apache/apisix-go-plugin-runner/pkg/log"
+       "github.com/stretchr/testify/assert"
 )
 
-const (
-       HeaderLen   = 4
-       MaxDataSize = 2<<24 - 1
-)
+func TestReadAndWriteBytes(t *testing.T) {
+       path := "/tmp/test.sock"
+       server, err := net.Listen("unix", path)
+       assert.NoError(t, err)
+       defer server.Close()
 
-const (
-       RPCError = iota
-       RPCPrepareConf
-       RPCHTTPReqCall
-       RPCExtraInfo
-       RPCHTTPRespCall
-)
-
-type RPCResult struct {
-       Err     error
-       Builder *flatbuffers.Builder
-}
+       // transfer large enough data
+       n := 10000000
 
-// Use struct if the result is not only []byte
-type ExtraInfoResult []byte
-
-func ReadErr(n int, err error, required int) bool {
-       if 0 < n && n < required {
-               err = fmt.Errorf("truncated, only get the first %d bytes", n)
-       }
-       if err != nil {
-               if err != io.EOF {
-                       log.Errorf("read: %s", err)
-               }
-               return true
+       const letterBytes = 
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+       in := make([]byte, n)
+       for i := range in {
+               in[i] = letterBytes[rand.Intn(len(letterBytes))]
        }
-       return false
-}
 
-func WriteErr(n int, err error) {
-       if err != nil {
-               log.Errorf("write: %s", err)
-       }
+       go func() {
+               client, err := net.DialTimeout("unix", path, 1*time.Second)
+               assert.NoError(t, err)
+               defer client.Close()
+               WriteBytes(client, in, len(in))
+       }()
+
+       fd, err := server.Accept()
+       assert.NoError(t, err)
+       out := make([]byte, n)
+       ReadBytes(fd, out, n)
+       assert.Equal(t, in, out)
 }

Reply via email to