This is an automated email from the ASF dual-hosted git repository. liujun pushed a commit to branch feature-triple in repository https://gitbox.apache.org/repos/asf/dubbo-go.git
The following commit(s) were added to refs/heads/feature-triple by this push: new 5e3115277 feat: sort out triple logic and fix comments (#2454) 5e3115277 is described below commit 5e311527786b88e72ae733cf12d993e29447db96 Author: Scout Wang <33331974+dmwangn...@users.noreply.github.com> AuthorDate: Thu Nov 9 21:43:12 2023 +0800 feat: sort out triple logic and fix comments (#2454) --- client/action.go | 3 + common/constant/key.go | 1 + common/url.go | 28 ++- common/url_test.go | 24 +- protocol/triple/client.go | 16 +- protocol/triple/dubbo3_invoker.go | 271 +++++++++++++++++++++ .../proto/triple_gen/greettriple/greet.triple.go | 13 +- protocol/triple/triple.go | 15 +- protocol/triple/triple_protocol/client.go | 40 ++- protocol/triple/triple_protocol/client_stream.go | 23 +- protocol/triple/triple_protocol/code.go | 13 +- .../triple/triple_protocol/duplex_http_call.go | 3 +- protocol/triple/triple_protocol/envelope.go | 20 +- protocol/triple/triple_protocol/error.go | 14 +- protocol/triple/triple_protocol/error_writer.go | 2 +- protocol/triple/triple_protocol/handler.go | 28 +-- protocol/triple/triple_protocol/handler_compat.go | 10 +- protocol/triple/triple_protocol/option.go | 91 ++----- protocol/triple/triple_protocol/protocol.go | 6 +- protocol/triple/triple_protocol/protocol_grpc.go | 14 +- .../triple/triple_protocol/protocol_grpc_test.go | 48 +++- protocol/triple/triple_protocol/protocol_triple.go | 12 +- protocol/triple/triple_protocol/triple.go | 32 ++- protocol/triple/triple_protocol/triple_ext_test.go | 202 ++++++++++++--- protocol/triple/triple_test.go | 220 +++++++++++++---- 25 files changed, 876 insertions(+), 273 deletions(-) diff --git a/client/action.go b/client/action.go index 3e5db00ea..83d72d5e0 100644 --- a/client/action.go +++ b/client/action.go @@ -125,6 +125,9 @@ func (opts *ClientOptions) refer(srv common.RPCService, info *ClientInfo) { common.WithParamsValue(constant.BeanNameKey, opts.id), common.WithParamsValue(constant.MetadataTypeKey, opts.metaDataType), ) + if info != nil { + cfgURL.SetAttribute(constant.ClientInfoKey, info) + } if ref.ForceTag { cfgURL.AddParam(constant.ForceUseTag, "true") diff --git a/common/constant/key.go b/common/constant/key.go index 63e626f73..8a32c50d6 100644 --- a/common/constant/key.go +++ b/common/constant/key.go @@ -145,6 +145,7 @@ const ( CallHTTP = "http" CallHTTP2 = "http2" ServiceInfoKey = "service-info" + ClientInfoKey = "client-info" ) const ( diff --git a/common/url.go b/common/url.go index 7ec0a108d..73e80e455 100644 --- a/common/url.go +++ b/common/url.go @@ -114,8 +114,10 @@ type URL struct { Username string Password string Methods []string - // Attributes should not be transported - Attributes map[string]interface{} `hessian:"-"` + + attributesLock sync.RWMutex + // attributes should not be transported + attributes map[string]interface{} `hessian:"-"` // special for registry SubURL *URL attributes sync.Map @@ -235,10 +237,10 @@ func WithToken(token string) Option { // WithAttribute sets attribute for URL func WithAttribute(key string, attribute interface{}) Option { return func(url *URL) { - if url.Attributes == nil { - url.Attributes = make(map[string]interface{}) + if url.attributes == nil { + url.attributes = make(map[string]interface{}) } - url.Attributes[key] = attribute + url.attributes[key] = attribute } } @@ -544,6 +546,22 @@ func (c *URL) SetParam(key string, value string) { c.params.Set(key, value) } +func (c *URL) SetAttribute(key string, value interface{}) { + c.attributesLock.Lock() + defer c.attributesLock.Unlock() + if c.attributes == nil { + c.attributes = make(map[string]interface{}) + } + c.attributes[key] = value +} + +func (c *URL) GetAttribute(key string) (interface{}, bool) { + c.attributesLock.RLock() + defer c.attributesLock.RUnlock() + r, ok := c.attributes[key] + return r, ok +} + // DelParam will delete the given key from the URL func (c *URL) DelParam(key string) { c.paramsLock.Lock() diff --git a/common/url_test.go b/common/url_test.go index 2971e6c45..7c9cdf875 100644 --- a/common/url_test.go +++ b/common/url_test.go @@ -49,7 +49,10 @@ func TestNewURLWithOptions(t *testing.T) { WithPort("8080"), WithMethods(methods), WithParams(params), - WithParamsValue("key2", "value2")) + WithParamsValue("key2", "value2"), + WithAttribute("key3", "value3"), + WithAttribute("key4", "value4"), + ) assert.Equal(t, "/com.test.Service", u.Path) assert.Equal(t, userName, u.Username) assert.Equal(t, password, u.Password) @@ -58,6 +61,7 @@ func TestNewURLWithOptions(t *testing.T) { assert.Equal(t, "8080", u.Port) assert.Equal(t, methods, u.Methods) assert.Equal(t, 2, len(u.params)) + assert.Equal(t, 2, len(u.attributes)) } func TestURL(t *testing.T) { @@ -304,6 +308,24 @@ func TestURLGetMethodParamBool(t *testing.T) { assert.Equal(t, false, v) } +func TestURLGetAttribute(t *testing.T) { + u := URL{} + key := "key" + notExistKey := "not-exist-key" + val := "value" + u.SetAttribute(key, val) + + rawVal, ok := u.GetAttribute(key) + assert.Equal(t, true, ok) + v, ok := rawVal.(string) + assert.Equal(t, true, ok) + assert.Equal(t, val, v) + + rawVal, ok = u.GetAttribute(notExistKey) + assert.Equal(t, false, ok) + assert.Nil(t, rawVal) +} + func TestMergeUrl(t *testing.T) { referenceUrlParams := url.Values{} referenceUrlParams.Set(constant.ClusterKey, "random") diff --git a/protocol/triple/client.go b/protocol/triple/client.go index 1b06a54d6..87799d291 100644 --- a/protocol/triple/client.go +++ b/protocol/triple/client.go @@ -143,14 +143,14 @@ func newClientManager(url *common.URL) (*clientManager, error) { panic(fmt.Sprintf("Unsupported serialization: %s", serialization)) } - // todo:// process timeout - // consumer config client connectTimeout - //connectTimeout := config.GetConsumerConfig().ConnectTimeout + // set timeout + timeout := url.GetParamDuration(constant.TimeoutKey, "") + triClientOpts = append(triClientOpts, tri.WithTimeout(timeout)) // dialOpts = append(dialOpts, // // grpc.WithBlock(), - // // todo config network timeout + // // todo config tracing // grpc.WithTimeout(time.Second*3), // grpc.WithUnaryInterceptor(otgrpc.OpenTracingClientInterceptor(tracer, otgrpc.LogPayloads())), // grpc.WithStreamInterceptor(otgrpc.OpenTracingStreamClientInterceptor(tracer, otgrpc.LogPayloads())), @@ -180,14 +180,6 @@ func newClientManager(url *common.URL) (*clientManager, error) { // tlsFlag = true //} - // todo(DMwangnima): this code fragment would be used to be compatible with old triple client - //key := url.GetParam(constant.InterfaceKey, "") - //conRefs := config.GetConsumerConfig().References - //ref, ok := conRefs[key] - //if !ok { - // panic("no reference") - //} - // todo: set timeout var transport http.RoundTripper callType := url.GetParam(constant.CallHTTPTypeKey, constant.CallHTTP2) switch callType { diff --git a/protocol/triple/dubbo3_invoker.go b/protocol/triple/dubbo3_invoker.go new file mode 100644 index 000000000..0aa0b4f71 --- /dev/null +++ b/protocol/triple/dubbo3_invoker.go @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package triple + +import ( + "context" + "reflect" + "strconv" + "strings" + "sync" + "time" +) + +import ( + "github.com/dubbogo/gost/log/logger" + + "github.com/dubbogo/grpc-go/metadata" + + tripleConstant "github.com/dubbogo/triple/pkg/common/constant" + triConfig "github.com/dubbogo/triple/pkg/config" + "github.com/dubbogo/triple/pkg/triple" + + "github.com/dustin/go-humanize" +) + +import ( + "dubbo.apache.org/dubbo-go/v3/common" + "dubbo.apache.org/dubbo-go/v3/common/constant" + "dubbo.apache.org/dubbo-go/v3/config" + "dubbo.apache.org/dubbo-go/v3/protocol" + invocation_impl "dubbo.apache.org/dubbo-go/v3/protocol/invocation" +) + +// same as dubbo_invoker.go attachmentKey +var attachmentKey = []string{ + constant.InterfaceKey, constant.GroupKey, constant.TokenKey, constant.TimeoutKey, + constant.VersionKey, tripleConstant.TripleServiceGroup, tripleConstant.TripleServiceVersion, +} + +// DubboInvoker is implement of protocol.Invoker, a dubboInvoker refer to one service and ip. +type DubboInvoker struct { + protocol.BaseInvoker + // the net layer client, it is focus on network communication. + client *triple.TripleClient + // quitOnce is used to make sure DubboInvoker is only destroyed once + quitOnce sync.Once + // timeout for service(interface) level. + timeout time.Duration + // clientGuard is the client lock of dubbo invoker + clientGuard *sync.RWMutex +} + +// NewDubbo3Invoker constructor +func NewDubbo3Invoker(url *common.URL) (*DubboInvoker, error) { + rt := config.GetConsumerConfig().RequestTimeout + + timeout := url.GetParamDuration(constant.TimeoutKey, rt) + // for triple pb serialization. The bean name from provider is the provider reference key, + // which can't locate the target consumer stub, so we use interface key.. + interfaceKey := url.GetParam(constant.InterfaceKey, "") + consumerService := config.GetConsumerServiceByInterfaceName(interfaceKey) + + dubboSerializerType := url.GetParam(constant.SerializationKey, constant.ProtobufSerialization) + triCodecType := tripleConstant.CodecType(dubboSerializerType) + // new triple client + opts := []triConfig.OptionFunction{ + triConfig.WithClientTimeout(timeout), + triConfig.WithCodecType(triCodecType), + triConfig.WithLocation(url.Location), + triConfig.WithHeaderAppVersion(url.GetParam(constant.AppVersionKey, "")), + triConfig.WithHeaderGroup(url.GetParam(constant.GroupKey, "")), + triConfig.WithLogger(logger.GetLogger()), + } + maxCallRecvMsgSize := constant.DefaultMaxCallRecvMsgSize + if maxCall, err := humanize.ParseBytes(url.GetParam(constant.MaxCallRecvMsgSize, "")); err == nil && maxCall != 0 { + maxCallRecvMsgSize = int(maxCall) + } + maxCallSendMsgSize := constant.DefaultMaxCallSendMsgSize + if maxCall, err := humanize.ParseBytes(url.GetParam(constant.MaxCallSendMsgSize, "")); err == nil && maxCall != 0 { + maxCallSendMsgSize = int(maxCall) + } + opts = append(opts, triConfig.WithGRPCMaxCallRecvMessageSize(maxCallRecvMsgSize)) + opts = append(opts, triConfig.WithGRPCMaxCallSendMessageSize(maxCallSendMsgSize)) + + tracingKey := url.GetParam(constant.TracingConfigKey, "") + if tracingKey != "" { + tracingConfig := config.GetTracingConfig(tracingKey) + if tracingConfig != nil { + if tracingConfig.Name == "jaeger" { + if tracingConfig.ServiceName == "" { + tracingConfig.ServiceName = config.GetApplicationConfig().Name + } + opts = append(opts, triConfig.WithJaegerConfig( + tracingConfig.Address, + tracingConfig.ServiceName, + *tracingConfig.UseAgent, + )) + } else { + logger.Warnf("unsupported tracing name %s, now triple only support jaeger", tracingConfig.Name) + } + } + } + + triOption := triConfig.NewTripleOption(opts...) + tlsConfig := config.GetRootConfig().TLSConfig + if tlsConfig != nil { + triOption.TLSCertFile = tlsConfig.TLSCertFile + triOption.TLSKeyFile = tlsConfig.TLSKeyFile + triOption.CACertFile = tlsConfig.CACertFile + triOption.TLSServerName = tlsConfig.TLSServerName + logger.Infof("Triple Client initialized the TLSConfig configuration") + } + client, err := triple.NewTripleClient(consumerService, triOption) + + if err != nil { + return nil, err + } + + return &DubboInvoker{ + BaseInvoker: *protocol.NewBaseInvoker(url), + client: client, + timeout: timeout, + clientGuard: &sync.RWMutex{}, + }, nil +} + +func (di *DubboInvoker) setClient(client *triple.TripleClient) { + di.clientGuard.Lock() + defer di.clientGuard.Unlock() + + di.client = client +} + +func (di *DubboInvoker) getClient() *triple.TripleClient { + di.clientGuard.RLock() + defer di.clientGuard.RUnlock() + + return di.client +} + +// Invoke call remoting. +func (di *DubboInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { + var ( + result protocol.RPCResult + ) + + if !di.BaseInvoker.IsAvailable() { + // Generally, the case will not happen, because the invoker has been removed + // from the invoker list before destroy,so no new request will enter the destroyed invoker + logger.Warnf("this dubboInvoker is destroyed") + result.Err = protocol.ErrDestroyedInvoker + return &result + } + + di.clientGuard.RLock() + defer di.clientGuard.RUnlock() + + if di.client == nil { + result.Err = protocol.ErrClientClosed + return &result + } + + if !di.BaseInvoker.IsAvailable() { + // Generally, the case will not happen, because the invoker has been removed + // from the invoker list before destroy,so no new request will enter the destroyed invoker + logger.Warnf("this grpcInvoker is destroying") + result.Err = protocol.ErrDestroyedInvoker + return &result + } + + for _, k := range attachmentKey { + var paramKey string + switch k { + case tripleConstant.TripleServiceGroup: + paramKey = constant.GroupKey + case tripleConstant.TripleServiceVersion: + paramKey = constant.VersionKey + default: + paramKey = k + } + + if v := di.GetURL().GetParam(paramKey, ""); len(v) > 0 { + invocation.SetAttachment(k, v) + } + } + + // append interface id to ctx + gRPCMD := make(metadata.MD, 0) + // triple will convert attachment value to []string + for k, v := range invocation.Attachments() { + if str, ok := v.(string); ok { + gRPCMD.Set(k, str) + continue + } + if str, ok := v.([]string); ok { + gRPCMD.Set(k, str...) + continue + } + logger.Warnf("[Triple Protocol]Triple attachment value with key = %s is invalid, which should be string or []string", k) + } + ctx = metadata.NewOutgoingContext(ctx, gRPCMD) + ctx = context.WithValue(ctx, tripleConstant.InterfaceKey, di.BaseInvoker.GetURL().GetParam(constant.InterfaceKey, "")) + in := make([]reflect.Value, 0, 16) + in = append(in, reflect.ValueOf(ctx)) + + if len(invocation.ParameterValues()) > 0 { + in = append(in, invocation.ParameterValues()...) + } + + methodName := invocation.MethodName() + triAttachmentWithErr := di.client.Invoke(methodName, in, invocation.Reply()) + result.Err = triAttachmentWithErr.GetError() + result.Attrs = make(map[string]interface{}) + for k, v := range triAttachmentWithErr.GetAttachments() { + result.Attrs[k] = v + } + result.Rest = invocation.Reply() + return &result +} + +// get timeout including methodConfig +func (di *DubboInvoker) getTimeout(invocation *invocation_impl.RPCInvocation) time.Duration { + timeout := di.GetURL().GetParam(strings.Join([]string{constant.MethodKeys, invocation.MethodName(), constant.TimeoutKey}, "."), "") + if len(timeout) != 0 { + if t, err := time.ParseDuration(timeout); err == nil { + // config timeout into attachment + invocation.SetAttachment(constant.TimeoutKey, strconv.Itoa(int(t.Milliseconds()))) + return t + } + } + // set timeout into invocation at method level + invocation.SetAttachment(constant.TimeoutKey, strconv.Itoa(int(di.timeout.Milliseconds()))) + return di.timeout +} + +// IsAvailable check if invoker is available, now it is useless +func (di *DubboInvoker) IsAvailable() bool { + client := di.getClient() + if client != nil { + // FIXME here can't check if tcp server is started now!!! + return client.IsAvailable() + } + return false +} + +// Destroy destroy dubbo3 client invoker. +func (di *DubboInvoker) Destroy() { + di.quitOnce.Do(func() { + di.BaseInvoker.Destroy() + client := di.getClient() + if client != nil { + di.setClient(nil) + client.Close() + } + }) +} diff --git a/protocol/triple/internal/proto/triple_gen/greettriple/greet.triple.go b/protocol/triple/internal/proto/triple_gen/greettriple/greet.triple.go index 6672c8909..3a9c685d2 100644 --- a/protocol/triple/internal/proto/triple_gen/greettriple/greet.triple.go +++ b/protocol/triple/internal/proto/triple_gen/greettriple/greet.triple.go @@ -28,9 +28,7 @@ import ( import ( client "dubbo.apache.org/dubbo-go/v3/client" - "dubbo.apache.org/dubbo-go/v3/common" "dubbo.apache.org/dubbo-go/v3/common/constant" - "dubbo.apache.org/dubbo-go/v3/config" "dubbo.apache.org/dubbo-go/v3/protocol/triple/internal/proto" triple_protocol "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" "dubbo.apache.org/dubbo-go/v3/server" @@ -38,7 +36,7 @@ import ( // This is a compile-time assertion to ensure that this generated file and the Triple package // are compatible. If you get a compiler error that this constant is not defined, this code was -// generated with a version ofTtriple newer than the one compiled into your binary. You can fix the +// generated with a version of Triple newer than the one compiled into your binary. You can fix the // problem by either regenerating this code with an older version of Triple or updating the Triple // version compiled into your binary. const _ = triple_protocol.IsAtLeastVersion0_1_0 @@ -49,7 +47,7 @@ const ( ) // These constants are the fully-qualified names of the RPCs defined in this package. They're -// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route. +// exposed at runtime as procedure and as the final two segments of the HTTP route. // // Note that these are different from the fully-qualified method names used by // google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to @@ -89,9 +87,6 @@ type GreetService interface { } // NewGreetService constructs a client for the greet.GreetService service. -// -// The URL supplied here should be the base URL for the Triple server (for example, -// http://api.acme.com or https://acme.com/grpc). func NewGreetService(cli *client.Client) (GreetService, error) { if err := cli.Init(&GreetService_ClientInfo); err != nil { return nil, err @@ -101,10 +96,6 @@ func NewGreetService(cli *client.Client) (GreetService, error) { }, nil } -func SetConsumerService(srv common.RPCService) { - config.SetClientInfoService(&GreetService_ClientInfo, srv) -} - // GreetServiceImpl implements GreetService. type GreetServiceImpl struct { cli *client.Client diff --git a/protocol/triple/triple.go b/protocol/triple/triple.go index c68ebd156..37217164d 100644 --- a/protocol/triple/triple.go +++ b/protocol/triple/triple.go @@ -57,9 +57,8 @@ type TripleProtocol struct { func (tp *TripleProtocol) Export(invoker protocol.Invoker) protocol.Exporter { url := invoker.GetURL() serviceKey := url.ServiceKey() - // todo: retrieve this info from url var info *server.ServiceInfo - infoRaw, ok := url.Attributes[constant.ServiceInfoKey] + infoRaw, ok := url.GetAttribute(constant.ServiceInfoKey) if ok { info = infoRaw.(*server.ServiceInfo) } @@ -104,7 +103,17 @@ func (tp *TripleProtocol) openServer(invoker protocol.Invoker, info *server.Serv // Refer a remote triple service func (tp *TripleProtocol) Refer(url *common.URL) protocol.Invoker { - invoker, err := NewTripleInvoker(url) + var invoker protocol.Invoker + var err error + // for now, we do not need to use this info + _, ok := url.GetAttribute(constant.ClientInfoKey) + if ok { + // stub code generated by new protoc-gen-go-triple + invoker, err = NewTripleInvoker(url) + } else { + // stub code generated by old protoc-gen-go-triple + invoker, err = NewDubbo3Invoker(url) + } if err != nil { logger.Warnf("can't dial the server: %s", url.Key()) return nil diff --git a/protocol/triple/triple_protocol/client.go b/protocol/triple/triple_protocol/client.go index 5ec4ac467..77ebd5bd6 100644 --- a/protocol/triple/triple_protocol/client.go +++ b/protocol/triple/triple_protocol/client.go @@ -22,16 +22,16 @@ import ( "net/http" "net/url" "strings" + "time" ) // Client is a reusable, concurrency-safe client for a single procedure. // Depending on the procedure's type, use the CallUnary, CallClientStream, // CallServerStream, or CallBidiStream method. // -// todo:// modify comment -// By default, clients use the Connect protocol with the binary Protobuf Codec, -// ask for gzipped responses, and send uncompressed requests. To use the gRPC -// or gRPC-Web protocols, use the [WithGRPC] or [WithGRPCWeb] options. +// By default, clients use the gRPC protocol with the binary Protobuf Codec, +// ask for gzipped responses, and send uncompressed requests. To use the Triple, +// use the [WithTriple] options. type Client struct { config *clientConfig callUnary func(context.Context, *Request, *Response) error @@ -63,9 +63,7 @@ func NewClient(httpClient HTTPClient, url string, options ...ClientOption) *Clie BufferPool: config.BufferPool, ReadMaxBytes: config.ReadMaxBytes, SendMaxBytes: config.SendMaxBytes, - EnableGet: config.EnableGet, GetURLMaxBytes: config.GetURLMaxBytes, - GetUseFallback: config.GetUseFallback, }, ) if protocolErr != nil { @@ -82,6 +80,8 @@ func NewClient(httpClient HTTPClient, url string, options ...ClientOption) *Clie // We want the user to continue to call Receive in those cases to get the // full error from the server-side. if err := conn.Send(request.Any()); err != nil && !errors.Is(err, io.EOF) { + // for HTTP/1.1 case, CloseRequest must happen before CloseResponse + // since HTTP/1.1 is of request-response type _ = conn.CloseRequest() _ = conn.CloseResponse() return err @@ -109,10 +109,7 @@ func NewClient(httpClient HTTPClient, url string, options ...ClientOption) *Clie if err := unaryFunc(ctx, request, response); err != nil { return err } - //typed, ok := response.(*Response[Res]) - //if !ok { - // return nil, errorf(CodeInternal, "unexpected client response type %T", response) - //} + return nil } return client @@ -123,6 +120,10 @@ func (c *Client) CallUnary(ctx context.Context, request *Request, response *Resp if c.err != nil { return c.err } + ctx, flag, cancel := applyDefaultTimeout(ctx, c.config.Timeout) + if flag { + defer cancel() + } return c.callUnary(ctx, request, response) } @@ -190,10 +191,10 @@ type clientConfig struct { BufferPool *bufferPool ReadMaxBytes int SendMaxBytes int - EnableGet bool GetURLMaxBytes int GetUseFallback bool IdempotencyLevel IdempotencyLevel + Timeout time.Duration } func newClientConfig(rawURL string, options []ClientOption) (*clientConfig, *Error) { @@ -203,13 +204,16 @@ func newClientConfig(rawURL string, options []ClientOption) (*clientConfig, *Err } protoPath := extractProtoPath(url.Path) config := clientConfig{ - URL: url, + URL: url, + // use gRPC by default Protocol: &protocolGRPC{}, Procedure: protoPath, CompressionPools: make(map[string]*compressionPool), BufferPool: newBufferPool(), } + // use proto binary by default withProtoBinaryCodec().applyToClient(&config) + // use gzip by default withGzip().applyToClient(&config) for _, opt := range options { opt.applyToClient(&config) @@ -263,3 +267,15 @@ func parseRequestURL(rawURL string) (*url.URL, *Error) { } return nil, NewError(CodeUnavailable, err) } + +func applyDefaultTimeout(ctx context.Context, timeout time.Duration) (context.Context, bool, context.CancelFunc) { + var cancel context.CancelFunc + var applyFlag bool + _, ok := ctx.Deadline() + if !ok && timeout != 0 { + ctx, cancel = context.WithTimeout(ctx, timeout) + applyFlag = true + } + + return ctx, applyFlag, cancel +} diff --git a/protocol/triple/triple_protocol/client_stream.go b/protocol/triple/triple_protocol/client_stream.go index 0f01f2960..c2a936b6b 100644 --- a/protocol/triple/triple_protocol/client_stream.go +++ b/protocol/triple/triple_protocol/client_stream.go @@ -47,6 +47,8 @@ func (c *ClientStreamForClient) Peer() Peer { // Triple and gRPC protocols. Applications shouldn't write them. func (c *ClientStreamForClient) RequestHeader() http.Header { if c.err != nil { + // todo(DMwangnima): since there is error in ClientStreamForClient, maybe we should tell user other than + // returning a empty Header return http.Header{} } return c.conn.RequestHeader() @@ -56,12 +58,13 @@ func (c *ClientStreamForClient) RequestHeader() http.Header { // headers. // // If the server returns an error, Send returns an error that wraps [io.EOF]. -// Clients should check for case using the standard library's [errors.Is] and -// unmarshal the error using CloseAndReceive. +// Clients should check for case using the standard library's [errors.Is] or +// [IsEnded] and unmarshal the error using CloseAndReceive. func (c *ClientStreamForClient) Send(request interface{}) error { if c.err != nil { return c.err } + // todo(DMwangnima): remove this redundant statement if request == nil { return c.conn.Send(nil) } @@ -109,6 +112,7 @@ type ServerStreamForClient struct { // either by reaching the end or by encountering an unexpected error. After // Receive returns false, the Err method will return any unexpected error // encountered. +// todo(DMwangnima): add classic usage func (s *ServerStreamForClient) Receive(msg interface{}) bool { if s.constructErr != nil || s.receiveErr != nil { return false @@ -120,7 +124,7 @@ func (s *ServerStreamForClient) Receive(msg interface{}) bool { // Msg returns the most recent message unmarshaled by a call to Receive. func (s *ServerStreamForClient) Msg() interface{} { - // todo:// processing nil pointer + // todo(DMwangnima): processing nil pointer //if s.msg == nil { // s.msg = new(Res) //} @@ -142,6 +146,8 @@ func (s *ServerStreamForClient) Err() error { // the first call to Receive returns. func (s *ServerStreamForClient) ResponseHeader() http.Header { if s.constructErr != nil { + // todo(DMwangnima): since there is error in ServerStreamForClient, maybe we should tell user other than + // returning an empty Header return http.Header{} } return s.conn.ResponseHeader() @@ -151,6 +157,8 @@ func (s *ServerStreamForClient) ResponseHeader() http.Header { // aren't fully populated until Receive() returns an error wrapping io.EOF. func (s *ServerStreamForClient) ResponseTrailer() http.Header { if s.constructErr != nil { + // todo(DMwangnima): since there is error in ServerStreamForClient, maybe we should tell user other than + // returning an empty Header return http.Header{} } return s.conn.ResponseTrailer() @@ -197,13 +205,15 @@ func (b *BidiStreamForClient) Peer() Peer { // Triple and gRPC protocols. Applications shouldn't write them. func (b *BidiStreamForClient) RequestHeader() http.Header { if b.err != nil { + // todo(DMwangnima): since there is error in BidiStreamForClient, maybe we should tell user other than + // returning an empty Header return http.Header{} } return b.conn.RequestHeader() } // Send a message to the server. The first call to Send also sends the request -// headers. To send just the request headers, without a body, call Send with a +// headers. To send just the request headers without a body, call Send with a // nil pointer. // // If the server returns an error, Send returns an error that wraps [io.EOF]. @@ -213,6 +223,7 @@ func (b *BidiStreamForClient) Send(msg interface{}) error { if b.err != nil { return b.err } + // todo(DMwangnima): remove this redundant statement if msg == nil { return b.conn.Send(nil) } @@ -251,6 +262,8 @@ func (b *BidiStreamForClient) CloseResponse() error { // the first call to Receive returns. func (b *BidiStreamForClient) ResponseHeader() http.Header { if b.err != nil { + // todo(DMwangnima): since there is error in BidiStreamForClient, maybe we should tell user other than + // returning an empty Header return http.Header{} } return b.conn.ResponseHeader() @@ -260,6 +273,8 @@ func (b *BidiStreamForClient) ResponseHeader() http.Header { // aren't fully populated until Receive() returns an error wrapping [io.EOF]. func (b *BidiStreamForClient) ResponseTrailer() http.Header { if b.err != nil { + // todo(DMwangnima): since there is error in BidiStreamForClient, maybe we should tell user other than + // returning an empty Header return http.Header{} } return b.conn.ResponseTrailer() diff --git a/protocol/triple/triple_protocol/code.go b/protocol/triple/triple_protocol/code.go index 95d634549..c66074a67 100644 --- a/protocol/triple/triple_protocol/code.go +++ b/protocol/triple/triple_protocol/code.go @@ -20,22 +20,23 @@ import ( "strings" ) -// A Code is one of the Connect protocol's error codes. There are no user-defined +// A Code is one of the Triple protocol's error codes. There are no user-defined // codes, so only the codes enumerated below are valid. In both name and // semantics, these codes match the gRPC status codes. // // The descriptions below are optimized for brevity rather than completeness. -// See the [Connect protocol specification] for detailed descriptions of each +// See the [Triple protocol specification] for detailed descriptions of each // code and example usage. // -// [Connect protocol specification]: https://connect.build/docs/protocol +// todo(DMwangnima): add specification to dubbo-go official site +// [Triple protocol specification]: https://connect.build/docs/protocol type Code uint32 const ( // The zero code in gRPC is OK, which indicates that the operation was a // success. We don't define a constant for it because it overlaps awkwardly // with Go's error semantics: what does it mean to have a non-nil error with - // an OK status? (Also, the Connect protocol doesn't use a code for + // an OK status? (Also, the Triple protocol doesn't use a code for // successes.) // CodeCanceled indicates that the operation was canceled, typically by the @@ -219,8 +220,8 @@ func (c *Code) UnmarshalText(data []byte) error { // CodeOf returns the error's status code if it is or wraps an [*Error] and // [CodeUnknown] otherwise. func CodeOf(err error) Code { - if connectErr, ok := asError(err); ok { - return connectErr.Code() + if tripleErr, ok := asError(err); ok { + return tripleErr.Code() } return CodeUnknown } diff --git a/protocol/triple/triple_protocol/duplex_http_call.go b/protocol/triple/triple_protocol/duplex_http_call.go index 9f5e6ed56..3865a56ab 100644 --- a/protocol/triple/triple_protocol/duplex_http_call.go +++ b/protocol/triple/triple_protocol/duplex_http_call.go @@ -63,6 +63,7 @@ func newDuplexHTTPCall( url = cloneURL(url) pipeReader, pipeWriter := io.Pipe() + // todo(DMwangnima): remove cloneURL logic in WithContext // This is mirroring what http.NewRequestContext did, but // using an already parsed url.URL object, rather than a string // and parsing it again. This is a bit funny with HTTP/1.1 @@ -112,7 +113,7 @@ func (d *duplexHTTPCall) Write(data []byte) (int, error) { return bytesWritten, err } -// Close the request body. Callers *must* call CloseWrite before Read when +// CloseWrite closes the request body. Callers *must* call CloseWrite before Read when // using HTTP/1.x. func (d *duplexHTTPCall) CloseWrite() error { // Even if Write was never called, we need to make an HTTP request. This diff --git a/protocol/triple/triple_protocol/envelope.go b/protocol/triple/triple_protocol/envelope.go index 845624e2a..8f46b7a54 100644 --- a/protocol/triple/triple_protocol/envelope.go +++ b/protocol/triple/triple_protocol/envelope.go @@ -22,22 +22,22 @@ import ( ) // flagEnvelopeCompressed indicates that the data is compressed. It has the -// same meaning in the gRPC-Web, gRPC-HTTP2, and Connect protocols. +// same meaning in the gRPC-Web, gRPC-HTTP2, and Triple protocols. const flagEnvelopeCompressed = 0b00000001 var errSpecialEnvelope = errorf( CodeUnknown, "final message has protocol-specific flags: %w", - // User code checks for end of stream with errors.Is(err, io.EOF). + // User code checks for end of stream with errors.Is(err, io.EOF) or triple.IsEnded(err). io.EOF, ) -// envelope is a block of arbitrary bytes wrapped in gRPC and Connect's framing +// envelope is a block of arbitrary bytes wrapped in gRPC and Triple's framing // protocol. // // Each message is preceded by a 5-byte prefix. The first byte is a uint8 used // as a set of bitwise flags, and the remainder is a uint32 indicating the -// message length. gRPC and Connect interpret the bitwise flags differently, so +// message length. gRPC and Triple interpret the bitwise flags differently, so // envelope leaves their interpretation up to the caller. type envelope struct { Data *bytes.Buffer @@ -61,8 +61,8 @@ type envelopeWriter struct { func (w *envelopeWriter) Marshal(message interface{}) *Error { if message == nil { if _, err := w.writer.Write(nil); err != nil { - if connectErr, ok := asError(err); ok { - return connectErr + if tripleErr, ok := asError(err); ok { + return tripleErr } return NewError(CodeUnknown, err) } @@ -113,8 +113,8 @@ func (w *envelopeWriter) write(env *envelope) *Error { prefix[0] = env.Flags binary.BigEndian.PutUint32(prefix[1:5], uint32(env.Data.Len())) if _, err := w.writer.Write(prefix[:]); err != nil { - if connectErr, ok := asError(err); ok { - return connectErr + if tripleErr, ok := asError(err); ok { + return tripleErr } return errorf(CodeUnknown, "write envelope: %w", err) } @@ -213,8 +213,8 @@ func (r *envelopeReader) Read(env *envelope) *Error { return NewError(CodeUnknown, err) case err != nil || prefixBytesRead < 5: // Something else has gone wrong - the stream didn't end cleanly. - if connectErr, ok := asError(err); ok { - return connectErr + if tripleErr, ok := asError(err); ok { + return tripleErr } if maxBytesErr := asMaxBytesError(err, "read 5 byte message prefix"); maxBytesErr != nil { // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. diff --git a/protocol/triple/triple_protocol/error.go b/protocol/triple/triple_protocol/error.go index f8b439fd3..609b4422e 100644 --- a/protocol/triple/triple_protocol/error.go +++ b/protocol/triple/triple_protocol/error.go @@ -30,6 +30,7 @@ import ( ) const ( + // todo(DMwangnima): add common errors documentation commonErrorsURL = "https://connect.build/docs/go/common-errors" defaultAnyResolverPrefix = "type.googleapis.com/" ) @@ -107,6 +108,7 @@ func (d *ErrorDetail) Value() (proto.Message, error) { // They're a clearer and more performant alternative to HTTP header // microformats. See [the documentation on errors] for more details. // +// todo(DMwangnima): add error documentation to dubbo-go official website // [the documentation on errors]: https://connect.build/docs/go/errors type Error struct { code Code @@ -225,9 +227,9 @@ func errorf(c Code, template string, args ...interface{}) *Error { // asError uses errors.As to unwrap any error and look for a triple *Error. func asError(err error) (*Error, bool) { - var connectErr *Error - ok := errors.As(err, &connectErr) - return connectErr, ok + var tripleErr *Error + ok := errors.As(err, &tripleErr) + return tripleErr, ok } // wrapIfUncoded ensures that all errors are wrapped. It leaves already-wrapped @@ -264,7 +266,7 @@ func wrapIfContextError(err error) error { return err } -// wrapIfLikelyWithGRPCNotUsedError adds a wrapping error that has a message +// wrapIfLikelyH2CNotConfiguredError adds a wrapping error that has a message // telling the caller that they likely need to use h2c but are using a raw http.Client{}. // // This happens when running a gRPC-only server. @@ -276,7 +278,7 @@ func wrapIfLikelyH2CNotConfiguredError(request *http.Request, err error) error { if _, ok := asError(err); ok { return err } - if url := request.URL; url != nil && url.Scheme != "http" { + if reqUrl := request.URL; reqUrl != nil && reqUrl.Scheme != "http" { // If the scheme is not http, we definitely do not have an h2c error, so just return. return err } @@ -292,7 +294,7 @@ func wrapIfLikelyH2CNotConfiguredError(request *http.Request, err error) error { } // wrapIfLikelyWithGRPCNotUsedError adds a wrapping error that has a message -// telling the caller that they likely forgot to use triple.WithGRPC(). +// telling the caller that the server side does not use gRPC. // // This happens when running a gRPC-only server. // This is fragile and may break over time, and this should be considered a best-effort. diff --git a/protocol/triple/triple_protocol/error_writer.go b/protocol/triple/triple_protocol/error_writer.go index b9aedef14..f227c8823 100644 --- a/protocol/triple/triple_protocol/error_writer.go +++ b/protocol/triple/triple_protocol/error_writer.go @@ -39,7 +39,7 @@ type ErrorWriter struct { // NewErrorWriter constructs an ErrorWriter. To properly recognize supported // RPC Content-Types in net/http middleware, you must pass the same -// HandlerOptions to NewErrorWriter and any wrapped Connect handlers. +// HandlerOptions to NewErrorWriter and any wrapped Triple handlers. func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { config := newHandlerConfig("", opts) writer := &ErrorWriter{ diff --git a/protocol/triple/triple_protocol/handler.go b/protocol/triple/triple_protocol/handler.go index d4baea773..9bca672b5 100644 --- a/protocol/triple/triple_protocol/handler.go +++ b/protocol/triple/triple_protocol/handler.go @@ -23,7 +23,7 @@ import ( // A Handler is the server-side implementation of a single RPC defined by a // service schema. // -// By default, Handlers support the Connect, gRPC, and gRPC-Web protocols with +// By default, Handlers support the Triple, gRPC, and gRPC-Web protocols with // the binary Protobuf and JSON codecs. They support gzip compression using the // standard library's [compress/gzip]. type Handler struct { @@ -259,18 +259,18 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re } type handlerConfig struct { - CompressionPools map[string]*compressionPool - CompressionNames []string - Codecs map[string]Codec - CompressMinBytes int - Interceptor Interceptor - Procedure string - HandleGRPC bool - RequireConnectProtocolHeader bool - IdempotencyLevel IdempotencyLevel - BufferPool *bufferPool - ReadMaxBytes int - SendMaxBytes int + CompressionPools map[string]*compressionPool + CompressionNames []string + Codecs map[string]Codec + CompressMinBytes int + Interceptor Interceptor + Procedure string + HandleGRPC bool + RequireTripleProtocolHeader bool + IdempotencyLevel IdempotencyLevel + BufferPool *bufferPool + ReadMaxBytes int + SendMaxBytes int } func newHandlerConfig(procedure string, options []HandlerOption) *handlerConfig { @@ -326,7 +326,7 @@ func (c *handlerConfig) newProtocolHandlers(streamType StreamType) []protocolHan BufferPool: c.BufferPool, ReadMaxBytes: c.ReadMaxBytes, SendMaxBytes: c.SendMaxBytes, - RequireTripleProtocolHeader: c.RequireConnectProtocolHeader, + RequireTripleProtocolHeader: c.RequireTripleProtocolHeader, IdempotencyLevel: c.IdempotencyLevel, })) } diff --git a/protocol/triple/triple_protocol/handler_compat.go b/protocol/triple/triple_protocol/handler_compat.go index ccd272d6a..5730abce6 100644 --- a/protocol/triple/triple_protocol/handler_compat.go +++ b/protocol/triple/triple_protocol/handler_compat.go @@ -31,14 +31,8 @@ import ( dubbo_protocol "dubbo.apache.org/dubbo-go/v3/protocol" ) -type TripleCtxKey string - type MethodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) -const ( - TripleGoInterfaceNameKey TripleCtxKey = "XXX_TRIPLE_GO_INTERFACE_NAME" -) - type tripleCompatInterceptor struct { spec Spec peer Peer @@ -105,7 +99,9 @@ func NewCompatUnaryHandler( } return nil } - ctx = context.WithValue(ctx, TripleGoInterfaceNameKey, config.Procedure) + // staticcheck error: SA1029. Stub code generated by protoc-gen-go-triple makes use of "XXX_TRIPLE_GO_INTERFACE_NAME" directly + //nolint:staticcheck + ctx = context.WithValue(ctx, "XXX_TRIPLE_GO_INTERFACE_NAME", config.Procedure) respRaw, err := unary(srv, ctx, decodeFunc, compatInterceptor.compatUnaryServerInterceptor) if err != nil { return err diff --git a/protocol/triple/triple_protocol/option.go b/protocol/triple/triple_protocol/option.go index 341ddc7f6..dab116074 100644 --- a/protocol/triple/triple_protocol/option.go +++ b/protocol/triple/triple_protocol/option.go @@ -19,6 +19,7 @@ import ( "context" "io" "net/http" + "time" ) // A ClientOption configures a [Client]. @@ -95,6 +96,13 @@ func WithSendGzip() ClientOption { return WithSendCompression(compressionGzip) } +// WithTimeout configures the default timeout of client call including unary +// and stream. If you want to specify the timeout of a specific request, please +// use context.WithTimeout, then default timeout would be overridden. +func WithTimeout(timeout time.Duration) ClientOption { + return &timeoutOption{Timeout: timeout} +} + // A HandlerOption configures a [Handler]. // // In addition to any options grouped in the documentation below, remember that @@ -106,7 +114,7 @@ type HandlerOption interface { // WithCompression configures handlers to support a compression algorithm. // Clients may send messages compressed with that algorithm and/or request // compressed responses. The [Compressor] and [Decompressor] produced by the -// supplied constructors must use the same algorithm. Internally, Connect pools +// supplied constructors must use the same algorithm. Internally, Triple pools // compressors and decompressors. // // By default, handlers support gzip using the standard library's @@ -149,16 +157,16 @@ func WithRecover(handle func(context.Context, Spec, http.Header, interface{}) er return WithInterceptors(&recoverHandlerInterceptor{handle: handle}) } -// WithRequireConnectProtocolHeader configures the Handler to require requests -// using the Connect RPC protocol to include the Connect-Protocol-Version +// WithRequireTripleProtocolHeader configures the Handler to require requests +// using the Triple RPC protocol to include the Triple-Protocol-Version // header. This ensures that HTTP proxies and net/http middleware can easily -// identify valid Connect requests, even if they use a common Content-Type like +// identify valid Triple requests, even if they use a common Content-Type like // application/json. However, it makes ad-hoc requests with tools like cURL // more laborious. // // This option has no effect if the client uses the gRPC or gRPC-Web protocols. -func WithRequireConnectProtocolHeader() HandlerOption { - return &requireConnectProtocolHeaderOption{} +func WithRequireTripleProtocolHeader() HandlerOption { + return &requireTripleProtocolHeaderOption{} } // Option implements both [ClientOption] and [HandlerOption], so it can be @@ -207,7 +215,7 @@ func WithCompressMinBytes(min int) Option { // handlers default to allowing any request size. // // Handlers may also use [http.MaxBytesHandler] to limit the total size of the -// HTTP request stream (rather than the per-message size). Connect handles +// HTTP request stream (rather than the per-message size). Triple handles // [http.MaxBytesError] specially, so clients still receive errors with the // appropriate error code and informative messages. func WithReadMaxBytes(max int) Option { @@ -226,6 +234,7 @@ func WithSendMaxBytes(max int) Option { return &sendMaxBytesOption{Max: max} } +// todo(DMwangnima): consider how to expose this functionality to users // WithIdempotency declares the idempotency of the procedure. This can determine // whether a procedure call can safely be retried, and may affect which request // modalities are allowed for a given procedure call. @@ -240,22 +249,6 @@ func WithIdempotency(idempotencyLevel IdempotencyLevel) Option { return &idempotencyOption{idempotencyLevel: idempotencyLevel} } -// WithHTTPGet allows Connect-protocol clients to use HTTP GET requests for -// side-effect free unary RPC calls. Typically, the service schema indicates -// which procedures are idempotent (see [WithIdempotency] for an example -// protobuf schema). The gRPC and gRPC-Web protocols are POST-only, so this -// option has no effect when combined with [WithGRPC] or [WithGRPCWeb]. -// -// Using HTTP GET requests makes it easier to take advantage of CDNs, caching -// reverse proxies, and browsers' built-in caching. Note, however, that servers -// don't automatically set any cache headers; you can set cache headers using -// interceptors or by adding headers in individual procedure implementations. -// -// By default, all requests are made as HTTP POSTs. -func WithHTTPGet() ClientOption { - return &enableGet{} -} - // WithInterceptors configures a client or handler's interceptor stack. Repeated // WithInterceptors options are applied in order, so // @@ -419,10 +412,10 @@ func (o *handlerOptionsOption) applyToHandler(config *handlerConfig) { } } -type requireConnectProtocolHeaderOption struct{} +type requireTripleProtocolHeaderOption struct{} -func (o *requireConnectProtocolHeaderOption) applyToHandler(config *handlerConfig) { - config.RequireConnectProtocolHeader = true +func (o *requireTripleProtocolHeaderOption) applyToHandler(config *handlerConfig) { + config.RequireTripleProtocolHeader = true } type idempotencyOption struct { @@ -443,44 +436,6 @@ func (o *tripleOption) applyToClient(config *clientConfig) { config.Protocol = &protocolTriple{} } -type enableGet struct{} - -func (o *enableGet) applyToClient(config *clientConfig) { - config.EnableGet = true -} - -// withHTTPGetMaxURLSize sets the maximum allowable URL length for GET requests -// made using the Connect protocol. It has no effect on gRPC or gRPC-Web -// clients, since those protocols are POST-only. -// -// Limiting the URL size is useful as most user agents, proxies, and servers -// have limits on the allowable length of a URL. For example, Apache and Nginx -// limit the size of a request line to around 8 KiB, meaning that maximum -// length of a URL is a bit smaller than this. If you run into URL size -// limitations imposed by your network infrastructure and don't know the -// maximum allowable size, or if you'd prefer to be cautious from the start, a -// 4096 byte (4 KiB) limit works with most common proxies and CDNs. -// -// If fallback is set to true and the URL would be longer than the configured -// maximum value, the request will be sent as an HTTP POST instead. If fallback -// is set to false, the request will fail with [CodeResourceExhausted]. -// -// By default, Connect-protocol clients with GET requests enabled may send a -// URL of any size. -func withHTTPGetMaxURLSize(bytes int, fallback bool) ClientOption { - return &getURLMaxBytes{Max: bytes, Fallback: fallback} -} - -type getURLMaxBytes struct { - Max int - Fallback bool -} - -func (o *getURLMaxBytes) applyToClient(config *clientConfig) { - config.GetURLMaxBytes = o.Max - config.GetUseFallback = o.Fallback -} - type interceptorsOption struct { Interceptors []Interceptor } @@ -530,6 +485,14 @@ func (o *sendCompressionOption) applyToClient(config *clientConfig) { config.RequestCompressionName = o.Name } +type timeoutOption struct { + Timeout time.Duration +} + +func (o *timeoutOption) applyToClient(config *clientConfig) { + config.Timeout = o.Timeout +} + func withGzip() Option { return &compressionOption{ Name: compressionGzip, diff --git a/protocol/triple/triple_protocol/protocol.go b/protocol/triple/triple_protocol/protocol.go index 3cc4813f2..aaffd1143 100644 --- a/protocol/triple/triple_protocol/protocol.go +++ b/protocol/triple/triple_protocol/protocol.go @@ -26,7 +26,7 @@ import ( "strings" ) -// The names of the Connect, gRPC, and gRPC-Web protocols (as exposed by +// The names of the Triple, gRPC, and gRPC-Web protocols (as exposed by // [Peer.Protocol]). Additional protocols may be added in the future. const ( ProtocolTriple = "triple" @@ -281,8 +281,8 @@ func negotiateCompression( //nolint:nonamedreturns } else { // To comply with // https://github.com/grpc/grpc/blob/master/doc/compression.md and the - // Connect protocol, we should return CodeUnimplemented and specify - // acceptable compression(s) (in addition to setting the a + // Triple protocol, we should return CodeUnimplemented and specify + // acceptable compression(s) (in addition to setting the // protocol-specific accept-encoding header). return "", "", errorf( CodeUnimplemented, diff --git a/protocol/triple/triple_protocol/protocol_grpc.go b/protocol/triple/triple_protocol/protocol_grpc.go index 555d29562..80dff3864 100644 --- a/protocol/triple/triple_protocol/protocol_grpc.go +++ b/protocol/triple/triple_protocol/protocol_grpc.go @@ -500,7 +500,7 @@ func (hc *grpcHandlerConn) Close(err error) (retErr error) { mergeHeaders(hc.responseWriter.Header(), hc.responseHeader) } // gRPC always sends the error's code, message, details, and metadata as - // trailing metadata. The Connect protocol doesn't do this, so we don't want + // trailing metadata. The Triple protocol doesn't do this, so we don't want // to mutate the trailers map that the user sees. mergedTrailers := make( http.Header, @@ -827,8 +827,8 @@ func grpcErrorToTrailer(bufferPool *bufferPool, trailer http.Header, protobuf Co ) return } - if connectErr, ok := asError(err); ok { - mergeHeaders(trailer, connectErr.meta) + if tripleErr, ok := asError(err); ok { + mergeHeaders(trailer, tripleErr.meta) } setHeaderCanonical(trailer, grpcHeaderStatus, code) setHeaderCanonical(trailer, grpcHeaderMessage, grpcPercentEncode(bufferPool, status.Message)) @@ -840,10 +840,10 @@ func grpcStatusFromError(err error) *statusv1.Status { Code: int32(CodeUnknown), Message: err.Error(), } - if connectErr, ok := asError(err); ok { - status.Code = int32(connectErr.Code()) - status.Message = connectErr.Message() - status.Details = connectErr.detailsAsAny() + if tripleErr, ok := asError(err); ok { + status.Code = int32(tripleErr.Code()) + status.Message = tripleErr.Message() + status.Details = tripleErr.detailsAsAny() } return status } diff --git a/protocol/triple/triple_protocol/protocol_grpc_test.go b/protocol/triple/triple_protocol/protocol_grpc_test.go index 200fcb5e9..6a454831a 100644 --- a/protocol/triple/triple_protocol/protocol_grpc_test.go +++ b/protocol/triple/triple_protocol/protocol_grpc_test.go @@ -15,7 +15,10 @@ package triple_protocol import ( + "compress/gzip" + "encoding/binary" "errors" + "io" "math" "net/http" "net/http/httptest" @@ -34,6 +37,46 @@ import ( "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol/internal/assert" ) +func TestGRPCClient_WriteRequestHeader(t *testing.T) { + tests := []struct { + desc string + params protocolClientParams + input http.Header + expect http.Header + }{ + { + params: protocolClientParams{ + Codec: &protoJSONCodec{codecNameJSON}, + CompressionName: compressionGzip, + CompressionPools: newReadOnlyCompressionPools(map[string]*compressionPool{ + compressionGzip: newCompressionPool( + func() Decompressor { return &gzip.Reader{} }, + func() Compressor { return gzip.NewWriter(io.Discard) }, + ), + }, []string{compressionGzip}), + }, + input: map[string][]string{}, + // todo(DMwangnima): add const for these literals + expect: map[string][]string{ + headerUserAgent: {defaultGrpcUserAgent}, + headerContentType: {grpcContentTypePrefix + codecNameJSON}, + "Accept-Encoding": {compressionIdentity}, + grpcHeaderCompression: {compressionGzip}, + grpcHeaderAcceptCompression: {compressionGzip}, + "Te": {"trailers"}, + }, + }, + } + + for _, test := range tests { + cli := &grpcClient{ + protocolClientParams: test.params, + } + cli.WriteRequestHeader(StreamType(4), test.input) + assert.Equal(t, test.input, test.expect) + } +} + func TestGRPCHandlerSender(t *testing.T) { t.Parallel() newConn := func(web bool) *grpcHandlerConn { @@ -116,7 +159,7 @@ func TestGRPCParseTimeout(t *testing.T) { assert.Nil(t, err) assert.Equal(t, duration, 45*time.Second) - const long = "99999999S" + var long = "99999999S" duration, err = grpcParseTimeout(long) // 8 digits, shouldn't overflow assert.Nil(t, err) assert.Equal(t, duration, 99999999*time.Second) @@ -195,7 +238,8 @@ func TestGRPCWebTrailerMarshalling(t *testing.T) { trailer.Add("User-Provided", "bar") err := marshaler.MarshalWebTrailers(trailer) assert.Nil(t, err) - responseWriter.Body.Next(5) // skip flags and message length + assert.Equal(t, responseWriter.Body.Next(1)[0], byte(grpcFlagEnvelopeTrailer)) + assert.Equal(t, binary.BigEndian.Uint32(responseWriter.Body.Next(4)), uint32(55)) marshaled := responseWriter.Body.String() assert.Equal(t, marshaled, "grpc-message: Foo\r\ngrpc-status: 0\r\nuser-provided: bar\r\n") } diff --git a/protocol/triple/triple_protocol/protocol_triple.go b/protocol/triple/triple_protocol/protocol_triple.go index cde9cf0f1..84a8b6668 100644 --- a/protocol/triple/triple_protocol/protocol_triple.go +++ b/protocol/triple/triple_protocol/protocol_triple.go @@ -40,7 +40,7 @@ const ( tripleUnaryTrailerPrefix = "Trailer-" tripleHeaderTimeout = "Triple-Timeout-Ms" tripleHeaderProtocolVersion = "Triple-Protocol-Version" - tripleProtocolVersion = "1" + tripleProtocolVersion = "0.1.0" tripleUnaryContentTypePrefix = "application/" tripleUnaryContentTypeJSON = tripleUnaryContentTypePrefix + "json" @@ -235,9 +235,8 @@ func (c *tripleClient) WriteRequestHeader(streamType StreamType, header http.Hea header[headerContentType] = []string{ tripleContentTypeFromCodecName(streamType, c.Codec.Name()), } - acceptCompressionHeader := tripleUnaryHeaderAcceptCompression if acceptCompression := c.CompressionPools.CommaSeparatedNames(); acceptCompression != "" { - header[acceptCompressionHeader] = []string{acceptCompression} + header[tripleUnaryHeaderAcceptCompression] = []string{acceptCompression} } } @@ -256,7 +255,6 @@ func (c *tripleClient) NewConn( } } duplexCall := newDuplexHTTPCall(ctx, c.HTTPClient, c.URL, spec, header) - var conn StreamingClientConn unaryConn := &tripleUnaryClientConn{ spec: spec, peer: c.Peer(), @@ -284,9 +282,8 @@ func (c *tripleClient) NewConn( responseHeader: make(http.Header), responseTrailer: make(http.Header), } - conn = unaryConn duplexCall.SetValidateResponse(unaryConn.validateResponse) - return wrapClientConnWithCodedErrors(conn) + return wrapClientConnWithCodedErrors(unaryConn) } type tripleUnaryClientConn struct { @@ -557,6 +554,7 @@ func (u *tripleUnaryUnmarshaler) UnmarshalFunc(message interface{}, unmarshal fu reader = io.LimitReader(u.reader, int64(u.readMaxBytes)+1) } // ReadFrom ignores io.EOF, so any error here is real. + // use io.LimitReader to prevent ReadFrom from panic bytesRead, err := data.ReadFrom(reader) if err != nil { if tripleErr, ok := asError(err); ok { @@ -568,7 +566,7 @@ func (u *tripleUnaryUnmarshaler) UnmarshalFunc(message interface{}, unmarshal fu return errorf(CodeUnknown, "read message: %w", err) } if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) { - // Attempt to read to end in order to allow tripleion re-use + // Attempt to read to end in order to allow connection re-use discardedBytes, err := io.Copy(io.Discard, u.reader) if err != nil { return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", u.readMaxBytes, err) diff --git a/protocol/triple/triple_protocol/triple.go b/protocol/triple/triple_protocol/triple.go index 3fe0a7fb8..251ad8fa8 100644 --- a/protocol/triple/triple_protocol/triple.go +++ b/protocol/triple/triple_protocol/triple.go @@ -13,26 +13,27 @@ // limitations under the License. // Package triple is a slim RPC framework built on Protocol Buffers and -// [net/http]. In addition to supporting its own protocol, Connect handlers and -// clients are wire-compatible with gRPC and gRPC-Web, including streaming. +// [net/http]. In addition to supporting its own protocol, Triple handlers and +// clients are wire-compatible with gRPC, including streaming. // // This documentation is intended to explain each type and function in // isolation. Walkthroughs, FAQs, and other narrative docs are available on the -// [Connect website], and there's a working [demonstration service] on Github. +// [dubbo-go website], and there's a working [demonstration service] on Github. // -// [Connect website]: https://connect.build -// [demonstration service]: https://github.com/bufbuild/connect-demo +// [dubbo-go website]: https://cn.dubbo.apache.org/zh-cn/overview/mannual/golang-sdk/ +// [demonstration service]: https://github.com/apache/dubbo-go-samples package triple_protocol import ( "errors" + "fmt" "io" "net/http" "net/url" ) // Version is the semantic version of the triple module. -const Version = "1.7.0-dev" +const Version = "0.1.0" // These constants are used in compile-time handshakes with triple's generated // code. @@ -85,16 +86,16 @@ type StreamingHandlerConn interface { } // StreamingClientConn is the client's view of a bidirectional message exchange. -// Interceptors for streaming RPCs may wrap StreamingClientConns. +// Interceptors for streaming RPCs may wrap StreamingClientConn. // -// StreamingClientConns write request headers to the network with the first +// StreamingClientConn write request headers to the network with the first // call to Send. Any subsequent mutations are effectively no-ops. When the // server is done sending data, the StreamingClientConn's Receive method // returns an error wrapping [io.EOF]. Clients should check for this using the -// standard library's [errors.Is]. If the server encounters an error during -// processing, subsequent calls to the StreamingClientConn's Send method will -// return an error wrapping [io.EOF]; clients may then call Receive to unmarshal -// the error. +// standard library's [errors.Is] or [IsEnded]. If the server encounters an error +// during processing, subsequent calls to the StreamingClientConn's Send method +// will return an error wrapping [io.EOF]; clients may then call Receive to +// unmarshal the error. // // Headers and trailers beginning with "Triple-" and "Grpc-" are reserved for // use by the gRPC and Triple protocols: applications may read them but @@ -318,13 +319,12 @@ type handlerConnCloser interface { // receiveUnaryResponse unmarshals a message from a StreamingClientConn, then // envelopes the message and attaches headers and trailers. It attempts to -// consume the response stream and isn't appropriate when receiving multiple +// consume the response stream and is not appropriate when receiving multiple // messages. func receiveUnaryResponse(conn StreamingClientConn, response AnyResponse) error { resp, ok := response.(*Response) if !ok { - // todo: add a more reasonable sentence - panic("wrong type") + panic(fmt.Sprintf("response %T is not of Response type", response)) } if err := conn.Receive(resp.Msg); err != nil { return err @@ -332,8 +332,6 @@ func receiveUnaryResponse(conn StreamingClientConn, response AnyResponse) error // In a well-formed stream, the response message may be followed by a block // of in-stream trailers or HTTP trailers. To ensure that we receive the // trailers, try to read another message from the stream. - // if err := conn.Receive(new(T)); err == nil { - // todo:// maybe using copy method if err := conn.Receive(resp.Msg); err == nil { return NewError(CodeUnknown, errors.New("unary stream has multiple messages")) } else if err != nil && !errors.Is(err, io.EOF) { diff --git a/protocol/triple/triple_protocol/triple_ext_test.go b/protocol/triple/triple_protocol/triple_ext_test.go index 1fd7188d7..d4b82543d 100644 --- a/protocol/triple/triple_protocol/triple_ext_test.go +++ b/protocol/triple/triple_protocol/triple_ext_test.go @@ -22,7 +22,6 @@ import ( "errors" "fmt" "io" - "math" "math/rand" "net/http" "net/http/httptest" @@ -52,12 +51,18 @@ const errorMessage = "oh no" // client doesn't set a header, and the server sets headers and trailers on the // response. const ( - headerValue = "some header value" - trailerValue = "some trailer value" - clientHeader = "Connect-Client-Header" - handlerHeader = "Connect-Handler-Header" - handlerTrailer = "Connect-Handler-Trailer" - clientMiddlewareErrorHeader = "Connect-Trigger-HTTP-Error" + headerValue = "some header value" + trailerValue = "some trailer value" + clientHeader = "Triple-Client-Header" + // use this header to tell server to mock timeout scenario + clientTimeoutHeader = "Triple-Client-Timeout-Header" + handlerHeader = "Triple-Handler-Header" + handlerTrailer = "Triple-Handler-Trailer" + clientMiddlewareErrorHeader = "Triple-Trigger-HTTP-Error" + + // since there is no math.MaxInt for go1.16, we need to define it for compatibility + intSize = 32 << (^uint(0) >> 63) // 32 or 64 + maxInt = 1<<(intSize-1) - 1 ) func TestServer(t *testing.T) { @@ -107,6 +112,8 @@ func TestServer(t *testing.T) { assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) }) t.Run("ping_error", func(t *testing.T) { + // please see pingServer.Ping(). + // if we do not send clientHeader: headerValue to pingServer.Ping(), it would return error err := client.Ping( context.Background(), triple.NewRequest(&pingv1.PingRequest{}), @@ -114,11 +121,23 @@ func TestServer(t *testing.T) { ) assert.Equal(t, triple.CodeOf(err), triple.CodeInvalidArgument) }) - t.Run("ping_timeout", func(t *testing.T) { + t.Run("ping_invalid_timeout", func(t *testing.T) { + // invalid Deadline ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) defer cancel() request := triple.NewRequest(&pingv1.PingRequest{}) request.Header().Set(clientHeader, headerValue) + // since we would inspect ctx error before sending request, this invocation would return DeadlineExceeded directly + err := client.Ping(ctx, request, triple.NewResponse(&pingv1.PingResponse{})) + assert.Equal(t, triple.CodeOf(err), triple.CodeDeadlineExceeded) + }) + t.Run("ping_timeout", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + request := triple.NewRequest(&pingv1.PingRequest{}) + request.Header().Set(clientHeader, headerValue) + // tell server to mock timeout + request.Header().Set(clientTimeoutHeader, (2 * time.Second).String()) err := client.Ping(ctx, request, triple.NewResponse(&pingv1.PingResponse{})) assert.Equal(t, triple.CodeOf(err), triple.CodeDeadlineExceeded) }) @@ -165,6 +184,31 @@ func TestServer(t *testing.T) { assert.Equal(t, msg, &pingv1.SumResponse{}) // receive header only stream assert.Equal(t, got.Header().Values(handlerHeader), []string{headerValue}) }) + t.Run("sum_invalid_timeout", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + stream, err := client.Sum(ctx) + assert.Nil(t, err) + stream.RequestHeader().Set(clientHeader, headerValue) + msg := &pingv1.SumResponse{} + got := triple.NewResponse(msg) + err = stream.CloseAndReceive(got) + // todo(DMwangnima): for now, invalid timeout would be encoded as "Grpc-Timeout: 0n". + // it would not inspect err like unary call. We should refer to grpc-go. + assert.Equal(t, triple.CodeOf(err), triple.CodeDeadlineExceeded) + }) + t.Run("sum_timeout", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + stream, err := client.Sum(ctx) + assert.Nil(t, err) + stream.RequestHeader().Set(clientHeader, headerValue) + stream.RequestHeader().Set(clientTimeoutHeader, (2 * time.Second).String()) + msg := &pingv1.SumResponse{} + got := triple.NewResponse(msg) + err = stream.CloseAndReceive(got) + assert.Equal(t, triple.CodeOf(err), triple.CodeDeadlineExceeded) + }) } testCountUp := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("count_up", func(t *testing.T) { @@ -195,19 +239,38 @@ func TestServer(t *testing.T) { for stream.Receive(&pingv1.CountUpResponse{}) { t.Fatalf("expected error, shouldn't receive any messages") } - assert.Equal( - t, - triple.CodeOf(stream.Err()), - triple.CodeInvalidArgument, - ) + assert.Equal(t, triple.CodeOf(stream.Err()), triple.CodeInvalidArgument) }) - t.Run("count_up_timeout", func(t *testing.T) { + t.Run("count_up_invalid_argument", func(t *testing.T) { + request := triple.NewRequest(&pingv1.CountUpRequest{Number: -1}) + request.Header().Set(clientHeader, headerValue) + stream, err := client.CountUp(context.Background(), request) + assert.Nil(t, err) + for stream.Receive(&pingv1.CountUpResponse{}) { + t.Fatalf("expected error, shouldn't receive any messages") + } + assert.Equal(t, triple.CodeOf(stream.Err()), triple.CodeInvalidArgument) + }) + t.Run("count_up_invalid_timeout", func(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) defer cancel() _, err := client.CountUp(ctx, triple.NewRequest(&pingv1.CountUpRequest{Number: 1})) assert.NotNil(t, err) assert.Equal(t, triple.CodeOf(err), triple.CodeDeadlineExceeded) }) + t.Run("count_up_timeout", func(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + request := triple.NewRequest(&pingv1.CountUpRequest{Number: 1}) + request.Header().Set(clientHeader, headerValue) + request.Header().Set(clientTimeoutHeader, (2 * time.Second).String()) + stream, err := client.CountUp(ctx, request) + assert.Nil(t, err) + for stream.Receive(&pingv1.CountUpResponse{}) { + t.Fatalf("expected error, shouldn't receive any messages") + } + assert.Equal(t, triple.CodeOf(stream.Err()), triple.CodeDeadlineExceeded) + }) } testCumSum := func(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper t.Run("cumsum", func(t *testing.T) { @@ -1138,7 +1201,7 @@ func TestClientWithReadMaxBytes(t *testing.T) { if enableCompression { compressionOption = triple.WithCompressMinBytes(1) } else { - compressionOption = triple.WithCompressMinBytes(math.MaxInt) + compressionOption = triple.WithCompressMinBytes(maxInt) } mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, compressionOption)) server := httptest.NewUnstartedServer(mux) @@ -1273,7 +1336,7 @@ func TestHandlerWithSendMaxBytes(t *testing.T) { if compressed { options = append(options, triple.WithCompressMinBytes(1)) } else { - options = append(options, triple.WithCompressMinBytes(math.MaxInt)) + options = append(options, triple.WithCompressMinBytes(maxInt)) } mux.Handle(pingv1connect.NewPingServiceHandler( pingServer{}, @@ -1614,14 +1677,14 @@ func TestStreamForServer(t *testing.T) { }) } -func TestConnectHTTPErrorCodes(t *testing.T) { +func TestTripleHTTPErrorCodes(t *testing.T) { t.Parallel() - checkHTTPStatus := func(t *testing.T, connectCode triple.Code, wantHttpStatus int) { + checkHTTPStatus := func(t *testing.T, tripleCode triple.Code, wantHttpStatus int) { t.Helper() mux := http.NewServeMux() pluggableServer := &pluggablePingServer{ ping: func(_ context.Context, _ *triple.Request) (*triple.Response, error) { - return nil, triple.NewError(connectCode, errors.New("error")) + return nil, triple.NewError(tripleCode, errors.New("error")) }, } mux.Handle(pingv1connect.NewPingServiceHandler(pluggableServer)) @@ -1758,6 +1821,7 @@ func TestUnflushableResponseWriter(t *testing.T) { assert.Equal(t, triple.CodeOf(err), triple.CodeInternal, assert.Sprintf("got %v", err)) assert.True( t, + // please see checkServerStreamsCanFlush() for detail strings.HasSuffix(err.Error(), "unflushableWriter does not implement http.Flusher"), assert.Sprintf("error doesn't reference http.Flusher: %s", err.Error()), ) @@ -1779,8 +1843,8 @@ func TestUnflushableResponseWriter(t *testing.T) { }{ {"grpc", nil}, } - for _, tt := range tests { - tt := tt + for _, test := range tests { + tt := test t.Run(tt.name, func(t *testing.T) { t.Parallel() pingclient := pingv1connect.NewPingServiceClient(server.Client(), server.URL, tt.options...) @@ -1834,14 +1898,14 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { _, err = io.Copy(io.Discard, res.Body) assert.Nil(t, err) assert.Nil(t, res.Body.Close()) - assert.NotZero(t, res.Trailer.Get(handlerHeader)) - assert.NotZero(t, res.Trailer.Get(handlerTrailer)) + assert.Equal(t, res.Trailer.Get(handlerHeader), headerValue) + assert.Equal(t, res.Trailer.Get(handlerTrailer), trailerValue) } -func TestConnectProtocolHeaderSentByDefault(t *testing.T) { +func TestTripleProtocolHeaderSentByDefault(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, triple.WithRequireConnectProtocolHeader())) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, triple.WithRequireTripleProtocolHeader())) server := httptest.NewUnstartedServer(mux) server.EnableHTTP2 = true server.StartTLS() @@ -1860,23 +1924,25 @@ func TestConnectProtocolHeaderSentByDefault(t *testing.T) { assert.Nil(t, stream.CloseResponse()) } -func TestConnectProtocolHeaderRequired(t *testing.T) { +// todo(DMwangnima): we need to expose this functionality as a configuration to dubbo-go +func TestTripleProtocolHeaderRequired(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler( pingServer{}, - triple.WithRequireConnectProtocolHeader(), + triple.WithRequireTripleProtocolHeader(), )) server := httptest.NewServer(mux) t.Cleanup(server.Close) tests := []struct { + desc string headers http.Header }{ - {http.Header{}}, - {http.Header{"Connect-Protocol-Version": []string{"0"}}}, + {"empty header", http.Header{}}, + {"invalid version", http.Header{"Triple-Protocol-Version": []string{"0"}}}, } - for _, tcase := range tests { + for _, test := range tests { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, @@ -1885,7 +1951,7 @@ func TestConnectProtocolHeaderRequired(t *testing.T) { ) assert.Nil(t, err) req.Header.Set("Content-Type", "application/json") - for k, v := range tcase.headers { + for k, v := range test.headers { req.Header[k] = v } response, err := server.Client().Do(req) @@ -1916,11 +1982,11 @@ func TestAllowCustomUserAgent(t *testing.T) { protocol string opts []triple.ClientOption }{ - {"triple", nil}, + {"triple", []triple.ClientOption{triple.WithTriple()}}, {"grpc", nil}, } - for _, testCase := range tests { - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, testCase.opts...) + for _, test := range tests { + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, test.opts...) req := triple.NewRequest(&pingv1.PingRequest{Number: 42}) req.Header().Set("User-Agent", customAgent) err := client.Ping(context.Background(), req, triple.NewResponse(&pingv1.PingResponse{})) @@ -2006,6 +2072,59 @@ func TestBlankImportCodeGeneration(t *testing.T) { assert.NotNil(t, desc) } +func TestDefaultTimeout(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + server := httptest.NewUnstartedServer(mux) + server.EnableHTTP2 = true + server.StartTLS() + t.Cleanup(server.Close) + + defaultTimeout := 3 * time.Second + serverTimeout := 2 * time.Second + tests := []struct { + desc string + cliOpts []triple.ClientOption + }{ + { + desc: "Triple protocol", + cliOpts: []triple.ClientOption{ + triple.WithTriple(), + triple.WithTimeout(defaultTimeout), + }, + }, + { + desc: "gRPC protocol", + cliOpts: []triple.ClientOption{ + triple.WithTimeout(defaultTimeout), + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, test.cliOpts...) + request := triple.NewRequest(&pingv1.PingRequest{}) + request.Header().Set(clientHeader, headerValue) + // tell server to mock timeout + request.Header().Set(clientTimeoutHeader, (serverTimeout).String()) + err := client.Ping(context.Background(), request, triple.NewResponse(&pingv1.PingResponse{})) + assert.Nil(t, err) + + // specify timeout to override default timeout + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + newRequest := triple.NewRequest(&pingv1.PingRequest{}) + newRequest.Header().Set(clientHeader, headerValue) + // tell server to mock timeout + newRequest.Header().Set(clientTimeoutHeader, (serverTimeout).String()) + newErr := client.Ping(ctx, request, triple.NewResponse(&pingv1.PingResponse{})) + assert.Equal(t, triple.CodeOf(newErr), triple.CodeDeadlineExceeded) + }) + } +} + type unflushableWriter struct { w http.ResponseWriter } @@ -2132,6 +2251,11 @@ func (p pingServer) Ping(ctx context.Context, request *triple.Request) (*triple. if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } + if timeoutStr := request.Header().Get(clientTimeoutHeader); timeoutStr != "" { + // got timeout instruction + timeout, _ := time.ParseDuration(timeoutStr) + time.Sleep(timeout) + } if request.Peer().Addr == "" { return nil, triple.NewError(triple.CodeInternal, errors.New("no peer address")) } @@ -2176,6 +2300,11 @@ func (p pingServer) Sum( return nil, err } } + if timeoutStr := stream.RequestHeader().Get(clientTimeoutHeader); timeoutStr != "" { + // got timeout instruction + timeout, _ := time.ParseDuration(timeoutStr) + time.Sleep(timeout) + } if stream.Peer().Addr == "" { return nil, triple.NewError(triple.CodeInternal, errors.New("no peer address")) } @@ -2205,6 +2334,11 @@ func (p pingServer) CountUp( if err := expectClientHeader(p.checkMetadata, request); err != nil { return err } + if timeoutStr := request.Header().Get(clientTimeoutHeader); timeoutStr != "" { + // got timeout instruction + timeout, _ := time.ParseDuration(timeoutStr) + time.Sleep(timeout) + } if request.Peer().Addr == "" { return triple.NewError(triple.CodeInternal, errors.New("no peer address")) } diff --git a/protocol/triple/triple_test.go b/protocol/triple/triple_test.go index f96caf34d..8c2b4281b 100644 --- a/protocol/triple/triple_test.go +++ b/protocol/triple/triple_test.go @@ -23,6 +23,7 @@ import ( "fmt" "io" "os" + "reflect" "strings" "testing" "time" @@ -35,6 +36,7 @@ import ( ) import ( + "dubbo.apache.org/dubbo-go/v3/client" "dubbo.apache.org/dubbo-go/v3/common" "dubbo.apache.org/dubbo-go/v3/common/constant" "dubbo.apache.org/dubbo-go/v3/common/extension" @@ -52,9 +54,10 @@ import ( ) const ( - triplePort = "20000" - dubbo3Port = "20001" + triplePort = "21000" + dubbo3Port = "21001" listenAddr = "0.0.0.0" + localAddr = "127.0.0.1" name = "triple" ) @@ -153,47 +156,81 @@ func TestMain(m *testing.M) { } func TestInvoke(t *testing.T) { - invokeFunc := func(t *testing.T, port string, interfaceName string, methods []string) { - url := common.NewURLWithOptions( + tripleInvokerInit := func(location string, port string, interfaceName string, methods []string, info *client.ClientInfo) (protocol.Invoker, error) { + newURL := common.NewURLWithOptions( common.WithInterface(interfaceName), - common.WithLocation("127.0.0.1"), + common.WithLocation(location), common.WithPort(port), common.WithMethods(methods), - common.WithProtocol(TRIPLE), + common.WithAttribute(constant.ClientInfoKey, info), ) - invoker, err := NewTripleInvoker(url) - if err != nil { - t.Fatal(err) + return NewTripleInvoker(newURL) + } + dubbo3InvokerInit := func(location string, port string, interfaceName string, svc common.RPCService) (protocol.Invoker, error) { + newURL := common.NewURLWithOptions( + common.WithInterface(interfaceName), + common.WithLocation(location), + common.WithPort(port), + ) + // dubbo3 needs to retrieve ConsumerService directly + config.SetConsumerServiceByInterfaceName(interfaceName, svc) + return NewDubbo3Invoker(newURL) + } + tripleInvocationInit := func(methodName string, rawParams []interface{}, callType string) protocol.Invocation { + newInv := invocation_impl.NewRPCInvocationWithOptions( + invocation_impl.WithMethodName(methodName), + invocation_impl.WithParameterRawValues(rawParams), + ) + newInv.SetAttribute(constant.CallTypeKey, callType) + return newInv + } + dubbo3InvocationInit := func(methodName string, params []reflect.Value, reply interface{}) protocol.Invocation { + newInv := invocation_impl.NewRPCInvocationWithOptions( + invocation_impl.WithMethodName(methodName), + invocation_impl.WithParameterValues(params), + ) + newInv.SetReply(reply) + return newInv + } + dubbo3ReplyInit := func(fieldType reflect.Type) interface{} { + var reply reflect.Value + replyType := fieldType.Out(0) + if replyType.Kind() == reflect.Ptr { + reply = reflect.New(replyType.Elem()) + } else { + reply = reflect.New(replyType) } + return reply.Interface() + } + + invokeTripleCodeFunc := func(t *testing.T, invoker protocol.Invoker) { tests := []struct { - desc string methodName string - params []interface{} - invoke func(t *testing.T, params []interface{}, res protocol.Result) callType string + rawParams []interface{} + validate func(t *testing.T, rawParams []interface{}, res protocol.Result) }{ { - desc: "Unary", methodName: "Greet", - params: []interface{}{ + callType: constant.CallUnary, + rawParams: []interface{}{ &greet.GreetRequest{ Name: name, }, &greet.GreetResponse{}, }, - invoke: func(t *testing.T, params []interface{}, res protocol.Result) { + validate: func(t *testing.T, params []interface{}, res protocol.Result) { assert.Nil(t, res.Result()) assert.Nil(t, res.Error()) req := params[0].(*greet.GreetRequest) resp := params[1].(*greet.GreetResponse) assert.Equal(t, req.Name, resp.Greeting) }, - callType: constant.CallUnary, }, { - desc: "ClientStream", methodName: "GreetClientStream", - invoke: func(t *testing.T, params []interface{}, res protocol.Result) { + callType: constant.CallClientStream, + validate: func(t *testing.T, params []interface{}, res protocol.Result) { assert.Nil(t, res.Error()) streamRaw, ok := res.Result().(*triple_protocol.ClientStreamForClient) assert.True(t, ok) @@ -211,17 +248,16 @@ func TestInvoke(t *testing.T) { assert.Nil(t, err) assert.Equal(t, expectStr, resp.Greeting) }, - callType: constant.CallClientStream, }, { - desc: "ServerStream", methodName: "GreetServerStream", - params: []interface{}{ + callType: constant.CallServerStream, + rawParams: []interface{}{ &greet.GreetServerStreamRequest{ Name: "dubbo", }, }, - invoke: func(t *testing.T, params []interface{}, res protocol.Result) { + validate: func(t *testing.T, params []interface{}, res protocol.Result) { assert.Nil(t, res.Error()) req := params[0].(*greet.GreetServerStreamRequest) streamRaw, ok := res.Result().(*triple_protocol.ServerStreamForClient) @@ -236,12 +272,11 @@ func TestInvoke(t *testing.T) { assert.True(t, true, errors.Is(stream.Err(), io.EOF)) } }, - callType: constant.CallServerStream, }, { - desc: "BidiStream", methodName: "GreetStream", - invoke: func(t *testing.T, params []interface{}, res protocol.Result) { + callType: constant.CallBidiStream, + validate: func(t *testing.T, params []interface{}, res protocol.Result) { assert.Nil(t, res.Error()) streamRaw, ok := res.Result().(*triple_protocol.BidiStreamForClient) assert.True(t, ok) @@ -256,37 +291,130 @@ func TestInvoke(t *testing.T) { assert.Nil(t, stream.CloseRequest()) assert.Nil(t, stream.CloseResponse()) }, - callType: constant.CallBidiStream, }, } for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - inv := invocation_impl.NewRPCInvocationWithOptions( - invocation_impl.WithMethodName(test.methodName), - // todo: process opts - invocation_impl.WithParameterRawValues(test.params), - ) - inv.SetAttribute(constant.CallTypeKey, test.callType) + t.Run(test.methodName, func(t *testing.T) { + inv := tripleInvocationInit(test.methodName, test.rawParams, test.callType) res := invoker.Invoke(context.Background(), inv) - test.invoke(t, test.params, res) + test.validate(t, test.rawParams, res) }) } } - t.Run("invoke server code generated by triple", func(t *testing.T) { - invokeFunc(t, triplePort, greettriple.GreetService_ClientInfo.InterfaceName, greettriple.GreetService_ClientInfo.MethodNames) - }) - t.Run("invoke server code generated by dubbo3", func(t *testing.T) { - desc := dubbo3_greet.GreetService_ServiceDesc - var methods []string - for _, method := range desc.Methods { - methods = append(methods, method.MethodName) + invokeDubbo3CodeFunc := func(t *testing.T, invoker protocol.Invoker, svc common.RPCService) { + tests := []struct { + methodName string + params []reflect.Value + validate func(t *testing.T, params []reflect.Value, res protocol.Result) + }{ + { + methodName: "Greet", + params: []reflect.Value{ + reflect.ValueOf(&greet.GreetRequest{ + Name: name, + }), + }, + validate: func(t *testing.T, Params []reflect.Value, res protocol.Result) { + assert.Nil(t, res.Error()) + req := Params[0].Interface().(*greet.GreetRequest) + resp := res.Result().(*greet.GreetResponse) + assert.Equal(t, req.Name, resp.Greeting) + }, + }, + { + methodName: "GreetClientStream", + validate: func(t *testing.T, reflectParams []reflect.Value, res protocol.Result) { + assert.Nil(t, res.Error()) + stream, ok := res.Result().(*dubbo3_greet.GreetService_GreetClientStreamClient) + assert.True(t, ok) + + var expectRes []string + times := 5 + for i := 1; i <= times; i++ { + expectRes = append(expectRes, name) + err := (*stream).Send(&greet.GreetClientStreamRequest{Name: name}) + assert.Nil(t, err) + } + expectStr := strings.Join(expectRes, ",") + resp, err := (*stream).CloseAndRecv() + assert.Nil(t, err) + assert.Equal(t, expectStr, resp.Greeting) + }, + }, + { + methodName: "GreetServerStream", + params: []reflect.Value{ + reflect.ValueOf(&greet.GreetServerStreamRequest{ + Name: "dubbo", + }), + }, + validate: func(t *testing.T, params []reflect.Value, res protocol.Result) { + assert.Nil(t, res.Error()) + req := params[0].Interface().(*greet.GreetServerStreamRequest) + stream, ok := res.Result().(*dubbo3_greet.GreetService_GreetServerStreamClient) + assert.True(t, ok) + times := 5 + for i := 1; i <= times; i++ { + msg, err := (*stream).Recv() + assert.Nil(t, err) + assert.Equal(t, req.Name, msg.Greeting) + } + }, + }, + { + methodName: "GreetStream", + validate: func(t *testing.T, params []reflect.Value, res protocol.Result) { + assert.Nil(t, res.Error()) + stream, ok := res.Result().(*dubbo3_greet.GreetService_GreetStreamClient) + assert.True(t, ok) + for i := 1; i <= 5; i++ { + err := (*stream).Send(&greet.GreetStreamRequest{Name: name}) + assert.Nil(t, err) + resp, err := (*stream).Recv() + assert.Nil(t, err) + assert.Equal(t, name, resp.Greeting) + } + assert.Nil(t, (*stream).CloseSend()) + }, + }, } - for _, stream := range desc.Streams { - methods = append(methods, stream.StreamName) + + svcPtrVal := reflect.ValueOf(svc) + svcVal := svcPtrVal.Elem() + svcType := svcVal.Type() + for _, test := range tests { + t.Run(test.methodName, func(t *testing.T) { + funcField, ok := svcType.FieldByName(test.methodName) + assert.True(t, ok) + reply := dubbo3ReplyInit(funcField.Type) + inv := dubbo3InvocationInit(test.methodName, test.params, reply) + res := invoker.Invoke(context.Background(), inv) + test.validate(t, test.params, res) + }) } + } - invokeFunc(t, dubbo3Port, desc.ServiceName, methods) + t.Run("triple2triple", func(t *testing.T) { + invoker, err := tripleInvokerInit(localAddr, triplePort, greettriple.GreetService_ClientInfo.InterfaceName, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo) + assert.Nil(t, err) + invokeTripleCodeFunc(t, invoker) + }) + t.Run("triple2dubbo3", func(t *testing.T) { + invoker, err := tripleInvokerInit(localAddr, dubbo3Port, greettriple.GreetService_ClientInfo.InterfaceName, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo) + assert.Nil(t, err) + invokeTripleCodeFunc(t, invoker) + }) + t.Run("dubbo32triple", func(t *testing.T) { + svc := new(dubbo3_greet.GreetServiceClientImpl) + invoker, err := dubbo3InvokerInit(localAddr, triplePort, dubbo3_greet.GreetService_ServiceDesc.ServiceName, svc) + assert.Nil(t, err) + invokeDubbo3CodeFunc(t, invoker, svc) + }) + t.Run("dubbo32dubbo3", func(t *testing.T) { + svc := new(dubbo3_greet.GreetServiceClientImpl) + invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, dubbo3_greet.GreetService_ServiceDesc.ServiceName, svc) + assert.Nil(t, err) + invokeDubbo3CodeFunc(t, invoker, svc) }) - }