This is an automated email from the ASF dual-hosted git repository.
alexstocks pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git
The following commit(s) were added to refs/heads/develop by this push:
new 82957c516 fix(core): resolve #3247 lock misuse and map exposure races
(#3265)
82957c516 is described below
commit 82957c516cba9a391718921eb3b38794a1015909
Author: xxs <[email protected]>
AuthorDate: Sat Mar 21 12:24:53 2026 +0800
fix(core): resolve #3247 lock misuse and map exposure races (#3265)
* fix(dubbo): use correct lock for provider service and lock consumer
connection access
Fixes: #3247
* fix(router): guard RouterChain.Route invoker reads with RLock
Fixes: #3247
* fix(config): return copied service maps instead of internal references
Fixes: #3247
* test(core): add non-duplicate coverage for #3247 race fixes
* test(router): resolve Sonar findings in chain tests
---
cluster/router/chain/chain.go | 10 ++-
cluster/router/chain/chain_test.go | 138 +++++++++++++++++++++++++++++++++++++
config/service.go | 18 ++++-
config/service_test.go | 73 ++++++++++++++++++++
dubbo.go | 6 +-
dubbo_test.go | 70 +++++++++++++++++++
6 files changed, 308 insertions(+), 7 deletions(-)
diff --git a/cluster/router/chain/chain.go b/cluster/router/chain/chain.go
index 3159f9c39..d849c5804 100644
--- a/cluster/router/chain/chain.go
+++ b/cluster/router/chain/chain.go
@@ -53,17 +53,21 @@ type RouterChain struct {
// Route Loop routers in RouterChain and call Route method to determine the
target invokers list.
func (c *RouterChain) Route(url *common.URL, invocation base.Invocation)
[]base.Invoker {
- finalInvokers := make([]base.Invoker, 0, len(c.invokers))
+ c.mutex.RLock()
+ invokers := c.invokers
+ c.mutex.RUnlock()
+
+ finalInvokers := make([]base.Invoker, 0, len(invokers))
// multiple invoker may include different methods, find correct invoker
otherwise
// will return the invoker without methods
- for _, invoker := range c.invokers {
+ for _, invoker := range invokers {
if invoker.GetURL().ServiceKey() == url.ServiceKey() {
finalInvokers = append(finalInvokers, invoker)
}
}
if len(finalInvokers) == 0 {
- finalInvokers = c.invokers
+ finalInvokers = invokers
}
for _, r := range c.copyRouters() {
diff --git a/cluster/router/chain/chain_test.go
b/cluster/router/chain/chain_test.go
new file mode 100644
index 000000000..afdc39420
--- /dev/null
+++ b/cluster/router/chain/chain_test.go
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package chain
+
+import (
+ "testing"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+import (
+ "dubbo.apache.org/dubbo-go/v3/cluster/router"
+ "dubbo.apache.org/dubbo-go/v3/common"
+ "dubbo.apache.org/dubbo-go/v3/protocol/base"
+ "dubbo.apache.org/dubbo-go/v3/protocol/invocation"
+)
+
+const testConsumerServiceURL = "consumer://127.0.0.1/com.demo.Service"
+
+type testPriorityRouter struct {
+ priority int64
+ called int
+ lastSize int
+
+ notifyFn func([]base.Invoker)
+ routeFn func([]base.Invoker, *common.URL, base.Invocation)
[]base.Invoker
+}
+
+func (r *testPriorityRouter) Route(invokers []base.Invoker, url *common.URL,
inv base.Invocation) []base.Invoker {
+ r.called++
+ r.lastSize = len(invokers)
+ if r.routeFn != nil {
+ return r.routeFn(invokers, url, inv)
+ }
+ return invokers
+}
+
+func (r *testPriorityRouter) URL() *common.URL {
+ return nil
+}
+
+func (r *testPriorityRouter) Priority() int64 {
+ return r.priority
+}
+
+func (r *testPriorityRouter) Notify(invokers []base.Invoker) {
+ if r.notifyFn != nil {
+ r.notifyFn(invokers)
+ }
+}
+
+func buildInvoker(t *testing.T, rawURL string) base.Invoker {
+ u, err := common.NewURL(rawURL)
+ require.NoError(t, err)
+ return base.NewBaseInvoker(u)
+}
+
+func TestRouteUsesServiceKeyMatchWhenAvailable(t *testing.T) {
+ consumerURL, err := common.NewURL(testConsumerServiceURL)
+ require.NoError(t, err)
+
+ match := buildInvoker(t, "dubbo://127.0.0.1:20000/com.demo.Service")
+ nonMatch := buildInvoker(t, "dubbo://127.0.0.1:20001/com.other.Service")
+
+ r := &testPriorityRouter{priority: 1}
+ chain := &RouterChain{
+ invokers: []base.Invoker{match, nonMatch},
+ routers: []router.PriorityRouter{r},
+ }
+
+ result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say",
nil, nil))
+ assert.Len(t, result, 1)
+ assert.Equal(t, match.GetURL().String(), result[0].GetURL().String())
+ assert.Equal(t, 1, r.called)
+ assert.Equal(t, 1, r.lastSize)
+}
+
+func TestRouteFallsBackToAllInvokersWhenNoMatch(t *testing.T) {
+ consumerURL, err := common.NewURL(testConsumerServiceURL)
+ require.NoError(t, err)
+
+ invokerA := buildInvoker(t, "dubbo://127.0.0.1:20000/com.foo.Service")
+ invokerB := buildInvoker(t, "dubbo://127.0.0.1:20001/com.bar.Service")
+
+ r := &testPriorityRouter{priority: 1}
+ chain := &RouterChain{
+ invokers: []base.Invoker{invokerA, invokerB},
+ routers: []router.PriorityRouter{r},
+ }
+
+ result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say",
nil, nil))
+ assert.Len(t, result, 2)
+ assert.Equal(t, 1, r.called)
+ assert.Equal(t, 2, r.lastSize)
+}
+
+func TestRouteAppliesRoutersOnSnapshot(t *testing.T) {
+ consumerURL, err := common.NewURL(testConsumerServiceURL)
+ require.NoError(t, err)
+
+ invokerA := buildInvoker(t, "dubbo://127.0.0.1:20000/com.demo.Service")
+ invokerB := buildInvoker(t, "dubbo://127.0.0.1:20001/com.demo.Service")
+
+ r1 := &testPriorityRouter{priority: 1, routeFn: func(invokers
[]base.Invoker, _ *common.URL, _ base.Invocation) []base.Invoker {
+ return invokers[:1]
+ }}
+ r2 := &testPriorityRouter{priority: 2}
+
+ chain := &RouterChain{
+ invokers: []base.Invoker{invokerA, invokerB},
+ routers: []router.PriorityRouter{r1, r2},
+ }
+
+ result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say",
nil, nil))
+ assert.Len(t, result, 1)
+ assert.Equal(t, invokerA.GetURL().String(), result[0].GetURL().String())
+ assert.Equal(t, 1, r1.called)
+ assert.Equal(t, 1, r2.called)
+ assert.Equal(t, 1, r2.lastSize)
+}
diff --git a/config/service.go b/config/service.go
index 2a241fa7b..cdcfbc52a 100644
--- a/config/service.go
+++ b/config/service.go
@@ -89,7 +89,14 @@ func GetProviderService(name string) common.RPCService {
// GetProviderServiceMap gets ProviderServiceMap
func GetProviderServiceMap() map[string]common.RPCService {
- return proServices
+ proServicesLock.Lock()
+ defer proServicesLock.Unlock()
+
+ m := make(map[string]common.RPCService, len(proServices))
+ for k, v := range proServices {
+ m[k] = v
+ }
+ return m
}
func GetProviderServiceInfo(name string) any {
@@ -100,7 +107,14 @@ func GetProviderServiceInfo(name string) any {
// GetConsumerServiceMap gets ProviderServiceMap
func GetConsumerServiceMap() map[string]common.RPCService {
- return conServices
+ conServicesLock.Lock()
+ defer conServicesLock.Unlock()
+
+ m := make(map[string]common.RPCService, len(conServices))
+ for k, v := range conServices {
+ m[k] = v
+ }
+ return m
}
// SetConsumerServiceByInterfaceName is used by pb serialization
diff --git a/config/service_test.go b/config/service_test.go
index be94b05c7..179ffa2f6 100644
--- a/config/service_test.go
+++ b/config/service_test.go
@@ -18,13 +18,23 @@
package config
import (
+ "maps"
"testing"
)
import (
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
+import (
+ "dubbo.apache.org/dubbo-go/v3/common"
+)
+
+func cloneRPCServiceMap(src map[string]common.RPCService)
map[string]common.RPCService {
+ return maps.Clone(src)
+}
+
func TestGetConsumerService(t *testing.T) {
SetConsumerService(&HelloService{})
@@ -43,3 +53,66 @@ func TestGetConsumerService(t *testing.T) {
callback := GetCallback(reference)
assert.Nil(t, callback)
}
+
+func TestGetProviderServiceMapReturnsCopy(t *testing.T) {
+ proServicesLock.Lock()
+ originalProServices := cloneRPCServiceMap(proServices)
+ originalProServicesInfo := maps.Clone(proServicesInfo)
+ proServices = map[string]common.RPCService{}
+ proServicesInfo = map[string]any{}
+ proServicesLock.Unlock()
+
+ defer func() {
+ proServicesLock.Lock()
+ proServices = originalProServices
+ proServicesInfo = originalProServicesInfo
+ proServicesLock.Unlock()
+ }()
+
+ svc := &HelloService{}
+ SetProviderService(svc)
+
+ got := GetProviderServiceMap()
+ require.Len(t, got, 1)
+
+ got["Injected"] = &HelloService{}
+ got["HelloService"] = &HelloService{}
+
+ proServicesLock.Lock()
+ _, hasInjected := proServices["Injected"]
+ _, hasHelloService := proServices["HelloService"]
+ proServicesLock.Unlock()
+
+ assert.False(t, hasInjected)
+ assert.True(t, hasHelloService)
+}
+
+func TestGetConsumerServiceMapReturnsCopy(t *testing.T) {
+ conServicesLock.Lock()
+ originalConServices := cloneRPCServiceMap(conServices)
+ conServices = map[string]common.RPCService{}
+ conServicesLock.Unlock()
+
+ defer func() {
+ conServicesLock.Lock()
+ conServices = originalConServices
+ conServicesLock.Unlock()
+ }()
+
+ svc := &HelloService{}
+ SetConsumerService(svc)
+
+ got := GetConsumerServiceMap()
+ require.Len(t, got, 1)
+
+ got["Injected"] = &HelloService{}
+ got["HelloService"] = &HelloService{}
+
+ conServicesLock.Lock()
+ _, hasInjected := conServices["Injected"]
+ stored := conServices["HelloService"]
+ conServicesLock.Unlock()
+
+ assert.False(t, hasInjected)
+ assert.Equal(t, svc, stored)
+}
diff --git a/dubbo.go b/dubbo.go
index 2a132d3cb..125af0385 100644
--- a/dubbo.go
+++ b/dubbo.go
@@ -298,13 +298,15 @@ func SetConsumerService(svc common.RPCService) {
}
func SetProviderService(svc common.RPCService) {
- conLock.Lock()
- defer conLock.Unlock()
+ proLock.Lock()
+ defer proLock.Unlock()
providerServices[common.GetReference(svc)] = &server.ServiceDefinition{
Handler: svc,
}
}
func GetConsumerConnection(interfaceName string) (*client.Connection, error) {
+ conLock.RLock()
+ defer conLock.RUnlock()
return consumerServices[interfaceName].GetConnection()
}
diff --git a/dubbo_test.go b/dubbo_test.go
index e23c97394..01d3bc5d4 100644
--- a/dubbo_test.go
+++ b/dubbo_test.go
@@ -18,11 +18,13 @@
package dubbo
import (
+ "maps"
"testing"
)
import (
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
import (
@@ -32,6 +34,22 @@ import (
"dubbo.apache.org/dubbo-go/v3/server"
)
+type testRPCService struct {
+ ref string
+}
+
+func (s *testRPCService) Reference() string {
+ return s.ref
+}
+
+func cloneClientDefinitions(src map[string]*client.ClientDefinition)
map[string]*client.ClientDefinition {
+ return maps.Clone(src)
+}
+
+func cloneServiceDefinitions(src map[string]*server.ServiceDefinition)
map[string]*server.ServiceDefinition {
+ return maps.Clone(src)
+}
+
// TestIndependentConfig tests the configurations of the `instance`, `client`,
and `server` are independent.
func TestIndependentConfig(t *testing.T) {
// instance configuration
@@ -96,3 +114,55 @@ func TestIndependentConfig(t *testing.T) {
panic(err)
}
}
+
+func TestSetProviderServiceRegistersByReference(t *testing.T) {
+ proLock.Lock()
+ original := cloneServiceDefinitions(providerServices)
+ providerServices = make(map[string]*server.ServiceDefinition)
+ proLock.Unlock()
+
+ defer func() {
+ proLock.Lock()
+ providerServices = original
+ proLock.Unlock()
+ }()
+
+ svc := &testRPCService{ref: "provider.test.Service"}
+ SetProviderService(svc)
+
+ proLock.RLock()
+ defer proLock.RUnlock()
+ def, ok := providerServices[svc.Reference()]
+ require.True(t, ok)
+ require.NotNil(t, def)
+ assert.Equal(t, svc, def.Handler)
+}
+
+func TestGetConsumerConnectionFromConsumerServices(t *testing.T) {
+ conLock.Lock()
+ original := cloneClientDefinitions(consumerServices)
+ consumerServices = make(map[string]*client.ClientDefinition)
+ conLock.Unlock()
+
+ defer func() {
+ conLock.Lock()
+ consumerServices = original
+ conLock.Unlock()
+ }()
+
+ svc := &testRPCService{ref: "consumer.test.Service"}
+ SetConsumerService(svc)
+
+ conn, err := GetConsumerConnection(svc.Reference())
+ require.Error(t, err)
+ require.Nil(t, conn)
+
+ expectedConn := &client.Connection{}
+ conLock.Lock()
+ consumerServices[svc.Reference()].SetConnection(expectedConn)
+ conLock.Unlock()
+
+ conn, err = GetConsumerConnection(svc.Reference())
+ require.NoError(t, err)
+ assert.Equal(t, expectedConn, conn)
+}