This is an automated email from the ASF dual-hosted git repository.
xuetaoli pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git
The following commit(s) were added to refs/heads/develop by this push:
new 93543d130 feat: support cors for triple (#3090)
93543d130 is described below
commit 93543d13018ac027656282d8b9b2d68ce88f3548
Author: zbchi <[email protected]>
AuthorDate: Mon Feb 9 11:27:19 2026 +0800
feat: support cors for triple (#3090)
* feat: support cors for triple
* refactor cors.go
* add unit test for cors
* remove options_cors into options
* feat: improve CORS port matching logic
* format code
* improve cors config in protocol
* fix SonarQube err
* fix nil calling function
* rm useless func
* improve validation
* simplify unit test
* simplify the implementation
* rm corsPolicy
* fix: set vary with origin when depending on origin header
---
global/config_test.go | 7 +
global/cors_config.go | 75 +++++++
global/triple_config.go | 5 +
protocol/triple/options.go | 188 ++++++++++++++++
protocol/triple/server.go | 12 +
protocol/triple/triple_protocol/cors.go | 316 +++++++++++++++++++++++++++
protocol/triple/triple_protocol/cors_test.go | 201 +++++++++++++++++
protocol/triple/triple_protocol/handler.go | 44 +++-
protocol/triple/triple_protocol/option.go | 16 ++
9 files changed, 862 insertions(+), 2 deletions(-)
diff --git a/global/config_test.go b/global/config_test.go
index 398e4e689..46f1293bb 100644
--- a/global/config_test.go
+++ b/global/config_test.go
@@ -187,6 +187,13 @@ func TestCloneConfig(t *testing.T) {
CheckCompleteInequality(t, c, clone)
})
+ t.Run("CorsConfig", func(t *testing.T) {
+ c := DefaultCorsConfig()
+ InitCheckCompleteInequality(t, c)
+ clone := c.Clone()
+ CheckCompleteInequality(t, c, clone)
+ })
+
t.Run("Http3Config", func(t *testing.T) {
c := DefaultHttp3Config()
InitCheckCompleteInequality(t, c)
diff --git a/global/cors_config.go b/global/cors_config.go
new file mode 100644
index 000000000..7d55bcecf
--- /dev/null
+++ b/global/cors_config.go
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package global
+
+// CorsConfig represents the CORS configuration for Triple protocol handlers.
+type CorsConfig struct {
+ // AllowOrigins specifies the allowed origins for CORS requests.
+ // Supports exact match, wildcard "*", and subdomain wildcard like
"https://*.example.com".
+ // Empty slice means CORS is disabled.
+ AllowOrigins []string `yaml:"allow-origins"
json:"allow-origins,omitempty"`
+
+ // AllowMethods specifies the allowed HTTP methods for CORS requests.
+ // If empty, methods will be automatically determined from handler
registrations.
+ // OPTIONS method is always included automatically.
+ AllowMethods []string `yaml:"allow-methods"
json:"allow-methods,omitempty"`
+
+ // AllowHeaders specifies the allowed request headers for CORS requests.
+ // If empty, the requested headers from Access-Control-Request-Headers
will be echoed back.
+ AllowHeaders []string `yaml:"allow-headers"
json:"allow-headers,omitempty"`
+
+ // ExposeHeaders specifies the headers that browsers are allowed to
access.
+ ExposeHeaders []string `yaml:"expose-headers"
json:"expose-headers,omitempty"`
+
+ // AllowCredentials indicates whether credentials are allowed in CORS
requests.
+ // When true, AllowOrigins cannot contain "*".
+ AllowCredentials bool `yaml:"allow-credentials"
json:"allow-credentials,omitempty"`
+
+ // MaxAge specifies the maximum age (in seconds) for preflight cache.
+ // Must be non-negative. If zero, disables caching (no
Access-Control-Max-Age header will be sent).
+ // If positive, specifies the cache duration in seconds.
+ MaxAge int `yaml:"max-age" json:"max-age,omitempty"`
+}
+
+// DefaultCorsConfig returns a default CorsConfig instance.
+func DefaultCorsConfig() *CorsConfig {
+ return &CorsConfig{
+ AllowOrigins: []string{},
+ AllowMethods: []string{},
+ AllowHeaders: []string{},
+ ExposeHeaders: []string{},
+ AllowCredentials: false,
+ MaxAge: 0,
+ }
+}
+
+// Clone a new CorsConfig
+func (c *CorsConfig) Clone() *CorsConfig {
+ if c == nil {
+ return nil
+ }
+
+ return &CorsConfig{
+ AllowOrigins: append([]string(nil), c.AllowOrigins...),
+ AllowMethods: append([]string(nil), c.AllowMethods...),
+ AllowHeaders: append([]string(nil), c.AllowHeaders...),
+ ExposeHeaders: append([]string(nil), c.ExposeHeaders...),
+ AllowCredentials: c.AllowCredentials,
+ MaxAge: c.MaxAge,
+ }
+}
diff --git a/global/triple_config.go b/global/triple_config.go
index 83b57c63d..004ff1073 100644
--- a/global/triple_config.go
+++ b/global/triple_config.go
@@ -34,6 +34,9 @@ type TripleConfig struct {
// the config of http3 transport
Http3 *Http3Config `yaml:"http3" json:"http3,omitempty"`
+ // Cors configures CORS for Triple protocol handlers
+ Cors *CorsConfig `yaml:"cors" json:"cors,omitempty"`
+
//
// for client
//
@@ -46,6 +49,7 @@ type TripleConfig struct {
func DefaultTripleConfig() *TripleConfig {
return &TripleConfig{
Http3: DefaultHttp3Config(),
+ Cors: DefaultCorsConfig(),
}
}
@@ -59,6 +63,7 @@ func (t *TripleConfig) Clone() *TripleConfig {
MaxServerSendMsgSize: t.MaxServerSendMsgSize,
MaxServerRecvMsgSize: t.MaxServerRecvMsgSize,
Http3: t.Http3.Clone(),
+ Cors: t.Cors.Clone(),
KeepAliveInterval: t.KeepAliveInterval,
KeepAliveTimeout: t.KeepAliveTimeout,
diff --git a/protocol/triple/options.go b/protocol/triple/options.go
index 5a6f37492..5ea92636f 100644
--- a/protocol/triple/options.go
+++ b/protocol/triple/options.go
@@ -18,9 +18,17 @@
package triple
import (
+ "errors"
+ "net/http"
+ "net/url"
+ "strings"
"time"
)
+import (
+ "github.com/dubbogo/gost/log/logger"
+)
+
import (
"dubbo.apache.org/dubbo-go/v3/global"
)
@@ -92,6 +100,23 @@ func WithMaxServerRecvMsgSize(size string) Option {
}
}
+// WithCORS applies CORS configuration to triple options.
+// Invalid configs are logged as errors and ignored (no-op).
+func WithCORS(opts ...CORSOption) Option {
+ cors := global.DefaultCorsConfig()
+ for _, opt := range opts {
+ opt(cors)
+ }
+ if err := validateCorsConfig(cors); err != nil {
+ logger.Errorf("[TRIPLE] invalid CORS config: %v", err)
+ // Return a no-op function to ignore invalid CORS configuration
+ return func(*Options) {}
+ }
+ return func(opts *Options) {
+ opts.Triple.Cors = cors
+ }
+}
+
// Http3Enable enables HTTP/3 support for the Triple protocol.
// This option configures the server to start both HTTP/2 and HTTP/3 servers
// simultaneously, providing modern HTTP/3 capabilities alongside traditional
HTTP/2.
@@ -161,3 +186,166 @@ func Http3Negotiation(negotiation bool) Option {
opts.Triple.Http3.Negotiation = negotiation
}
}
+
+// CORSOption configures a single aspect of CORS.
+type CORSOption func(*global.CorsConfig)
+
+// CORSAllowOrigins sets allowed origins for CORS requests.
+func CORSAllowOrigins(origins ...string) CORSOption {
+ return func(c *global.CorsConfig) {
+ c.AllowOrigins = append([]string(nil), origins...)
+ }
+}
+
+// CORSAllowMethods sets allowed HTTP methods for CORS requests.
+func CORSAllowMethods(methods ...string) CORSOption {
+ return func(c *global.CorsConfig) {
+ c.AllowMethods = append([]string(nil), methods...)
+ }
+}
+
+// CORSAllowHeaders sets allowed request headers for CORS requests.
+func CORSAllowHeaders(headers ...string) CORSOption {
+ return func(c *global.CorsConfig) {
+ c.AllowHeaders = append([]string(nil), headers...)
+ }
+}
+
+// CORSExposeHeaders sets headers exposed to the browser.
+func CORSExposeHeaders(headers ...string) CORSOption {
+ return func(c *global.CorsConfig) {
+ c.ExposeHeaders = append([]string(nil), headers...)
+ }
+}
+
+// CORSAllowCredentials toggles whether credentials are allowed.
+func CORSAllowCredentials(allow bool) CORSOption {
+ return func(c *global.CorsConfig) {
+ c.AllowCredentials = allow
+ }
+}
+
+// CORSMaxAge sets the max age for preflight cache.
+func CORSMaxAge(maxAge int) CORSOption {
+ return func(c *global.CorsConfig) {
+ c.MaxAge = maxAge
+ }
+}
+
+var validHTTPMethods = map[string]bool{
+ http.MethodGet: true,
+ http.MethodHead: true,
+ http.MethodPost: true,
+ http.MethodPut: true,
+ http.MethodPatch: true,
+ http.MethodDelete: true,
+ http.MethodConnect: true,
+ http.MethodOptions: true,
+ http.MethodTrace: true,
+}
+
+// validateCorsConfig validates CORS configuration.
+func validateCorsConfig(cors *global.CorsConfig) error {
+ if cors == nil {
+ return nil
+ }
+
+ // Validate origins
+ for _, origin := range cors.AllowOrigins {
+ if origin == "" {
+ return errors.New("allow-origins cannot contain empty
string")
+ }
+ if err := validateOrigin(origin); err != nil {
+ return err
+ }
+ if cors.AllowCredentials && origin == "*" {
+ return errors.New("allowCredentials cannot be true when
allow-origins contains \"*\"")
+ }
+ }
+
+ // Validate methods
+ for _, method := range cors.AllowMethods {
+ if method == "" || !validHTTPMethods[strings.ToUpper(method)] {
+ return errors.New("allow-methods contains invalid HTTP
method")
+ }
+ }
+
+ // Validate headers (both allow and expose)
+ if err := validateHeaders(cors.AllowHeaders, "allow-headers"); err !=
nil {
+ return err
+ }
+ if err := validateHeaders(cors.ExposeHeaders, "expose-headers"); err !=
nil {
+ return err
+ }
+
+ if cors.MaxAge < 0 {
+ return errors.New("max-age cannot be negative")
+ }
+
+ return nil
+}
+
+func validateHeaders(headers []string, fieldName string) error {
+ for _, header := range headers {
+ if strings.TrimSpace(header) == "" {
+ return errors.New(fieldName + " cannot contain empty
string")
+ }
+ }
+ return nil
+}
+
+func validateOrigin(origin string) error {
+ // Allow wildcard
+ if origin == "*" {
+ return nil
+ }
+
+ // Check for whitespace
+ if strings.ContainsAny(origin, " \t\n\r") {
+ return errors.New("origin contains whitespace")
+ }
+
+ // Handle subdomain wildcard (*.example.com or https://*.example.com)
+ if strings.Contains(origin, "*") {
+ return validateWildcardOrigin(origin)
+ }
+
+ // Validate URL format
+ if strings.Contains(origin, "://") {
+ u, err := url.Parse(origin)
+ if err != nil || u.Scheme == "" || u.Host == "" {
+ return errors.New("invalid URL format")
+ }
+ }
+
+ return nil
+}
+
+func validateWildcardOrigin(origin string) error {
+ // Must be *.domain or scheme://*.domain
+ if !strings.HasPrefix(origin, "*.") && !strings.Contains(origin,
"://*.") {
+ return errors.New("wildcard must be at start: '*.domain' or
'scheme://*.domain'")
+ }
+
+ // Extract domain part after *.
+ var domain string
+ if strings.Contains(origin, "://*.") {
+ parts := strings.SplitN(origin, "://*.", 2)
+ if len(parts) != 2 || parts[1] == "" {
+ return errors.New("invalid subdomain wildcard format")
+ }
+ domain = parts[1]
+ } else {
+ domain = origin[2:]
+ if domain == "" {
+ return errors.New("invalid subdomain wildcard format")
+ }
+ }
+
+ // Only single wildcard allowed
+ if strings.Contains(domain, "*") {
+ return errors.New("only single wildcard at subdomain level is
allowed")
+ }
+
+ return nil
+}
diff --git a/protocol/triple/server.go b/protocol/triple/server.go
index a8e63e98a..dce734a8f 100644
--- a/protocol/triple/server.go
+++ b/protocol/triple/server.go
@@ -233,6 +233,18 @@ func getHanOpts(url *common.URL, tripleConf
*global.TripleConfig) (hanOpts []tri
// todo:// open tracing
+ // CORS configuration
+ if tripleConf.Cors != nil && len(tripleConf.Cors.AllowOrigins) > 0 {
+ hanOpts = append(hanOpts, tri.WithCORS(&tri.CorsConfig{
+ AllowOrigins: tripleConf.Cors.AllowOrigins,
+ AllowMethods: tripleConf.Cors.AllowMethods,
+ AllowHeaders: tripleConf.Cors.AllowHeaders,
+ ExposeHeaders: tripleConf.Cors.ExposeHeaders,
+ AllowCredentials: tripleConf.Cors.AllowCredentials,
+ MaxAge: tripleConf.Cors.MaxAge,
+ }))
+ }
+
return hanOpts
}
diff --git a/protocol/triple/triple_protocol/cors.go
b/protocol/triple/triple_protocol/cors.go
new file mode 100644
index 000000000..c3c031955
--- /dev/null
+++ b/protocol/triple/triple_protocol/cors.go
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package triple_protocol
+
+import (
+ "net/http"
+ "net/url"
+ "sort"
+ "strconv"
+ "strings"
+)
+
+import (
+ "github.com/dubbogo/gost/log/logger"
+)
+
+import (
+ "dubbo.apache.org/dubbo-go/v3/common/constant"
+)
+
+// CorsConfig is a CORS configuration struct for handler options.
+type CorsConfig struct {
+ AllowOrigins []string
+ AllowMethods []string
+ AllowHeaders []string
+ ExposeHeaders []string
+ AllowCredentials bool
+ MaxAge int
+}
+
+const (
+ corsOrigin = "Origin"
+ corsVary = "Vary"
+ corsAllowOrigin = "Access-Control-Allow-Origin"
+ corsAllowMethods = "Access-Control-Allow-Methods"
+ corsAllowHeaders = "Access-Control-Allow-Headers"
+ corsExposeHeaders = "Access-Control-Expose-Headers"
+ corsAllowCredentials = "Access-Control-Allow-Credentials"
+ corsMaxAge = "Access-Control-Max-Age"
+ corsRequestMethod = "Access-Control-Request-Method"
+ corsRequestHeaders = "Access-Control-Request-Headers"
+)
+
+var defaultCorsMethods = []string{http.MethodGet, http.MethodPost,
http.MethodPut, http.MethodDelete}
+
+// buildCorsPolicy processes the CorsConfig with handlers and returns a
configured CorsConfig.
+func buildCorsPolicy(cfg *CorsConfig, handlers []protocolHandler) *CorsConfig {
+ if cfg == nil || len(cfg.AllowOrigins) == 0 {
+ return nil
+ }
+
+ built := &CorsConfig{
+ AllowOrigins: append([]string(nil), cfg.AllowOrigins...),
+ AllowMethods: normalizeMethods(cfg.AllowMethods, handlers),
+ AllowHeaders: append([]string(nil), cfg.AllowHeaders...),
+ ExposeHeaders: append([]string(nil), cfg.ExposeHeaders...),
+ AllowCredentials: cfg.AllowCredentials,
+ MaxAge: cfg.MaxAge,
+ }
+
+ if built.hasWildcard() && !cfg.AllowCredentials &&
len(cfg.AllowOrigins) > 1 {
+ logger.Warnf("[TRIPLE] CORS: wildcard \"*\" will override other
origins when allowCredentials=false")
+ }
+
+ return built
+}
+
+// hasWildcard checks if "*" is present in allowOrigins.
+func (c *CorsConfig) hasWildcard() bool {
+ if c == nil {
+ return false
+ }
+ for _, origin := range c.AllowOrigins {
+ if origin == constant.AnyValue {
+ return true
+ }
+ }
+ return false
+}
+
+// normalizeMethods normalizes and deduplicates CORS methods.
+func normalizeMethods(configMethods []string, handlers []protocolHandler)
[]string {
+ methodSet := make(map[string]struct{})
+
+ // Priority 1: explicit configuration
+ if len(configMethods) > 0 {
+ for _, m := range configMethods {
+ if m != "" {
+ methodSet[strings.ToUpper(m)] = struct{}{}
+ }
+ }
+ } else {
+ // Priority 2: extract from handlers
+ if len(handlers) > 0 {
+ for _, hdl := range handlers {
+ for m := range hdl.Methods() {
+ methodSet[strings.ToUpper(m)] =
struct{}{}
+ }
+ }
+ }
+ // Priority 3: use defaults
+ if len(methodSet) == 0 {
+ for _, m := range defaultCorsMethods {
+ methodSet[m] = struct{}{}
+ }
+ }
+ }
+
+ // Always include OPTIONS for preflight
+ methodSet[http.MethodOptions] = struct{}{}
+
+ methods := make([]string, 0, len(methodSet))
+ for m := range methodSet {
+ methods = append(methods, m)
+ }
+ sort.Strings(methods)
+ return methods
+}
+
+// matchOrigin checks if the request origin matches any allowed pattern.
+func (c *CorsConfig) matchOrigin(origin string) bool {
+ if origin == "" || c == nil || len(c.AllowOrigins) == 0 {
+ return false
+ }
+
+ originURL, err := url.Parse(origin)
+ if err != nil {
+ return false
+ }
+
+ originScheme := strings.ToLower(originURL.Scheme)
+ originHost := originURL.Hostname()
+ originPort := originURL.Port()
+ if originPort == "" {
+ originPort = defaultPort(originScheme)
+ }
+
+ for _, pattern := range c.AllowOrigins {
+ if pattern == constant.AnyValue {
+ return true
+ }
+
+ // Try parsing pattern as URL
+ patternURL, err := url.Parse(pattern)
+ if err == nil && patternURL.Host != "" {
+ if matchPattern(originScheme, originHost, originPort,
patternURL) {
+ return true
+ }
+ continue
+ }
+
+ // Try as hostname pattern (*.example.com or example.com)
+ if matchHostnamePattern(originHost, pattern) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// matchPattern matches origin against a URL pattern.
+func matchPattern(originScheme, originHost, originPort string, patternURL
*url.URL) bool {
+ patternScheme := strings.ToLower(patternURL.Scheme)
+ patternHost := patternURL.Hostname()
+ patternPort := patternURL.Port()
+ if patternPort == "" {
+ patternPort = defaultPort(patternScheme)
+ }
+
+ // Scheme must match
+ if patternScheme != "" && patternScheme != originScheme {
+ return false
+ }
+
+ // Check host (supports *.example.com)
+ if strings.HasPrefix(patternHost, "*.") {
+ base := patternHost[2:]
+ if !strings.HasSuffix(originHost, "."+base) || originHost ==
base {
+ return false
+ }
+ } else if patternHost != originHost {
+ return false
+ }
+
+ // Check port
+ return patternPort == originPort
+}
+
+// matchHostnamePattern matches origin host against hostname pattern
(*.example.com or example.com).
+func matchHostnamePattern(originHost, pattern string) bool {
+ if strings.HasPrefix(pattern, "*.") {
+ base := pattern[2:]
+ return strings.HasSuffix(originHost, "."+base) && originHost !=
base
+ }
+ return originHost == pattern
+}
+
+// defaultPort returns the default port for a scheme.
+func defaultPort(scheme string) string {
+ switch scheme {
+ case "https":
+ return "443"
+ case "http":
+ return "80"
+ default:
+ return ""
+ }
+}
+
+// handlePreflight handles CORS preflight requests.
+func (c *CorsConfig) handlePreflight(w http.ResponseWriter, r *http.Request)
bool {
+ if c == nil {
+ return false
+ }
+
+ origin := r.Header.Get(corsOrigin)
+ if origin == "" || !c.matchOrigin(origin) {
+ if origin != "" {
+ logger.Debugf("[TRIPLE] CORS forbidden origin: %s",
origin)
+ }
+ w.Header().Add(corsVary, corsOrigin)
+ w.WriteHeader(http.StatusForbidden)
+ return true
+ }
+
+ requestedMethod := r.Header.Get(corsRequestMethod)
+ if requestedMethod != "" && !c.containsMethod(requestedMethod) {
+ logger.Debugf("[TRIPLE] CORS forbidden method: %s (origin:
%s)", requestedMethod, origin)
+ w.Header().Add(corsVary, corsOrigin)
+ w.WriteHeader(http.StatusForbidden)
+ return true
+ }
+
+ c.setCORSOrigin(w, origin)
+ c.setAllowMethods(w)
+ c.setAllowHeaders(w, r)
+ if c.MaxAge > 0 {
+ w.Header().Set(corsMaxAge, strconv.Itoa(c.MaxAge))
+ }
+ w.WriteHeader(http.StatusNoContent)
+ return true
+}
+
+// addCORSHeaders adds CORS headers to the response.
+func (c *CorsConfig) addCORSHeaders(w http.ResponseWriter, r *http.Request) {
+ if c == nil {
+ return
+ }
+
+ origin := r.Header.Get(corsOrigin)
+ if origin == "" || !c.matchOrigin(origin) {
+ return
+ }
+
+ c.setCORSOrigin(w, origin)
+ if len(c.ExposeHeaders) > 0 {
+ w.Header().Set(corsExposeHeaders, strings.Join(c.ExposeHeaders,
", "))
+ }
+}
+
+// containsMethod checks if the method is allowed.
+func (c *CorsConfig) containsMethod(target string) bool {
+ if c == nil {
+ return false
+ }
+ targetUpper := strings.ToUpper(target)
+ for _, method := range c.AllowMethods {
+ if method == targetUpper {
+ return true
+ }
+ }
+ return false
+}
+
+// setAllowMethods sets the Access-Control-Allow-Methods header.
+func (c *CorsConfig) setAllowMethods(w http.ResponseWriter) {
+ w.Header().Set(corsAllowMethods, strings.Join(c.AllowMethods, ", "))
+}
+
+// setAllowHeaders sets the Access-Control-Allow-Headers header.
+func (c *CorsConfig) setAllowHeaders(w http.ResponseWriter, r *http.Request) {
+ if len(c.AllowHeaders) > 0 {
+ w.Header().Set(corsAllowHeaders, strings.Join(c.AllowHeaders,
", "))
+ } else if requestedHeaders := r.Header.Get(corsRequestHeaders);
requestedHeaders != "" {
+ w.Header().Set(corsAllowHeaders, requestedHeaders)
+ }
+}
+
+// setCORSOrigin sets the Access-Control-Allow-Origin header.
+func (c *CorsConfig) setCORSOrigin(w http.ResponseWriter, origin string) {
+ if c.AllowCredentials {
+ w.Header().Set(corsAllowOrigin, origin)
+ w.Header().Add(corsVary, corsOrigin)
+ w.Header().Set(corsAllowCredentials, "true")
+ } else if c.hasWildcard() {
+ w.Header().Set(corsAllowOrigin, constant.AnyValue)
+ } else {
+ w.Header().Set(corsAllowOrigin, origin)
+ w.Header().Add(corsVary, corsOrigin)
+ }
+}
diff --git a/protocol/triple/triple_protocol/cors_test.go
b/protocol/triple/triple_protocol/cors_test.go
new file mode 100644
index 000000000..f585b7d55
--- /dev/null
+++ b/protocol/triple/triple_protocol/cors_test.go
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package triple_protocol
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+import (
+ "dubbo.apache.org/dubbo-go/v3/global"
+
"dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol/internal/assert"
+)
+
+func convertCorsConfigForTest(cfg *global.CorsConfig) *CorsConfig {
+ if cfg == nil {
+ return nil
+ }
+ return &CorsConfig{
+ AllowOrigins: append([]string(nil), cfg.AllowOrigins...),
+ AllowMethods: append([]string(nil), cfg.AllowMethods...),
+ AllowHeaders: append([]string(nil), cfg.AllowHeaders...),
+ ExposeHeaders: append([]string(nil), cfg.ExposeHeaders...),
+ AllowCredentials: cfg.AllowCredentials,
+ MaxAge: cfg.MaxAge,
+ }
+}
+
+func TestMatchOrigin(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ origin string
+ allowed []string
+ want bool
+ }{
+ {"exact match", "https://api.example.com",
[]string{"https://api.example.com"}, true},
+ {"wildcard any", "https://foo.bar", []string{"*"}, true},
+ {"subdomain wildcard", "https://a.example.com",
[]string{"https://*.example.com"}, true},
+ {"subdomain no scheme", "https://b.example.com",
[]string{"*.example.com"}, true},
+ {"scheme mismatch", "http://a.example.com",
[]string{"https://*.example.com"}, false},
+ {"not matched", "https://other.com",
[]string{"https://example.com"}, false},
+ {"empty origin", "", []string{"https://api.example.com"},
false},
+ {"default port https", "https://api.example.com:443",
[]string{"https://api.example.com"}, true},
+ {"default port http", "http://api.example.com:80",
[]string{"http://api.example.com"}, true},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ c := &CorsConfig{AllowOrigins: tt.allowed}
+ assert.Equal(t, c.matchOrigin(tt.origin), tt.want)
+ })
+ }
+}
+
+func TestAddCORSHeaders(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ policy *CorsConfig
+ origin string
+ wantOrigin string
+ wantCreds bool
+ }{
+ {"allowed with creds",
buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{
+ AllowOrigins: []string{"https://a.com"},
+ AllowCredentials: true,
+ }), nil), "https://a.com", "https://a.com", true},
+ {"nil policy", nil, "https://a.com", "", false},
+ {"empty origin",
buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{
+ AllowOrigins: []string{"https://a.com"},
+ }), nil), "", "", false},
+ {"not allowed",
buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{
+ AllowOrigins: []string{"https://a.com"},
+ }), nil), "https://b.com", "", false},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if tt.origin != "" {
+ req.Header.Set("Origin", tt.origin)
+ }
+ rr := httptest.NewRecorder()
+
+ tt.policy.addCORSHeaders(rr, req)
+ assert.Equal(t,
rr.Header().Get("Access-Control-Allow-Origin"), tt.wantOrigin)
+ if tt.wantCreds {
+ assert.Equal(t,
rr.Header().Get("Access-Control-Allow-Credentials"), "true")
+ }
+ })
+ }
+}
+
+func TestBuildCorsPolicy(t *testing.T) {
+ t.Parallel()
+ assert.Nil(t, buildCorsPolicy(nil, nil))
+ assert.Nil(t,
buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{}), nil))
+
+ p := buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{
+ AllowOrigins: []string{"https://a.com"},
+ MaxAge: 123,
+ }), nil)
+ assert.NotNil(t, p)
+ assert.Equal(t, p.MaxAge, 123)
+}
+
+func TestServeHTTPCORS(t *testing.T) {
+ t.Parallel()
+ // Allowed origin
+ h1 := NewUnaryHandler("dummy", func() any { return nil }, func(ctx
context.Context, req *Request) (*Response, error) {
+ return &Response{Msg: struct{}{}, header: http.Header{},
trailer: http.Header{}}, nil
+ }, WithCORS(&CorsConfig{AllowOrigins: []string{"https://a.com"},
AllowCredentials: true}))
+
+ req1 := httptest.NewRequest(http.MethodPost, "/", nil)
+ req1.Header.Set("Origin", "https://a.com")
+ req1.Header.Set(tripleServiceGroup, "")
+ req1.Header.Set(tripleServiceVersion, "")
+ rr1 := httptest.NewRecorder()
+ h1.ServeHTTP(rr1, req1)
+ assert.Equal(t, rr1.Header().Get("Access-Control-Allow-Origin"),
"https://a.com")
+ assert.Equal(t, rr1.Header().Get("Access-Control-Allow-Credentials"),
"true")
+
+ // Forbidden origin
+ h2 := NewUnaryHandler("dummy", func() any { return nil }, func(ctx
context.Context, req *Request) (*Response, error) {
+ return &Response{Msg: struct{}{}}, nil
+ }, WithCORS(&CorsConfig{AllowOrigins: []string{"https://a.com"}}))
+
+ req2 := httptest.NewRequest(http.MethodPost, "/", nil)
+ req2.Header.Set("Origin", "https://b.com")
+ req2.Header.Set(tripleServiceGroup, "")
+ req2.Header.Set(tripleServiceVersion, "")
+ rr2 := httptest.NewRecorder()
+ h2.ServeHTTP(rr2, req2)
+ assert.Equal(t, rr2.Code, http.StatusForbidden)
+}
+
+func TestHandleCORS(t *testing.T) {
+ t.Parallel()
+ cors := buildCorsPolicy(convertCorsConfigForTest(&global.CorsConfig{
+ AllowOrigins: []string{"https://a.com"},
+ }), nil)
+
+ tests := []struct {
+ name string
+ cors *CorsConfig
+ method string
+ origin string
+ preflightHeader string
+ wantHandled bool
+ wantCode int
+ }{
+ {"nil config", nil, http.MethodGet, "https://a.com", "", false,
0},
+ {"preflight success", cors, http.MethodOptions,
"https://a.com", http.MethodPost, true, http.StatusNoContent},
+ {"forbidden origin", cors, http.MethodGet, "https://b.com", "",
true, http.StatusForbidden},
+ {"allowed origin", cors, http.MethodGet, "https://a.com", "",
false, 0},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ h := &Handler{cors: tt.cors}
+ req := httptest.NewRequest(tt.method, "/", nil)
+ if tt.origin != "" {
+ req.Header.Set("Origin", tt.origin)
+ }
+ if tt.preflightHeader != "" {
+ req.Header.Set("Access-Control-Request-Method",
tt.preflightHeader)
+ }
+ rr := httptest.NewRecorder()
+
+ handled := h.handleCORS(rr, req)
+ assert.Equal(t, handled, tt.wantHandled)
+ if tt.wantCode != 0 {
+ assert.Equal(t, rr.Code, tt.wantCode)
+ }
+ })
+ }
+}
diff --git a/protocol/triple/triple_protocol/handler.go
b/protocol/triple/triple_protocol/handler.go
index 256a7dcbb..ff12d9bb1 100644
--- a/protocol/triple/triple_protocol/handler.go
+++ b/protocol/triple/triple_protocol/handler.go
@@ -35,8 +35,9 @@ type Handler struct {
// key is group/version
implementations map[string]StreamingHandlerFunc
protocolHandlers []protocolHandler
- allowMethod string // Allow header
- acceptPost string // Accept-Post header
+ allowMethod string // Allow header
+ acceptPost string // Accept-Post header
+ cors *CorsConfig // CORS policy
}
// NewUnaryHandler constructs a [Handler] for a request-response procedure.
@@ -56,6 +57,7 @@ func NewUnaryHandler(
protocolHandlers: protocolHandlers,
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
+ cors: buildCorsPolicy(config.Cors,
protocolHandlers),
}
hdl.processImplementation(getIdentifier(config.Group, config.Version),
implementation)
return hdl
@@ -141,6 +143,7 @@ func NewClientStreamHandler(
protocolHandlers: protocolHandlers,
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
+ cors: buildCorsPolicy(config.Cors,
protocolHandlers),
}
hdl.processImplementation(getIdentifier(config.Group, config.Version),
implementation)
@@ -199,6 +202,7 @@ func NewServerStreamHandler(
protocolHandlers: protocolHandlers,
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
+ cors: buildCorsPolicy(config.Cors,
protocolHandlers),
}
hdl.processImplementation(getIdentifier(config.Group, config.Version),
implementation)
@@ -259,6 +263,7 @@ func NewBidiStreamHandler(
protocolHandlers: protocolHandlers,
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
+ cors: buildCorsPolicy(config.Cors,
protocolHandlers),
}
hdl.processImplementation(getIdentifier(config.Group, config.Version),
implementation)
@@ -309,6 +314,13 @@ func (h *Handler) ServeHTTP(responseWriter
http.ResponseWriter, request *http.Re
return
}
+ // CORS handling
+ if h.cors != nil {
+ if h.handleCORS(responseWriter, request) {
+ return
+ }
+ }
+
// inspect headers
var protocolHandlers []protocolHandler
for _, handler := range h.protocolHandlers {
@@ -393,6 +405,7 @@ type handlerConfig struct {
SendMaxBytes int
Group string
Version string
+ Cors *CorsConfig
}
func newHandlerConfig(procedure string, options []HandlerOption)
*handlerConfig {
@@ -462,3 +475,30 @@ func (c *handlerConfig) newProtocolHandlers(streamType
StreamType) []protocolHan
func getIdentifier(group, version string) string {
return group + "/" + version
}
+
+// handleCORS processes CORS requests. Returns true if the request was handled
and processing should stop.
+func (h *Handler) handleCORS(w http.ResponseWriter, r *http.Request) bool {
+ if h.cors == nil {
+ return false
+ }
+
+ // Handle preflight requests
+ if r.Method == http.MethodOptions && r.Header.Get(corsRequestMethod) !=
"" {
+ return h.cors.handlePreflight(w, r)
+ }
+
+ // Handle normal requests with Origin header
+ origin := r.Header.Get(corsOrigin)
+ if origin == "" {
+ return false
+ }
+
+ if !h.cors.matchOrigin(origin) {
+ w.Header().Add(corsVary, corsOrigin)
+ w.WriteHeader(http.StatusForbidden)
+ return true
+ }
+
+ h.cors.addCORSHeaders(w, r)
+ return false
+}
diff --git a/protocol/triple/triple_protocol/option.go
b/protocol/triple/triple_protocol/option.go
index 13830e0e5..197df4d95 100644
--- a/protocol/triple/triple_protocol/option.go
+++ b/protocol/triple/triple_protocol/option.go
@@ -178,6 +178,11 @@ func WithRequireTripleProtocolHeader() HandlerOption {
return &requireTripleProtocolHeaderOption{}
}
+// WithCORS configures CORS for the handler.
+func WithCORS(cors *CorsConfig) HandlerOption {
+ return &corsOption{cors: cors}
+}
+
func WithGroup(group string) Option {
return &groupOption{group}
}
@@ -443,6 +448,17 @@ func (o *requireTripleProtocolHeaderOption)
applyToHandler(config *handlerConfig
config.RequireTripleProtocolHeader = true
}
+type corsOption struct {
+ cors *CorsConfig
+}
+
+func (o *corsOption) applyToHandler(config *handlerConfig) {
+ if o.cors == nil {
+ return
+ }
+ config.Cors = o.cors
+}
+
type groupOption struct {
Group string
}