Similarityoung commented on code in PR #688: URL: https://github.com/apache/dubbo-go-pixiu/pull/688#discussion_r2174822752
########## pkg/common/codec/grpc/passthrough/codec.go: ########## @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package passthrough + +import ( + "fmt" +) + +import ( + "google.golang.org/grpc/encoding" + "google.golang.org/protobuf/proto" +) + +// Codec is a gRPC codec that passes through bytes as is. +// This is used for transparent proxying where the message types are unknown at compile time. +type Codec struct{} + +func init() { + encoding.RegisterCodec(Codec{}) +} + +// Marshal checks if the value is already bytes or a proto.Message and marshals accordingly. +func (c Codec) Marshal(v any) ([]byte, error) { + if p, ok := v.(proto.Message); ok { + return proto.Marshal(p) + } + if b, ok := v.([]byte); ok { + return b, nil + } + return nil, fmt.Errorf("passthrough codec: cannot marshal type %T, want proto.Message or []byte", v) +} + +// Unmarshal stores the raw data into the target, which must be a *[]byte or proto.Message. +func (c Codec) Unmarshal(data []byte, v any) error { + if vb, ok := v.(*[]byte); ok { + *vb = data + return nil + } + if p, ok := v.(proto.Message); ok { + return proto.Unmarshal(data, p) + } + return fmt.Errorf("passthrough codec: cannot unmarshal into type %T, want *[]byte or proto.Message", v) +} + +// Name returns the name of the codec. +func (c Codec) Name() string { + return "pass_through" +} Review Comment: to imply codec interface ########## pkg/filter/network/grpcproxy/filter/proxy/grpc_proxy_filter.go: ########## @@ -0,0 +1,530 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package proxy + +import ( + "context" + "crypto/tls" + "io" + "sync" + "time" +) + +import ( + "github.com/pkg/errors" + "google.golang.org/grpc" + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" +) + +import ( + ptcodec "github.com/apache/dubbo-go-pixiu/pkg/common/codec/grpc/passthrough" + "github.com/apache/dubbo-go-pixiu/pkg/common/constant" + "github.com/apache/dubbo-go-pixiu/pkg/common/extension/filter" + grpcCtx "github.com/apache/dubbo-go-pixiu/pkg/context/grpc" + "github.com/apache/dubbo-go-pixiu/pkg/logger" + "github.com/apache/dubbo-go-pixiu/pkg/server" +) + +// Constants for gRPC proxy filter +const ( + Kind = constant.GRPCProxyFilter + defaultKeepAliveTime = 300 * time.Second + defaultKeepAliveTimeout = 5 * time.Second + defaultConnectTimeout = 5 * time.Second + defaultMaxRetryCount = 3 + defaultMaxMsgSize = 4 * 1024 * 1024 // 4MB +) + +func init() { + filter.RegisterGrpcFilterPlugin(&Plugin{}) +} + +type ( + // Plugin gRPC proxy plugin implementation + Plugin struct{} + + // Config defines the configuration options for the gRPC proxy filter + Config struct { + EnableTLS bool `yaml:"enable_tls" json:"enable_tls" mapstructure:"enable_tls"` + TLSCertFile string `yaml:"tls_cert_file" json:"tls_cert_file" mapstructure:"tls_cert_file"` + TLSKeyFile string `yaml:"tls_key_file" json:"tls_key_file" mapstructure:"tls_key_file"` + MaxConcurrentStreams uint32 `yaml:"max_concurrent_streams" json:"max_concurrent_streams" mapstructure:"max_concurrent_streams"` + KeepAliveTimeStr string `yaml:"keepalive_time" json:"keepalive_time" mapstructure:"keepalive_time"` + KeepAliveTimeoutStr string `yaml:"keepalive_timeout" json:"keepalive_timeout" mapstructure:"keepalive_timeout"` + ConnectTimeoutStr string `yaml:"connect_timeout" json:"connect_timeout" mapstructure:"connect_timeout"` + KeepAliveTime time.Duration `yaml:"-" json:"-"` + KeepAliveTimeout time.Duration `yaml:"-" json:"-"` + ConnectTimeout time.Duration `yaml:"-" json:"-"` + } + + // Filter implements the gRPC proxy filter + Filter struct { + Config *Config + clientConnPool sync.Map // address -> *grpc.ClientConn + mu sync.RWMutex // protects concurrent operations + } +) + +// Kind return plugin kind +func (p Plugin) Kind() string { + return Kind +} + +// CreateFilter create gRPC proxy filter +func (p Plugin) CreateFilter(config any) (filter.GrpcFilter, error) { + cfg, ok := config.(*Config) + if !ok { + return nil, errors.New("gRPC proxy filter config type error") + } + + // Parse time durations from strings, with defaults + cfg.KeepAliveTime = parseDurationWithDefault(cfg.KeepAliveTimeStr, defaultKeepAliveTime) + cfg.KeepAliveTimeout = parseDurationWithDefault(cfg.KeepAliveTimeoutStr, defaultKeepAliveTimeout) + cfg.ConnectTimeout = parseDurationWithDefault(cfg.ConnectTimeoutStr, defaultConnectTimeout) + + return &Filter{Config: cfg}, nil +} + +// Config Expose the config so that Filter Manger can inject it, so it must be a pointer +func (p Plugin) Config() any { + return &Config{} +} + +// Handle processes gRPC invocation by routing to the appropriate backend +func (f *Filter) Handle(ctx *grpcCtx.GrpcContext) filter.FilterStatus { + // Validate context + if ctx == nil { + logger.Error("gRPC proxy received nil context") + return filter.Stop + } + + // Get route information + if ctx.Route == nil { + ctx.SetError(errors.New("gRPC proxy missing route information")) + return filter.Stop + } + + clusterName := ctx.Route.Cluster + if clusterName == "" { + ctx.SetError(errors.New("gRPC proxy missing cluster name")) + return filter.Stop + } + + // Get cluster manager + clusterManager := server.GetClusterManager() + if clusterManager == nil { + ctx.SetError(errors.New("gRPC proxy cluster manager not initialized")) + return filter.Stop + } + + // Select endpoint from cluster + endpoint := clusterManager.PickEndpoint(clusterName, ctx) + if endpoint == nil { + ctx.SetError(errors.Errorf("gRPC proxy can't find endpoint in cluster: %s", clusterName)) + return filter.Stop + } + + // Get target address + address := endpoint.Address.GetAddress() + if address == "" { + ctx.SetError(errors.New("gRPC proxy got empty endpoint address")) + return filter.Stop + } + + logger.Debugf("Forwarding gRPC request %s to cluster %s, endpoint %s", + ctx.ServiceName+"/"+ctx.MethodName, ctx.Route.Cluster, address) + + return f.handleStream(ctx, address) +} + +// handleStream handles all types of gRPC calls by creating a full-duplex stream pipe. +func (f *Filter) handleStream(ctx *grpcCtx.GrpcContext, address string) filter.FilterStatus { + // Get or create connection + conn, err := f.getOrCreateConnection(address) + if err != nil { + ctx.SetError(errors.Errorf("gRPC proxy failed to get connection: %v", err)) + return filter.Stop + } + + // Set metadata for the outgoing context + md := make(metadata.MD) + for k, v := range ctx.Attachments { + if str, ok := v.(string); ok { + md.Set(k, str) + } + } + outCtx := metadata.NewOutgoingContext(ctx.Context, md) + + // Create the full method path for the gRPC call + fullMethod := ctx.ServiceName + "/" + ctx.MethodName + // logger.Debugf("[dubbo-go-pixiu] gRPC proxy bidirectional stream to %s", fullMethod) + + // Create a new client stream to the backend + clientStream, err := conn.NewStream(outCtx, &grpc.StreamDesc{ + StreamName: ctx.MethodName, + ServerStreams: true, + ClientStreams: true, + }, fullMethod, grpc.ForceCodec(ptcodec.Codec{})) + + if err != nil { + ctx.SetError(errors.Errorf("failed to create client stream: %v", err)) + return filter.Stop + } + + // Ensure there is a server stream to work with + if ctx.Stream == nil { + ctx.SetError(errors.New("no stream available in context")) + return filter.Stop + } + + // Use a WaitGroup to coordinate the two forwarding goroutines + var wg sync.WaitGroup + wg.Add(2) + + // Channels for error propagation and termination signaling + errChan := make(chan error, 2) + doneChan := make(chan struct{}) + + // Start forwarding data in both directions + go f.forwardClientToServer(ctx, clientStream, &wg, errChan, doneChan) + go f.forwardServerToClient(ctx, clientStream, &wg, errChan, doneChan) + + // Goroutine to wait for context cancellation or the first error + go func() { + select { + case <-ctx.Context.Done(): + // If the client context is canceled, signal the forwarding goroutines to stop + close(doneChan) + case err := <-errChan: + // If an error occurs, propagate it and signal termination + ctx.SetError(err) + close(doneChan) + } + }() + + // Wait for both forwarding goroutines to complete + wg.Wait() + close(errChan) // Close channel to allow the final error check to complete + + // Final check for any errors that might have occurred + for err := range errChan { + if err != nil && ctx.Error == nil { + // Set error if one hasn't been set already + ctx.SetError(err) + } + } + + if ctx.Error != nil { + logger.Debugf("gRPC stream for %s completed with error: %v", fullMethod, ctx.Error) + return filter.Stop + } + + // The listener already logs the successful completion with duration. + // logger.Debugf("gRPC stream for %s completed successfully", fullMethod) + return filter.Continue +} + +// forwardClientToServer forwards messages from the incoming client stream to the backend server stream. +func (f *Filter) forwardClientToServer(ctx *grpcCtx.GrpcContext, clientStream grpc.ClientStream, wg *sync.WaitGroup, errChan chan<- error, doneChan <-chan struct{}) { + defer wg.Done() + + // Send initial arguments if available (for unary and server-stream calls) + if len(ctx.Arguments) > 0 { + for _, arg := range ctx.Arguments { + if err := clientStream.SendMsg(arg); err != nil { + errChan <- errors.Wrap(err, "failed to send initial message") + return + } + } + } + + // Continuously forward messages from the client stream + for { + select { + case <-doneChan: + // Stop forwarding if the done signal is received + return + default: + var msg []byte + if err := ctx.Stream.RecvMsg(&msg); err != nil { + if err == io.EOF { + // Client has finished sending, so close the send direction of the backend stream + if err := clientStream.CloseSend(); err != nil { + logger.Errorf("Error closing send stream to backend: %v", err) + } + return + } + errChan <- errors.Wrap(err, "error receiving from client") + return + } + + if err := clientStream.SendMsg(msg); err != nil { + errChan <- errors.Wrap(err, "error forwarding to backend") + return + } + } + } +} + +// forwardServerToClient forwards messages from the backend server stream to the incoming client stream. +func (f *Filter) forwardServerToClient(ctx *grpcCtx.GrpcContext, clientStream grpc.ClientStream, wg *sync.WaitGroup, errChan chan<- error, doneChan <-chan struct{}) { + defer wg.Done() + + // Forward header metadata from backend to client + if header, err := clientStream.Header(); err == nil { + if s, ok := ctx.Stream.(grpc.ServerStream); ok { + s.SetHeader(header) + } + } + + for { + select { + case <-doneChan: + // Stop forwarding if the done signal is received + return + default: + var resp []byte + err := clientStream.RecvMsg(&resp) + if err != nil { + // Upon any error from the backend, including EOF, forward the trailer metadata + if s, ok := ctx.Stream.(grpc.ServerStream); ok { + s.SetTrailer(clientStream.Trailer()) + } + if err != io.EOF { + // Propagate the actual gRPC status error, but not EOF + errChan <- err + } + return + } + + if err := ctx.Stream.SendMsg(resp); err != nil { + errChan <- errors.Wrap(err, "failed to forward response to client") + return + } + } + } +} + +// parseDurationWithDefault parses a string duration and returns a default if empty or invalid. +func parseDurationWithDefault(durationStr string, defaultVal time.Duration) time.Duration { + if durationStr == "" { + return defaultVal + } + d, err := time.ParseDuration(durationStr) + if err != nil { + logger.Warnf("Invalid duration format: '%s', using default %s", durationStr, defaultVal) + return defaultVal + } + return d +} + +// getOrCreateConnection retrieves an existing connection or creates a new one for a given address. +func (f *Filter) getOrCreateConnection(address string) (*grpc.ClientConn, error) { + if address == "" { + return nil, errors.New("cannot create connection to empty address") + } + + // Optimistic check without a lock. `sync.Map` is safe for concurrent reads. + if conn, ok := f.clientConnPool.Load(address); ok { + if grpcConn, ok := conn.(*grpc.ClientConn); ok { + state := grpcConn.GetState() + if state != connectivity.Shutdown && state != connectivity.TransientFailure { + logger.Debugf("Reusing existing connection to %s (state: %s)", address, state.String()) + return grpcConn, nil + } + // If the connection is stale, it will be handled by the write-lock path. + logger.Warnf("Found stale connection to %s in state %s, will create new one", address, state.String()) + } + } + + // If no valid connection is found, acquire a write lock to create one. + f.mu.Lock() + defer f.mu.Unlock() + + // Double-check if another goroutine created the connection while we were waiting for the lock. + if conn, ok := f.clientConnPool.Load(address); ok { + if grpcConn, ok := conn.(*grpc.ClientConn); ok { + state := grpcConn.GetState() + if state != connectivity.Shutdown && state != connectivity.TransientFailure { + logger.Debugf("Another goroutine created connection to %s, reusing it", address) + return grpcConn, nil + } + // The existing connection is bad, remove it before creating a new one. + f.clientConnPool.Delete(address) + } + } + + // Create a new connection. + logger.Infof("Creating new backend connection to %s", address) + conn, err := f.createConnection(address) + if err != nil { + return nil, errors.Wrapf(err, "failed to connect to %s", address) + } + + // Store the new connection in the pool. + f.clientConnPool.Store(address, conn) + + // Start a goroutine to monitor the connection's health. + go f.monitorConnection(address, conn) + + return conn, nil +} + +// monitorConnection periodically checks connection health and removes bad connections +func (f *Filter) monitorConnection(cacheKey string, conn *grpc.ClientConn) { + ticker := time.NewTicker(30 * time.Second) Review Comment: done -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: notifications-unsubscr...@dubbo.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: notifications-unsubscr...@dubbo.apache.org For additional commands, e-mail: notifications-h...@dubbo.apache.org