This is an automated email from the ASF dual-hosted git repository.

alexstocks pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git


The following commit(s) were added to refs/heads/develop by this push:
     new 82957c516 fix(core): resolve #3247 lock misuse and map exposure races 
(#3265)
82957c516 is described below

commit 82957c516cba9a391718921eb3b38794a1015909
Author: xxs <[email protected]>
AuthorDate: Sat Mar 21 12:24:53 2026 +0800

    fix(core): resolve #3247 lock misuse and map exposure races (#3265)
    
    * fix(dubbo): use correct lock for provider service and lock consumer 
connection access
    
    Fixes: #3247
    
    * fix(router): guard RouterChain.Route invoker reads with RLock
    
    Fixes: #3247
    
    * fix(config): return copied service maps instead of internal references
    
    Fixes: #3247
    
    * test(core): add non-duplicate coverage for #3247 race fixes
    
    * test(router): resolve Sonar findings in chain tests
---
 cluster/router/chain/chain.go      |  10 ++-
 cluster/router/chain/chain_test.go | 138 +++++++++++++++++++++++++++++++++++++
 config/service.go                  |  18 ++++-
 config/service_test.go             |  73 ++++++++++++++++++++
 dubbo.go                           |   6 +-
 dubbo_test.go                      |  70 +++++++++++++++++++
 6 files changed, 308 insertions(+), 7 deletions(-)

diff --git a/cluster/router/chain/chain.go b/cluster/router/chain/chain.go
index 3159f9c39..d849c5804 100644
--- a/cluster/router/chain/chain.go
+++ b/cluster/router/chain/chain.go
@@ -53,17 +53,21 @@ type RouterChain struct {
 
 // Route Loop routers in RouterChain and call Route method to determine the 
target invokers list.
 func (c *RouterChain) Route(url *common.URL, invocation base.Invocation) 
[]base.Invoker {
-       finalInvokers := make([]base.Invoker, 0, len(c.invokers))
+       c.mutex.RLock()
+       invokers := c.invokers
+       c.mutex.RUnlock()
+
+       finalInvokers := make([]base.Invoker, 0, len(invokers))
        // multiple invoker may include different methods, find correct invoker 
otherwise
        // will return the invoker without methods
-       for _, invoker := range c.invokers {
+       for _, invoker := range invokers {
                if invoker.GetURL().ServiceKey() == url.ServiceKey() {
                        finalInvokers = append(finalInvokers, invoker)
                }
        }
 
        if len(finalInvokers) == 0 {
-               finalInvokers = c.invokers
+               finalInvokers = invokers
        }
 
        for _, r := range c.copyRouters() {
diff --git a/cluster/router/chain/chain_test.go 
b/cluster/router/chain/chain_test.go
new file mode 100644
index 000000000..afdc39420
--- /dev/null
+++ b/cluster/router/chain/chain_test.go
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package chain
+
+import (
+       "testing"
+)
+
+import (
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
+)
+
+import (
+       "dubbo.apache.org/dubbo-go/v3/cluster/router"
+       "dubbo.apache.org/dubbo-go/v3/common"
+       "dubbo.apache.org/dubbo-go/v3/protocol/base"
+       "dubbo.apache.org/dubbo-go/v3/protocol/invocation"
+)
+
+const testConsumerServiceURL = "consumer://127.0.0.1/com.demo.Service"
+
+type testPriorityRouter struct {
+       priority int64
+       called   int
+       lastSize int
+
+       notifyFn func([]base.Invoker)
+       routeFn  func([]base.Invoker, *common.URL, base.Invocation) 
[]base.Invoker
+}
+
+func (r *testPriorityRouter) Route(invokers []base.Invoker, url *common.URL, 
inv base.Invocation) []base.Invoker {
+       r.called++
+       r.lastSize = len(invokers)
+       if r.routeFn != nil {
+               return r.routeFn(invokers, url, inv)
+       }
+       return invokers
+}
+
+func (r *testPriorityRouter) URL() *common.URL {
+       return nil
+}
+
+func (r *testPriorityRouter) Priority() int64 {
+       return r.priority
+}
+
+func (r *testPriorityRouter) Notify(invokers []base.Invoker) {
+       if r.notifyFn != nil {
+               r.notifyFn(invokers)
+       }
+}
+
+func buildInvoker(t *testing.T, rawURL string) base.Invoker {
+       u, err := common.NewURL(rawURL)
+       require.NoError(t, err)
+       return base.NewBaseInvoker(u)
+}
+
+func TestRouteUsesServiceKeyMatchWhenAvailable(t *testing.T) {
+       consumerURL, err := common.NewURL(testConsumerServiceURL)
+       require.NoError(t, err)
+
+       match := buildInvoker(t, "dubbo://127.0.0.1:20000/com.demo.Service")
+       nonMatch := buildInvoker(t, "dubbo://127.0.0.1:20001/com.other.Service")
+
+       r := &testPriorityRouter{priority: 1}
+       chain := &RouterChain{
+               invokers: []base.Invoker{match, nonMatch},
+               routers:  []router.PriorityRouter{r},
+       }
+
+       result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say", 
nil, nil))
+       assert.Len(t, result, 1)
+       assert.Equal(t, match.GetURL().String(), result[0].GetURL().String())
+       assert.Equal(t, 1, r.called)
+       assert.Equal(t, 1, r.lastSize)
+}
+
+func TestRouteFallsBackToAllInvokersWhenNoMatch(t *testing.T) {
+       consumerURL, err := common.NewURL(testConsumerServiceURL)
+       require.NoError(t, err)
+
+       invokerA := buildInvoker(t, "dubbo://127.0.0.1:20000/com.foo.Service")
+       invokerB := buildInvoker(t, "dubbo://127.0.0.1:20001/com.bar.Service")
+
+       r := &testPriorityRouter{priority: 1}
+       chain := &RouterChain{
+               invokers: []base.Invoker{invokerA, invokerB},
+               routers:  []router.PriorityRouter{r},
+       }
+
+       result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say", 
nil, nil))
+       assert.Len(t, result, 2)
+       assert.Equal(t, 1, r.called)
+       assert.Equal(t, 2, r.lastSize)
+}
+
+func TestRouteAppliesRoutersOnSnapshot(t *testing.T) {
+       consumerURL, err := common.NewURL(testConsumerServiceURL)
+       require.NoError(t, err)
+
+       invokerA := buildInvoker(t, "dubbo://127.0.0.1:20000/com.demo.Service")
+       invokerB := buildInvoker(t, "dubbo://127.0.0.1:20001/com.demo.Service")
+
+       r1 := &testPriorityRouter{priority: 1, routeFn: func(invokers 
[]base.Invoker, _ *common.URL, _ base.Invocation) []base.Invoker {
+               return invokers[:1]
+       }}
+       r2 := &testPriorityRouter{priority: 2}
+
+       chain := &RouterChain{
+               invokers: []base.Invoker{invokerA, invokerB},
+               routers:  []router.PriorityRouter{r1, r2},
+       }
+
+       result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say", 
nil, nil))
+       assert.Len(t, result, 1)
+       assert.Equal(t, invokerA.GetURL().String(), result[0].GetURL().String())
+       assert.Equal(t, 1, r1.called)
+       assert.Equal(t, 1, r2.called)
+       assert.Equal(t, 1, r2.lastSize)
+}
diff --git a/config/service.go b/config/service.go
index 2a241fa7b..cdcfbc52a 100644
--- a/config/service.go
+++ b/config/service.go
@@ -89,7 +89,14 @@ func GetProviderService(name string) common.RPCService {
 
 // GetProviderServiceMap gets ProviderServiceMap
 func GetProviderServiceMap() map[string]common.RPCService {
-       return proServices
+       proServicesLock.Lock()
+       defer proServicesLock.Unlock()
+
+       m := make(map[string]common.RPCService, len(proServices))
+       for k, v := range proServices {
+               m[k] = v
+       }
+       return m
 }
 
 func GetProviderServiceInfo(name string) any {
@@ -100,7 +107,14 @@ func GetProviderServiceInfo(name string) any {
 
 // GetConsumerServiceMap gets ProviderServiceMap
 func GetConsumerServiceMap() map[string]common.RPCService {
-       return conServices
+       conServicesLock.Lock()
+       defer conServicesLock.Unlock()
+
+       m := make(map[string]common.RPCService, len(conServices))
+       for k, v := range conServices {
+               m[k] = v
+       }
+       return m
 }
 
 // SetConsumerServiceByInterfaceName is used by pb serialization
diff --git a/config/service_test.go b/config/service_test.go
index be94b05c7..179ffa2f6 100644
--- a/config/service_test.go
+++ b/config/service_test.go
@@ -18,13 +18,23 @@
 package config
 
 import (
+       "maps"
        "testing"
 )
 
 import (
        "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
 )
 
+import (
+       "dubbo.apache.org/dubbo-go/v3/common"
+)
+
+func cloneRPCServiceMap(src map[string]common.RPCService) 
map[string]common.RPCService {
+       return maps.Clone(src)
+}
+
 func TestGetConsumerService(t *testing.T) {
 
        SetConsumerService(&HelloService{})
@@ -43,3 +53,66 @@ func TestGetConsumerService(t *testing.T) {
        callback := GetCallback(reference)
        assert.Nil(t, callback)
 }
+
+func TestGetProviderServiceMapReturnsCopy(t *testing.T) {
+       proServicesLock.Lock()
+       originalProServices := cloneRPCServiceMap(proServices)
+       originalProServicesInfo := maps.Clone(proServicesInfo)
+       proServices = map[string]common.RPCService{}
+       proServicesInfo = map[string]any{}
+       proServicesLock.Unlock()
+
+       defer func() {
+               proServicesLock.Lock()
+               proServices = originalProServices
+               proServicesInfo = originalProServicesInfo
+               proServicesLock.Unlock()
+       }()
+
+       svc := &HelloService{}
+       SetProviderService(svc)
+
+       got := GetProviderServiceMap()
+       require.Len(t, got, 1)
+
+       got["Injected"] = &HelloService{}
+       got["HelloService"] = &HelloService{}
+
+       proServicesLock.Lock()
+       _, hasInjected := proServices["Injected"]
+       _, hasHelloService := proServices["HelloService"]
+       proServicesLock.Unlock()
+
+       assert.False(t, hasInjected)
+       assert.True(t, hasHelloService)
+}
+
+func TestGetConsumerServiceMapReturnsCopy(t *testing.T) {
+       conServicesLock.Lock()
+       originalConServices := cloneRPCServiceMap(conServices)
+       conServices = map[string]common.RPCService{}
+       conServicesLock.Unlock()
+
+       defer func() {
+               conServicesLock.Lock()
+               conServices = originalConServices
+               conServicesLock.Unlock()
+       }()
+
+       svc := &HelloService{}
+       SetConsumerService(svc)
+
+       got := GetConsumerServiceMap()
+       require.Len(t, got, 1)
+
+       got["Injected"] = &HelloService{}
+       got["HelloService"] = &HelloService{}
+
+       conServicesLock.Lock()
+       _, hasInjected := conServices["Injected"]
+       stored := conServices["HelloService"]
+       conServicesLock.Unlock()
+
+       assert.False(t, hasInjected)
+       assert.Equal(t, svc, stored)
+}
diff --git a/dubbo.go b/dubbo.go
index 2a132d3cb..125af0385 100644
--- a/dubbo.go
+++ b/dubbo.go
@@ -298,13 +298,15 @@ func SetConsumerService(svc common.RPCService) {
 }
 
 func SetProviderService(svc common.RPCService) {
-       conLock.Lock()
-       defer conLock.Unlock()
+       proLock.Lock()
+       defer proLock.Unlock()
        providerServices[common.GetReference(svc)] = &server.ServiceDefinition{
                Handler: svc,
        }
 }
 
 func GetConsumerConnection(interfaceName string) (*client.Connection, error) {
+       conLock.RLock()
+       defer conLock.RUnlock()
        return consumerServices[interfaceName].GetConnection()
 }
diff --git a/dubbo_test.go b/dubbo_test.go
index e23c97394..01d3bc5d4 100644
--- a/dubbo_test.go
+++ b/dubbo_test.go
@@ -18,11 +18,13 @@
 package dubbo
 
 import (
+       "maps"
        "testing"
 )
 
 import (
        "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
 )
 
 import (
@@ -32,6 +34,22 @@ import (
        "dubbo.apache.org/dubbo-go/v3/server"
 )
 
+type testRPCService struct {
+       ref string
+}
+
+func (s *testRPCService) Reference() string {
+       return s.ref
+}
+
+func cloneClientDefinitions(src map[string]*client.ClientDefinition) 
map[string]*client.ClientDefinition {
+       return maps.Clone(src)
+}
+
+func cloneServiceDefinitions(src map[string]*server.ServiceDefinition) 
map[string]*server.ServiceDefinition {
+       return maps.Clone(src)
+}
+
 // TestIndependentConfig tests the configurations of the `instance`, `client`, 
and `server` are independent.
 func TestIndependentConfig(t *testing.T) {
        // instance configuration
@@ -96,3 +114,55 @@ func TestIndependentConfig(t *testing.T) {
                panic(err)
        }
 }
+
+func TestSetProviderServiceRegistersByReference(t *testing.T) {
+       proLock.Lock()
+       original := cloneServiceDefinitions(providerServices)
+       providerServices = make(map[string]*server.ServiceDefinition)
+       proLock.Unlock()
+
+       defer func() {
+               proLock.Lock()
+               providerServices = original
+               proLock.Unlock()
+       }()
+
+       svc := &testRPCService{ref: "provider.test.Service"}
+       SetProviderService(svc)
+
+       proLock.RLock()
+       defer proLock.RUnlock()
+       def, ok := providerServices[svc.Reference()]
+       require.True(t, ok)
+       require.NotNil(t, def)
+       assert.Equal(t, svc, def.Handler)
+}
+
+func TestGetConsumerConnectionFromConsumerServices(t *testing.T) {
+       conLock.Lock()
+       original := cloneClientDefinitions(consumerServices)
+       consumerServices = make(map[string]*client.ClientDefinition)
+       conLock.Unlock()
+
+       defer func() {
+               conLock.Lock()
+               consumerServices = original
+               conLock.Unlock()
+       }()
+
+       svc := &testRPCService{ref: "consumer.test.Service"}
+       SetConsumerService(svc)
+
+       conn, err := GetConsumerConnection(svc.Reference())
+       require.Error(t, err)
+       require.Nil(t, conn)
+
+       expectedConn := &client.Connection{}
+       conLock.Lock()
+       consumerServices[svc.Reference()].SetConnection(expectedConn)
+       conLock.Unlock()
+
+       conn, err = GetConsumerConnection(svc.Reference())
+       require.NoError(t, err)
+       assert.Equal(t, expectedConn, conn)
+}

Reply via email to