This is an automated email from the ASF dual-hosted git repository. baze pushed a commit to branch 1.4 in repository https://gitbox.apache.org/repos/asf/dubbo-go.git
commit bcb01004337eb1a3baf4390f905f649e4477ad85 Author: Patrick <[email protected]> AuthorDate: Wed Apr 1 19:05:58 2020 +0800 optimize header transmit in RestClient and RestServer --- protocol/rest/client/client_impl/resty_client.go | 11 +++--- protocol/rest/client/rest_client.go | 5 +-- protocol/rest/rest_invoker.go | 23 ++++++++--- protocol/rest/server/rest_server.go | 4 ++ .../rest/server/server_impl/go_restful_server.go | 44 +++++++++++++++++++++- 5 files changed, 72 insertions(+), 15 deletions(-) diff --git a/protocol/rest/client/client_impl/resty_client.go b/protocol/rest/client/client_impl/resty_client.go index af9637e..9e0c80c 100644 --- a/protocol/rest/client/client_impl/resty_client.go +++ b/protocol/rest/client/client_impl/resty_client.go @@ -66,20 +66,19 @@ func NewRestyClient(restOption *client.RestOptions) client.RestClient { } func (rc *RestyClient) Do(restRequest *client.RestClientRequest, res interface{}) error { - r, err := rc.client.R(). - SetHeader("Content-Type", restRequest.Consumes). - SetHeader("Accept", restRequest.Produces). + req := rc.client.R() + req.Header = restRequest.Header + resp, err := req. SetPathParams(restRequest.PathParams). SetQueryParams(restRequest.QueryParams). - SetHeaders(restRequest.Headers). SetBody(restRequest.Body). SetResult(res). Execute(restRequest.Method, "http://"+path.Join(restRequest.Location, restRequest.Path)) if err != nil { return perrors.WithStack(err) } - if r.IsError() { - return perrors.New(r.String()) + if resp.IsError() { + return perrors.New(resp.String()) } return nil } diff --git a/protocol/rest/client/rest_client.go b/protocol/rest/client/rest_client.go index 3acccb5..5be4bb3 100644 --- a/protocol/rest/client/rest_client.go +++ b/protocol/rest/client/rest_client.go @@ -18,6 +18,7 @@ package client import ( + "net/http" "time" ) @@ -27,15 +28,13 @@ type RestOptions struct { } type RestClientRequest struct { + Header http.Header Location string Path string - Produces string - Consumes string Method string PathParams map[string]string QueryParams map[string]string Body interface{} - Headers map[string]string } type RestClient interface { diff --git a/protocol/rest/rest_invoker.go b/protocol/rest/rest_invoker.go index c8e3fea..121d121 100644 --- a/protocol/rest/rest_invoker.go +++ b/protocol/rest/rest_invoker.go @@ -20,6 +20,7 @@ package rest import ( "context" "fmt" + "net/http" ) import ( @@ -56,7 +57,7 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocatio body interface{} pathParams map[string]string queryParams map[string]string - headers map[string]string + header http.Header err error ) if methodConfig == nil { @@ -71,7 +72,7 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocatio result.Err = err return &result } - if headers, err = restStringMapTransform(methodConfig.HeadersMap, inv.Arguments()); err != nil { + if header, err = getRestHttpHeader(methodConfig, inv.Arguments()); err != nil { result.Err = err return &result } @@ -80,14 +81,12 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocatio } req := &client.RestClientRequest{ Location: ri.GetUrl().Location, - Produces: methodConfig.Produces, - Consumes: methodConfig.Consumes, Method: methodConfig.MethodType, Path: methodConfig.Path, PathParams: pathParams, QueryParams: queryParams, Body: body, - Headers: headers, + Header: header, } result.Err = ri.client.Do(req, inv.Reply()) if result.Err == nil { @@ -106,3 +105,17 @@ func restStringMapTransform(paramsMap map[int]string, args []interface{}) (map[s } return resMap, nil } + +func getRestHttpHeader(methodConfig *config.RestMethodConfig, args []interface{}) (http.Header, error) { + header := http.Header{} + headersMap := methodConfig.HeadersMap + header.Set("Content-Type", methodConfig.Consumes) + header.Set("Accept", methodConfig.Produces) + for k, v := range headersMap { + if k >= len(args) || k < 0 { + return nil, perrors.Errorf("[Rest Invoke] Index %v is out of bundle", k) + } + header.Set(v, fmt.Sprint(args[k])) + } + return header, nil +} diff --git a/protocol/rest/server/rest_server.go b/protocol/rest/server/rest_server.go index b7eb555..7fb0560 100644 --- a/protocol/rest/server/rest_server.go +++ b/protocol/rest/server/rest_server.go @@ -46,6 +46,7 @@ type RestServer interface { // RestServerRequest interface type RestServerRequest interface { + RawRequest() *http.Request PathParameter(name string) string PathParameters() map[string]string QueryParameter(name string) string @@ -57,6 +58,9 @@ type RestServerRequest interface { // RestServerResponse interface type RestServerResponse interface { + Header() http.Header + Write([]byte) (int, error) + WriteHeader(statusCode int) WriteError(httpStatus int, err error) (writeErr error) WriteEntity(value interface{}) error } diff --git a/protocol/rest/server/server_impl/go_restful_server.go b/protocol/rest/server/server_impl/go_restful_server.go index 81043c8..9163d3a 100644 --- a/protocol/rest/server/server_impl/go_restful_server.go +++ b/protocol/rest/server/server_impl/go_restful_server.go @@ -79,7 +79,7 @@ func (grs *GoRestfulServer) Start(url common.URL) { func (grs *GoRestfulServer) Deploy(restMethodConfig *config.RestMethodConfig, routeFunc func(request server.RestServerRequest, response server.RestServerResponse)) { ws := new(restful.WebService) rf := func(req *restful.Request, resp *restful.Response) { - routeFunc(req, resp) + routeFunc(NewGoRestfulRequestAdapter(req), resp) } ws.Path(restMethodConfig.Path). Produces(strings.Split(restMethodConfig.Produces, ",")...). @@ -116,3 +116,45 @@ func GetNewGoRestfulServer() server.RestServer { func AddGoRestfulServerFilter(filterFuc restful.FilterFunction) { filterSlice = append(filterSlice, filterFuc) } + +// Adapter about RestServerRequest +type GoRestfulRequestAdapter struct { + server.RestServerRequest + request *restful.Request +} + +func NewGoRestfulRequestAdapter(request *restful.Request) *GoRestfulRequestAdapter { + return &GoRestfulRequestAdapter{request: request} +} + +func (grra *GoRestfulRequestAdapter) RawRequest() *http.Request { + return grra.request.Request +} + +func (grra *GoRestfulRequestAdapter) PathParameter(name string) string { + return grra.request.PathParameter(name) +} + +func (grra *GoRestfulRequestAdapter) PathParameters() map[string]string { + return grra.request.PathParameters() +} + +func (grra *GoRestfulRequestAdapter) QueryParameter(name string) string { + return grra.request.QueryParameter(name) +} + +func (grra *GoRestfulRequestAdapter) QueryParameters(name string) []string { + return grra.request.QueryParameters(name) +} + +func (grra *GoRestfulRequestAdapter) BodyParameter(name string) (string, error) { + return grra.request.BodyParameter(name) +} + +func (grra *GoRestfulRequestAdapter) HeaderParameter(name string) string { + return grra.request.HeaderParameter(name) +} + +func (grra *GoRestfulRequestAdapter) ReadEntity(entityPointer interface{}) error { + return grra.request.ReadEntity(entityPointer) +}
