This is an automated email from the ASF dual-hosted git repository.

xuetaoli pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git


The following commit(s) were added to refs/heads/develop by this push:
     new 93543d130 feat: support cors for triple (#3090)
93543d130 is described below

commit 93543d13018ac027656282d8b9b2d68ce88f3548
Author: zbchi <[email protected]>
AuthorDate: Mon Feb 9 11:27:19 2026 +0800

    feat: support cors for triple (#3090)
    
    * feat: support cors for triple
    
    * refactor cors.go
    
    * add unit test for cors
    
    * remove options_cors into options
    
    * feat: improve CORS port matching logic
    
    * format code
    
    * improve cors config in protocol
    
    * fix SonarQube err
    
    * fix nil calling function
    
    * rm useless func
    
    * improve validation
    
    * simplify unit test
    
    * simplify the implementation
    
    * rm corsPolicy
    
    * fix: set vary with origin when depending on origin header
---
 global/config_test.go                        |   7 +
 global/cors_config.go                        |  75 +++++++
 global/triple_config.go                      |   5 +
 protocol/triple/options.go                   | 188 ++++++++++++++++
 protocol/triple/server.go                    |  12 +
 protocol/triple/triple_protocol/cors.go      | 316 +++++++++++++++++++++++++++
 protocol/triple/triple_protocol/cors_test.go | 201 +++++++++++++++++
 protocol/triple/triple_protocol/handler.go   |  44 +++-
 protocol/triple/triple_protocol/option.go    |  16 ++
 9 files changed, 862 insertions(+), 2 deletions(-)

diff --git a/global/config_test.go b/global/config_test.go
index 398e4e689..46f1293bb 100644
--- a/global/config_test.go
+++ b/global/config_test.go
@@ -187,6 +187,13 @@ func TestCloneConfig(t *testing.T) {
                CheckCompleteInequality(t, c, clone)
        })
 
+       t.Run("CorsConfig", func(t *testing.T) {
+               c := DefaultCorsConfig()
+               InitCheckCompleteInequality(t, c)
+               clone := c.Clone()
+               CheckCompleteInequality(t, c, clone)
+       })
+
        t.Run("Http3Config", func(t *testing.T) {
                c := DefaultHttp3Config()
                InitCheckCompleteInequality(t, c)
diff --git a/global/cors_config.go b/global/cors_config.go
new file mode 100644
index 000000000..7d55bcecf
--- /dev/null
+++ b/global/cors_config.go
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package global
+
+// CorsConfig represents the CORS configuration for Triple protocol handlers.
+type CorsConfig struct {
+       // AllowOrigins specifies the allowed origins for CORS requests.
+       // Supports exact match, wildcard "*", and subdomain wildcard like 
"https://*.example.com";.
+       // Empty slice means CORS is disabled.
+       AllowOrigins []string `yaml:"allow-origins" 
json:"allow-origins,omitempty"`
+
+       // AllowMethods specifies the allowed HTTP methods for CORS requests.
+       // If empty, methods will be automatically determined from handler 
registrations.
+       // OPTIONS method is always included automatically.
+       AllowMethods []string `yaml:"allow-methods" 
json:"allow-methods,omitempty"`
+
+       // AllowHeaders specifies the allowed request headers for CORS requests.
+       // If empty, the requested headers from Access-Control-Request-Headers 
will be echoed back.
+       AllowHeaders []string `yaml:"allow-headers" 
json:"allow-headers,omitempty"`
+
+       // ExposeHeaders specifies the headers that browsers are allowed to 
access.
+       ExposeHeaders []string `yaml:"expose-headers" 
json:"expose-headers,omitempty"`
+
+       // AllowCredentials indicates whether credentials are allowed in CORS 
requests.
+       // When true, AllowOrigins cannot contain "*".
+       AllowCredentials bool `yaml:"allow-credentials" 
json:"allow-credentials,omitempty"`
+
+       // MaxAge specifies the maximum age (in seconds) for preflight cache.
+       // Must be non-negative. If zero, disables caching (no 
Access-Control-Max-Age header will be sent).
+       // If positive, specifies the cache duration in seconds.
+       MaxAge int `yaml:"max-age" json:"max-age,omitempty"`
+}
+
+// DefaultCorsConfig returns a default CorsConfig instance.
+func DefaultCorsConfig() *CorsConfig {
+       return &CorsConfig{
+               AllowOrigins:     []string{},
+               AllowMethods:     []string{},
+               AllowHeaders:     []string{},
+               ExposeHeaders:    []string{},
+               AllowCredentials: false,
+               MaxAge:           0,
+       }
+}
+
+// Clone a new CorsConfig
+func (c *CorsConfig) Clone() *CorsConfig {
+       if c == nil {
+               return nil
+       }
+
+       return &CorsConfig{
+               AllowOrigins:     append([]string(nil), c.AllowOrigins...),
+               AllowMethods:     append([]string(nil), c.AllowMethods...),
+               AllowHeaders:     append([]string(nil), c.AllowHeaders...),
+               ExposeHeaders:    append([]string(nil), c.ExposeHeaders...),
+               AllowCredentials: c.AllowCredentials,
+               MaxAge:           c.MaxAge,
+       }
+}
diff --git a/global/triple_config.go b/global/triple_config.go
index 83b57c63d..004ff1073 100644
--- a/global/triple_config.go
+++ b/global/triple_config.go
@@ -34,6 +34,9 @@ type TripleConfig struct {
        // the config of http3 transport
        Http3 *Http3Config `yaml:"http3" json:"http3,omitempty"`
 
+       // Cors configures CORS for Triple protocol handlers
+       Cors *CorsConfig `yaml:"cors" json:"cors,omitempty"`
+
        //
        // for client
        //
@@ -46,6 +49,7 @@ type TripleConfig struct {
 func DefaultTripleConfig() *TripleConfig {
        return &TripleConfig{
                Http3: DefaultHttp3Config(),
+               Cors:  DefaultCorsConfig(),
        }
 }
 
@@ -59,6 +63,7 @@ func (t *TripleConfig) Clone() *TripleConfig {
                MaxServerSendMsgSize: t.MaxServerSendMsgSize,
                MaxServerRecvMsgSize: t.MaxServerRecvMsgSize,
                Http3:                t.Http3.Clone(),
+               Cors:                 t.Cors.Clone(),
 
                KeepAliveInterval: t.KeepAliveInterval,
                KeepAliveTimeout:  t.KeepAliveTimeout,
diff --git a/protocol/triple/options.go b/protocol/triple/options.go
index 5a6f37492..5ea92636f 100644
--- a/protocol/triple/options.go
+++ b/protocol/triple/options.go
@@ -18,9 +18,17 @@
 package triple
 
 import (
+       "errors"
+       "net/http"
+       "net/url"
+       "strings"
        "time"
 )
 
+import (
+       "github.com/dubbogo/gost/log/logger"
+)
+
 import (
        "dubbo.apache.org/dubbo-go/v3/global"
 )
@@ -92,6 +100,23 @@ func WithMaxServerRecvMsgSize(size string) Option {
        }
 }
 
+// WithCORS applies CORS configuration to triple options.
+// Invalid configs are logged as errors and ignored (no-op).
+func WithCORS(opts ...CORSOption) Option {
+       cors := global.DefaultCorsConfig()
+       for _, opt := range opts {
+               opt(cors)
+       }
+       if err := validateCorsConfig(cors); err != nil {
+               logger.Errorf("[TRIPLE] invalid CORS config: %v", err)
+               // Return a no-op function to ignore invalid CORS configuration
+               return func(*Options) {}
+       }
+       return func(opts *Options) {
+               opts.Triple.Cors = cors
+       }
+}
+
 // Http3Enable enables HTTP/3 support for the Triple protocol.
 // This option configures the server to start both HTTP/2 and HTTP/3 servers
 // simultaneously, providing modern HTTP/3 capabilities alongside traditional 
HTTP/2.
@@ -161,3 +186,166 @@ func Http3Negotiation(negotiation bool) Option {
                opts.Triple.Http3.Negotiation = negotiation
        }
 }
+
+// CORSOption configures a single aspect of CORS.
+type CORSOption func(*global.CorsConfig)
+
+// CORSAllowOrigins sets allowed origins for CORS requests.
+func CORSAllowOrigins(origins ...string) CORSOption {
+       return func(c *global.CorsConfig) {
+               c.AllowOrigins = append([]string(nil), origins...)
+       }
+}
+
+// CORSAllowMethods sets allowed HTTP methods for CORS requests.
+func CORSAllowMethods(methods ...string) CORSOption {
+       return func(c *global.CorsConfig) {
+               c.AllowMethods = append([]string(nil), methods...)
+       }
+}
+
+// CORSAllowHeaders sets allowed request headers for CORS requests.
+func CORSAllowHeaders(headers ...string) CORSOption {
+       return func(c *global.CorsConfig) {
+               c.AllowHeaders = append([]string(nil), headers...)
+       }
+}
+
+// CORSExposeHeaders sets headers exposed to the browser.
+func CORSExposeHeaders(headers ...string) CORSOption {
+       return func(c *global.CorsConfig) {
+               c.ExposeHeaders = append([]string(nil), headers...)
+       }
+}
+
+// CORSAllowCredentials toggles whether credentials are allowed.
+func CORSAllowCredentials(allow bool) CORSOption {
+       return func(c *global.CorsConfig) {
+               c.AllowCredentials = allow
+       }
+}
+
+// CORSMaxAge sets the max age for preflight cache.
+func CORSMaxAge(maxAge int) CORSOption {
+       return func(c *global.CorsConfig) {
+               c.MaxAge = maxAge
+       }
+}
+
+var validHTTPMethods = map[string]bool{
+       http.MethodGet:     true,
+       http.MethodHead:    true,
+       http.MethodPost:    true,
+       http.MethodPut:     true,
+       http.MethodPatch:   true,
+       http.MethodDelete:  true,
+       http.MethodConnect: true,
+       http.MethodOptions: true,
+       http.MethodTrace:   true,
+}
+
+// validateCorsConfig validates CORS configuration.
+func validateCorsConfig(cors *global.CorsConfig) error {
+       if cors == nil {
+               return nil
+       }
+
+       // Validate origins
+       for _, origin := range cors.AllowOrigins {
+               if origin == "" {
+                       return errors.New("allow-origins cannot contain empty 
string")
+               }
+               if err := validateOrigin(origin); err != nil {
+                       return err
+               }
+               if cors.AllowCredentials && origin == "*" {
+                       return errors.New("allowCredentials cannot be true when 
allow-origins contains \"*\"")
+               }
+       }
+
+       // Validate methods
+       for _, method := range cors.AllowMethods {
+               if method == "" || !validHTTPMethods[strings.ToUpper(method)] {
+                       return errors.New("allow-methods contains invalid HTTP 
method")
+               }
+       }
+
+       // Validate headers (both allow and expose)
+       if err := validateHeaders(cors.AllowHeaders, "allow-headers"); err != 
nil {
+               return err
+       }
+       if err := validateHeaders(cors.ExposeHeaders, "expose-headers"); err != 
nil {
+               return err
+       }
+
+       if cors.MaxAge < 0 {
+               return errors.New("max-age cannot be negative")
+       }
+
+       return nil
+}
+
+func validateHeaders(headers []string, fieldName string) error {
+       for _, header := range headers {
+               if strings.TrimSpace(header) == "" {
+                       return errors.New(fieldName + " cannot contain empty 
string")
+               }
+       }
+       return nil
+}
+
+func validateOrigin(origin string) error {
+       // Allow wildcard
+       if origin == "*" {
+               return nil
+       }
+
+       // Check for whitespace
+       if strings.ContainsAny(origin, " \t\n\r") {
+               return errors.New("origin contains whitespace")
+       }
+
+       // Handle subdomain wildcard (*.example.com or https://*.example.com)
+       if strings.Contains(origin, "*") {
+               return validateWildcardOrigin(origin)
+       }
+
+       // Validate URL format
+       if strings.Contains(origin, "://") {
+               u, err := url.Parse(origin)
+               if err != nil || u.Scheme == "" || u.Host == "" {
+                       return errors.New("invalid URL format")
+               }
+       }
+
+       return nil
+}
+
+func validateWildcardOrigin(origin string) error {
+       // Must be *.domain or scheme://*.domain
+       if !strings.HasPrefix(origin, "*.") && !strings.Contains(origin, 
"://*.") {
+               return errors.New("wildcard must be at start: '*.domain' or 
'scheme://*.domain'")
+       }
+
+       // Extract domain part after *.
+       var domain string
+       if strings.Contains(origin, "://*.") {
+               parts := strings.SplitN(origin, "://*.", 2)
+               if len(parts) != 2 || parts[1] == "" {
+                       return errors.New("invalid subdomain wildcard format")
+               }
+               domain = parts[1]
+       } else {
+               domain = origin[2:]
+               if domain == "" {
+                       return errors.New("invalid subdomain wildcard format")
+               }
+       }
+
+       // Only single wildcard allowed
+       if strings.Contains(domain, "*") {
+               return errors.New("only single wildcard at subdomain level is 
allowed")
+       }
+
+       return nil
+}
diff --git a/protocol/triple/server.go b/protocol/triple/server.go
index a8e63e98a..dce734a8f 100644
--- a/protocol/triple/server.go
+++ b/protocol/triple/server.go
@@ -233,6 +233,18 @@ func getHanOpts(url *common.URL, tripleConf 
*global.TripleConfig) (hanOpts []tri
 
        // todo:// open tracing
 
+       // CORS configuration
+       if tripleConf.Cors != nil && len(tripleConf.Cors.AllowOrigins) > 0 {
+               hanOpts = append(hanOpts, tri.WithCORS(&tri.CorsConfig{
+                       AllowOrigins:     tripleConf.Cors.AllowOrigins,
+                       AllowMethods:     tripleConf.Cors.AllowMethods,
+                       AllowHeaders:     tripleConf.Cors.AllowHeaders,
+                       ExposeHeaders:    tripleConf.Cors.ExposeHeaders,
+                       AllowCredentials: tripleConf.Cors.AllowCredentials,
+                       MaxAge:           tripleConf.Cors.MaxAge,
+               }))
+       }
+
        return hanOpts
 }
 
diff --git a/protocol/triple/triple_protocol/cors.go 
b/protocol/triple/triple_protocol/cors.go
new file mode 100644
index 000000000..c3c031955
--- /dev/null
+++ b/protocol/triple/triple_protocol/cors.go
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package triple_protocol
+
+import (
+       "net/http"
+       "net/url"
+       "sort"
+       "strconv"
+       "strings"
+)
+
+import (
+       "github.com/dubbogo/gost/log/logger"
+)
+
+import (
+       "dubbo.apache.org/dubbo-go/v3/common/constant"
+)
+
+// CorsConfig is a CORS configuration struct for handler options.
+type CorsConfig struct {
+       AllowOrigins     []string
+       AllowMethods     []string
+       AllowHeaders     []string
+       ExposeHeaders    []string
+       AllowCredentials bool
+       MaxAge           int
+}
+
+const (
+       corsOrigin           = "Origin"
+       corsVary             = "Vary"
+       corsAllowOrigin      = "Access-Control-Allow-Origin"
+       corsAllowMethods     = "Access-Control-Allow-Methods"
+       corsAllowHeaders     = "Access-Control-Allow-Headers"
+       corsExposeHeaders    = "Access-Control-Expose-Headers"
+       corsAllowCredentials = "Access-Control-Allow-Credentials"
+       corsMaxAge           = "Access-Control-Max-Age"
+       corsRequestMethod    = "Access-Control-Request-Method"
+       corsRequestHeaders   = "Access-Control-Request-Headers"
+)
+
+var defaultCorsMethods = []string{http.MethodGet, http.MethodPost, 
http.MethodPut, http.MethodDelete}
+
+// buildCorsPolicy processes the CorsConfig with handlers and returns a 
configured CorsConfig.
+func buildCorsPolicy(cfg *CorsConfig, handlers []protocolHandler) *CorsConfig {
+       if cfg == nil || len(cfg.AllowOrigins) == 0 {
+               return nil
+       }
+
+       built := &CorsConfig{
+               AllowOrigins:     append([]string(nil), cfg.AllowOrigins...),
+               AllowMethods:     normalizeMethods(cfg.AllowMethods, handlers),
+               AllowHeaders:     append([]string(nil), cfg.AllowHeaders...),
+               ExposeHeaders:    append([]string(nil), cfg.ExposeHeaders...),
+               AllowCredentials: cfg.AllowCredentials,
+               MaxAge:           cfg.MaxAge,
+       }
+
+       if built.hasWildcard() && !cfg.AllowCredentials && 
len(cfg.AllowOrigins) > 1 {
+               logger.Warnf("[TRIPLE] CORS: wildcard \"*\" will override other 
origins when allowCredentials=false")
+       }
+
+       return built
+}
+
+// hasWildcard checks if "*" is present in allowOrigins.
+func (c *CorsConfig) hasWildcard() bool {
+       if c == nil {
+               return false
+       }
+       for _, origin := range c.AllowOrigins {
+               if origin == constant.AnyValue {
+                       return true
+               }
+       }
+       return false
+}
+
+// normalizeMethods normalizes and deduplicates CORS methods.
+func normalizeMethods(configMethods []string, handlers []protocolHandler) 
[]string {
+       methodSet := make(map[string]struct{})
+
+       // Priority 1: explicit configuration
+       if len(configMethods) > 0 {
+               for _, m := range configMethods {
+                       if m != "" {
+                               methodSet[strings.ToUpper(m)] = struct{}{}
+                       }
+               }
+       } else {
+               // Priority 2: extract from handlers
+               if len(handlers) > 0 {
+                       for _, hdl := range handlers {
+                               for m := range hdl.Methods() {
+                                       methodSet[strings.ToUpper(m)] = 
struct{}{}
+                               }
+                       }
+               }
+               // Priority 3: use defaults
+               if len(methodSet) == 0 {
+                       for _, m := range defaultCorsMethods {
+                               methodSet[m] = struct{}{}
+                       }
+               }
+       }
+
+       // Always include OPTIONS for preflight
+       methodSet[http.MethodOptions] = struct{}{}
+
+       methods := make([]string, 0, len(methodSet))
+       for m := range methodSet {
+               methods = append(methods, m)
+       }
+       sort.Strings(methods)
+       return methods
+}
+
+// matchOrigin checks if the request origin matches any allowed pattern.
+func (c *CorsConfig) matchOrigin(origin string) bool {
+       if origin == "" || c == nil || len(c.AllowOrigins) == 0 {
+               return false
+       }
+
+       originURL, err := url.Parse(origin)
+       if err != nil {
+               return false
+       }
+
+       originScheme := strings.ToLower(originURL.Scheme)
+       originHost := originURL.Hostname()
+       originPort := originURL.Port()
+       if originPort == "" {
+               originPort = defaultPort(originScheme)
+       }
+
+       for _, pattern := range c.AllowOrigins {
+               if pattern == constant.AnyValue {
+                       return true
+               }
+
+               // Try parsing pattern as URL
+               patternURL, err := url.Parse(pattern)
+               if err == nil && patternURL.Host != "" {
+                       if matchPattern(originScheme, originHost, originPort, 
patternURL) {
+                               return true
+                       }
+                       continue
+               }
+
+               // Try as hostname pattern (*.example.com or example.com)
+               if matchHostnamePattern(originHost, pattern) {
+                       return true
+               }
+       }
+
+       return false
+}
+
+// matchPattern matches origin against a URL pattern.
+func matchPattern(originScheme, originHost, originPort string, patternURL 
*url.URL) bool {
+       patternScheme := strings.ToLower(patternURL.Scheme)
+       patternHost := patternURL.Hostname()
+       patternPort := patternURL.Port()
+       if patternPort == "" {
+               patternPort = defaultPort(patternScheme)
+       }
+
+       // Scheme must match
+       if patternScheme != "" && patternScheme != originScheme {
+               return false
+       }
+
+       // Check host (supports *.example.com)
+       if strings.HasPrefix(patternHost, "*.") {
+               base := patternHost[2:]
+               if !strings.HasSuffix(originHost, "."+base) || originHost == 
base {
+                       return false
+               }
+       } else if patternHost != originHost {
+               return false
+       }
+
+       // Check port
+       return patternPort == originPort
+}
+
+// matchHostnamePattern matches origin host against hostname pattern 
(*.example.com or example.com).
+func matchHostnamePattern(originHost, pattern string) bool {
+       if strings.HasPrefix(pattern, "*.") {
+               base := pattern[2:]
+               return strings.HasSuffix(originHost, "."+base) && originHost != 
base
+       }
+       return originHost == pattern
+}
+
+// defaultPort returns the default port for a scheme.
+func defaultPort(scheme string) string {
+       switch scheme {
+       case "https":
+               return "443"
+       case "http":
+               return "80"
+       default:
+               return ""
+       }
+}
+
+// handlePreflight handles CORS preflight requests.
+func (c *CorsConfig) handlePreflight(w http.ResponseWriter, r *http.Request) 
bool {
+       if c == nil {
+               return false
+       }
+
+       origin := r.Header.Get(corsOrigin)
+       if origin == "" || !c.matchOrigin(origin) {
+               if origin != "" {
+                       logger.Debugf("[TRIPLE] CORS forbidden origin: %s", 
origin)
+               }
+               w.Header().Add(corsVary, corsOrigin)
+               w.WriteHeader(http.StatusForbidden)
+               return true
+       }
+
+       requestedMethod := r.Header.Get(corsRequestMethod)
+       if requestedMethod != "" && !c.containsMethod(requestedMethod) {
+               logger.Debugf("[TRIPLE] CORS forbidden method: %s (origin: 
%s)", requestedMethod, origin)
+               w.Header().Add(corsVary, corsOrigin)
+               w.WriteHeader(http.StatusForbidden)
+               return true
+       }
+
+       c.setCORSOrigin(w, origin)
+       c.setAllowMethods(w)
+       c.setAllowHeaders(w, r)
+       if c.MaxAge > 0 {
+               w.Header().Set(corsMaxAge, strconv.Itoa(c.MaxAge))
+       }
+       w.WriteHeader(http.StatusNoContent)
+       return true
+}
+
+// addCORSHeaders adds CORS headers to the response.
+func (c *CorsConfig) addCORSHeaders(w http.ResponseWriter, r *http.Request) {
+       if c == nil {
+               return
+       }
+
+       origin := r.Header.Get(corsOrigin)
+       if origin == "" || !c.matchOrigin(origin) {
+               return
+       }
+
+       c.setCORSOrigin(w, origin)
+       if len(c.ExposeHeaders) > 0 {
+               w.Header().Set(corsExposeHeaders, strings.Join(c.ExposeHeaders, 
", "))
+       }
+}
+
+// containsMethod checks if the method is allowed.
+func (c *CorsConfig) containsMethod(target string) bool {
+       if c == nil {
+               return false
+       }
+       targetUpper := strings.ToUpper(target)
+       for _, method := range c.AllowMethods {
+               if method == targetUpper {
+                       return true
+               }
+       }
+       return false
+}
+
+// setAllowMethods sets the Access-Control-Allow-Methods header.
+func (c *CorsConfig) setAllowMethods(w http.ResponseWriter) {
+       w.Header().Set(corsAllowMethods, strings.Join(c.AllowMethods, ", "))
+}
+
+// setAllowHeaders sets the Access-Control-Allow-Headers header.
+func (c *CorsConfig) setAllowHeaders(w http.ResponseWriter, r *http.Request) {
+       if len(c.AllowHeaders) > 0 {
+               w.Header().Set(corsAllowHeaders, strings.Join(c.AllowHeaders, 
", "))
+       } else if requestedHeaders := r.Header.Get(corsRequestHeaders); 
requestedHeaders != "" {
+               w.Header().Set(corsAllowHeaders, requestedHeaders)
+       }
+}
+
+// setCORSOrigin sets the Access-Control-Allow-Origin header.
+func (c *CorsConfig) setCORSOrigin(w http.ResponseWriter, origin string) {
+       if c.AllowCredentials {
+               w.Header().Set(corsAllowOrigin, origin)
+               w.Header().Add(corsVary, corsOrigin)
+               w.Header().Set(corsAllowCredentials, "true")
+       } else if c.hasWildcard() {
+               w.Header().Set(corsAllowOrigin, constant.AnyValue)
+       } else {
+               w.Header().Set(corsAllowOrigin, origin)
+               w.Header().Add(corsVary, corsOrigin)
+       }
+}
diff --git a/protocol/triple/triple_protocol/cors_test.go 
b/protocol/triple/triple_protocol/cors_test.go
new file mode 100644
index 000000000..f585b7d55
--- /dev/null
+++ b/protocol/triple/triple_protocol/cors_test.go
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package triple_protocol
+
+import (
+       "context"
+       "net/http"
+       "net/http/httptest"
+       "testing"
+)
+
+import (
+       "dubbo.apache.org/dubbo-go/v3/global"
+       
"dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol/internal/assert"
+)
+
+func convertCorsConfigForTest(cfg *global.CorsConfig) *CorsConfig {
+       if cfg == nil {
+               return nil
+       }
+       return &CorsConfig{
+               AllowOrigins:     append([]string(nil), cfg.AllowOrigins...),
+               AllowMethods:     append([]string(nil), cfg.AllowMethods...),
+               AllowHeaders:     append([]string(nil), cfg.AllowHeaders...),
+               ExposeHeaders:    append([]string(nil), cfg.ExposeHeaders...),
+               AllowCredentials: cfg.AllowCredentials,
+               MaxAge:           cfg.MaxAge,
+       }
+}
+
+func TestMatchOrigin(t *testing.T) {
+       t.Parallel()
+       tests := []struct {
+               name    string
+               origin  string
+               allowed []string
+               want    bool
+       }{
+               {"exact match", "https://api.example.com";, 
[]string{"https://api.example.com"}, true},
+               {"wildcard any", "https://foo.bar";, []string{"*"}, true},
+               {"subdomain wildcard", "https://a.example.com";, 
[]string{"https://*.example.com"}, true},
+               {"subdomain no scheme", "https://b.example.com";, 
[]string{"*.example.com"}, true},
+               {"scheme mismatch", "http://a.example.com";, 
[]string{"https://*.example.com"}, false},
+               {"not matched", "https://other.com";, 
[]string{"https://example.com"}, false},
+               {"empty origin", "", []string{"https://api.example.com"}, 
false},
+               {"default port https", "https://api.example.com:443";, 
[]string{"https://api.example.com"}, true},
+               {"default port http", "http://api.example.com:80";, 
[]string{"http://api.example.com"}, true},
+       }
+
+       for _, tt := range tests {
+               tt := tt
+               t.Run(tt.name, func(t *testing.T) {
+                       t.Parallel()
+                       c := &CorsConfig{AllowOrigins: tt.allowed}
+                       assert.Equal(t, c.matchOrigin(tt.origin), tt.want)
+               })
+       }
+}
+
+func TestAddCORSHeaders(t *testing.T) {
+       t.Parallel()
+       tests := []struct {
+               name       string
+               policy     *CorsConfig
+               origin     string
+               wantOrigin string
+               wantCreds  bool
+       }{
+               {"allowed with creds", 
buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{
+                       AllowOrigins:     []string{"https://a.com"},
+                       AllowCredentials: true,
+               }), nil), "https://a.com";, "https://a.com";, true},
+               {"nil policy", nil, "https://a.com";, "", false},
+               {"empty origin", 
buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{
+                       AllowOrigins: []string{"https://a.com"},
+               }), nil), "", "", false},
+               {"not allowed", 
buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{
+                       AllowOrigins: []string{"https://a.com"},
+               }), nil), "https://b.com";, "", false},
+       }
+
+       for _, tt := range tests {
+               tt := tt
+               t.Run(tt.name, func(t *testing.T) {
+                       t.Parallel()
+                       req := httptest.NewRequest(http.MethodGet, "/", nil)
+                       if tt.origin != "" {
+                               req.Header.Set("Origin", tt.origin)
+                       }
+                       rr := httptest.NewRecorder()
+
+                       tt.policy.addCORSHeaders(rr, req)
+                       assert.Equal(t, 
rr.Header().Get("Access-Control-Allow-Origin"), tt.wantOrigin)
+                       if tt.wantCreds {
+                               assert.Equal(t, 
rr.Header().Get("Access-Control-Allow-Credentials"), "true")
+                       }
+               })
+       }
+}
+
+func TestBuildCorsPolicy(t *testing.T) {
+       t.Parallel()
+       assert.Nil(t, buildCorsPolicy(nil, nil))
+       assert.Nil(t, 
buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{}), nil))
+
+       p := buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{
+               AllowOrigins: []string{"https://a.com"},
+               MaxAge:       123,
+       }), nil)
+       assert.NotNil(t, p)
+       assert.Equal(t, p.MaxAge, 123)
+}
+
+func TestServeHTTPCORS(t *testing.T) {
+       t.Parallel()
+       // Allowed origin
+       h1 := NewUnaryHandler("dummy", func() any { return nil }, func(ctx 
context.Context, req *Request) (*Response, error) {
+               return &Response{Msg: struct{}{}, header: http.Header{}, 
trailer: http.Header{}}, nil
+       }, WithCORS(&CorsConfig{AllowOrigins: []string{"https://a.com"}, 
AllowCredentials: true}))
+
+       req1 := httptest.NewRequest(http.MethodPost, "/", nil)
+       req1.Header.Set("Origin", "https://a.com";)
+       req1.Header.Set(tripleServiceGroup, "")
+       req1.Header.Set(tripleServiceVersion, "")
+       rr1 := httptest.NewRecorder()
+       h1.ServeHTTP(rr1, req1)
+       assert.Equal(t, rr1.Header().Get("Access-Control-Allow-Origin"), 
"https://a.com";)
+       assert.Equal(t, rr1.Header().Get("Access-Control-Allow-Credentials"), 
"true")
+
+       // Forbidden origin
+       h2 := NewUnaryHandler("dummy", func() any { return nil }, func(ctx 
context.Context, req *Request) (*Response, error) {
+               return &Response{Msg: struct{}{}}, nil
+       }, WithCORS(&CorsConfig{AllowOrigins: []string{"https://a.com"}}))
+
+       req2 := httptest.NewRequest(http.MethodPost, "/", nil)
+       req2.Header.Set("Origin", "https://b.com";)
+       req2.Header.Set(tripleServiceGroup, "")
+       req2.Header.Set(tripleServiceVersion, "")
+       rr2 := httptest.NewRecorder()
+       h2.ServeHTTP(rr2, req2)
+       assert.Equal(t, rr2.Code, http.StatusForbidden)
+}
+
+func TestHandleCORS(t *testing.T) {
+       t.Parallel()
+       cors := buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{
+               AllowOrigins: []string{"https://a.com"},
+       }), nil)
+
+       tests := []struct {
+               name            string
+               cors            *CorsConfig
+               method          string
+               origin          string
+               preflightHeader string
+               wantHandled     bool
+               wantCode        int
+       }{
+               {"nil config", nil, http.MethodGet, "https://a.com";, "", false, 
0},
+               {"preflight success", cors, http.MethodOptions, 
"https://a.com";, http.MethodPost, true, http.StatusNoContent},
+               {"forbidden origin", cors, http.MethodGet, "https://b.com";, "", 
true, http.StatusForbidden},
+               {"allowed origin", cors, http.MethodGet, "https://a.com";, "", 
false, 0},
+       }
+
+       for _, tt := range tests {
+               tt := tt
+               t.Run(tt.name, func(t *testing.T) {
+                       t.Parallel()
+                       h := &Handler{cors: tt.cors}
+                       req := httptest.NewRequest(tt.method, "/", nil)
+                       if tt.origin != "" {
+                               req.Header.Set("Origin", tt.origin)
+                       }
+                       if tt.preflightHeader != "" {
+                               req.Header.Set("Access-Control-Request-Method", 
tt.preflightHeader)
+                       }
+                       rr := httptest.NewRecorder()
+
+                       handled := h.handleCORS(rr, req)
+                       assert.Equal(t, handled, tt.wantHandled)
+                       if tt.wantCode != 0 {
+                               assert.Equal(t, rr.Code, tt.wantCode)
+                       }
+               })
+       }
+}
diff --git a/protocol/triple/triple_protocol/handler.go 
b/protocol/triple/triple_protocol/handler.go
index 256a7dcbb..ff12d9bb1 100644
--- a/protocol/triple/triple_protocol/handler.go
+++ b/protocol/triple/triple_protocol/handler.go
@@ -35,8 +35,9 @@ type Handler struct {
        // key is group/version
        implementations  map[string]StreamingHandlerFunc
        protocolHandlers []protocolHandler
-       allowMethod      string // Allow header
-       acceptPost       string // Accept-Post header
+       allowMethod      string      // Allow header
+       acceptPost       string      // Accept-Post header
+       cors             *CorsConfig // CORS policy
 }
 
 // NewUnaryHandler constructs a [Handler] for a request-response procedure.
@@ -56,6 +57,7 @@ func NewUnaryHandler(
                protocolHandlers: protocolHandlers,
                allowMethod:      sortedAllowMethodValue(protocolHandlers),
                acceptPost:       sortedAcceptPostValue(protocolHandlers),
+               cors:             buildCorsPolicy(config.Cors, 
protocolHandlers),
        }
        hdl.processImplementation(getIdentifier(config.Group, config.Version), 
implementation)
        return hdl
@@ -141,6 +143,7 @@ func NewClientStreamHandler(
                protocolHandlers: protocolHandlers,
                allowMethod:      sortedAllowMethodValue(protocolHandlers),
                acceptPost:       sortedAcceptPostValue(protocolHandlers),
+               cors:             buildCorsPolicy(config.Cors, 
protocolHandlers),
        }
        hdl.processImplementation(getIdentifier(config.Group, config.Version), 
implementation)
 
@@ -199,6 +202,7 @@ func NewServerStreamHandler(
                protocolHandlers: protocolHandlers,
                allowMethod:      sortedAllowMethodValue(protocolHandlers),
                acceptPost:       sortedAcceptPostValue(protocolHandlers),
+               cors:             buildCorsPolicy(config.Cors, 
protocolHandlers),
        }
        hdl.processImplementation(getIdentifier(config.Group, config.Version), 
implementation)
 
@@ -259,6 +263,7 @@ func NewBidiStreamHandler(
                protocolHandlers: protocolHandlers,
                allowMethod:      sortedAllowMethodValue(protocolHandlers),
                acceptPost:       sortedAcceptPostValue(protocolHandlers),
+               cors:             buildCorsPolicy(config.Cors, 
protocolHandlers),
        }
        hdl.processImplementation(getIdentifier(config.Group, config.Version), 
implementation)
 
@@ -309,6 +314,13 @@ func (h *Handler) ServeHTTP(responseWriter 
http.ResponseWriter, request *http.Re
                return
        }
 
+       // CORS handling
+       if h.cors != nil {
+               if h.handleCORS(responseWriter, request) {
+                       return
+               }
+       }
+
        // inspect headers
        var protocolHandlers []protocolHandler
        for _, handler := range h.protocolHandlers {
@@ -393,6 +405,7 @@ type handlerConfig struct {
        SendMaxBytes                int
        Group                       string
        Version                     string
+       Cors                        *CorsConfig
 }
 
 func newHandlerConfig(procedure string, options []HandlerOption) 
*handlerConfig {
@@ -462,3 +475,30 @@ func (c *handlerConfig) newProtocolHandlers(streamType 
StreamType) []protocolHan
 func getIdentifier(group, version string) string {
        return group + "/" + version
 }
+
+// handleCORS processes CORS requests. Returns true if the request was handled 
and processing should stop.
+func (h *Handler) handleCORS(w http.ResponseWriter, r *http.Request) bool {
+       if h.cors == nil {
+               return false
+       }
+
+       // Handle preflight requests
+       if r.Method == http.MethodOptions && r.Header.Get(corsRequestMethod) != 
"" {
+               return h.cors.handlePreflight(w, r)
+       }
+
+       // Handle normal requests with Origin header
+       origin := r.Header.Get(corsOrigin)
+       if origin == "" {
+               return false
+       }
+
+       if !h.cors.matchOrigin(origin) {
+               w.Header().Add(corsVary, corsOrigin)
+               w.WriteHeader(http.StatusForbidden)
+               return true
+       }
+
+       h.cors.addCORSHeaders(w, r)
+       return false
+}
diff --git a/protocol/triple/triple_protocol/option.go 
b/protocol/triple/triple_protocol/option.go
index 13830e0e5..197df4d95 100644
--- a/protocol/triple/triple_protocol/option.go
+++ b/protocol/triple/triple_protocol/option.go
@@ -178,6 +178,11 @@ func WithRequireTripleProtocolHeader() HandlerOption {
        return &requireTripleProtocolHeaderOption{}
 }
 
+// WithCORS configures CORS for the handler.
+func WithCORS(cors *CorsConfig) HandlerOption {
+       return &corsOption{cors: cors}
+}
+
 func WithGroup(group string) Option {
        return &groupOption{group}
 }
@@ -443,6 +448,17 @@ func (o *requireTripleProtocolHeaderOption) 
applyToHandler(config *handlerConfig
        config.RequireTripleProtocolHeader = true
 }
 
+type corsOption struct {
+       cors *CorsConfig
+}
+
+func (o *corsOption) applyToHandler(config *handlerConfig) {
+       if o.cors == nil {
+               return
+       }
+       config.Cors = o.cors
+}
+
 type groupOption struct {
        Group string
 }

Reply via email to