This is an automated email from the ASF dual-hosted git repository.
xuetaoli pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git
The following commit(s) were added to refs/heads/develop by this push:
new 334569647 fix(extension): eliminate race-prone global registries in
common/extension (#3264)
334569647 is described below
commit 334569647b57f789d75e44043ba2ab7fc4715cad
Author: xxs <[email protected]>
AuthorDate: Sat Mar 21 16:35:50 2026 +0800
fix(extension): eliminate race-prone global registries in common/extension
(#3264)
* feat(extension): add thread-safe generic Registry[T] container
Fixes: #3247
* fix(extension): migrate filter/protocol/registry to Registry[T]
Fixes: #3247
* fix(extension): migrate cluster/loadbalance/config_center registries
Fixes: #3247
* fix(extension): migrate configurator/config/config_center_factory
Fixes: #3247
* fix(extension): migrate metadata/proxy/tps registries
Fixes: #3247
* fix(extension): migrate auth/rest client and server registries
Fixes: #3247
* fix(extension): migrate config reader/post processor/router factory
Fixes: #3247
* fix(extension): migrate discovery selector logger and otel registries
Fixes: #3247
* fix(extension): guard directory/customizer/shutdown/name mapping state
Fixes: #3247
* style(extension): apply repository formatter output for registry
synchronization files
Fixes: #3247
* test(extension): use assert.Len to satisfy testifylint
Fixes: #3247
* test(extension): add coverage for synchronized registry helper paths
Fixes: #3247
* fix(extension): clarify config center factory registry name
Align registry name with config center factory semantics for clearer
diagnostics in extension registration lookup.
Fixes: #3247
---
common/extension/auth.go | 17 +-
common/extension/cluster.go | 9 +-
common/extension/concurrency_guards_test.go | 229 +++++++++++++++++++++
common/extension/config.go | 4 +-
common/extension/config_center.go | 8 +-
common/extension/config_center_factory.go | 9 +-
common/extension/config_post_processor.go | 12 +-
common/extension/config_reader.go | 15 +-
common/extension/configurator.go | 21 +-
common/extension/filter.go | 25 +--
common/extension/graceful_shutdown.go | 14 +-
common/extension/loadbalance.go | 18 +-
common/extension/logger.go | 8 +-
common/extension/metadata_report_factory.go | 9 +-
common/extension/otel_trace.go | 8 +-
common/extension/protocol.go | 17 +-
common/extension/proxy_factory.go | 9 +-
common/extension/registry.go | 9 +-
common/extension/registry_directory.go | 22 +-
common/extension/registry_type.go | 85 ++++++++
common/extension/registry_type_test.go | 84 ++++++++
common/extension/rest_client.go | 9 +-
common/extension/rest_server.go | 9 +-
common/extension/router_factory.go | 11 +-
common/extension/service_discovery.go | 6 +-
common/extension/service_instance_customizer.go | 12 +-
.../extension/service_instance_selector_factory.go | 6 +-
common/extension/service_name_mapping.go | 14 +-
common/extension/tps_limit.go | 12 +-
29 files changed, 557 insertions(+), 154 deletions(-)
diff --git a/common/extension/auth.go b/common/extension/auth.go
index 430e0697f..05d910dc5 100644
--- a/common/extension/auth.go
+++ b/common/extension/auth.go
@@ -26,34 +26,35 @@ import (
)
var (
- authenticators = make(map[string]func() filter.Authenticator)
- accessKeyStorages = make(map[string]func() filter.AccessKeyStorage)
+ authenticators = NewRegistry[func()
filter.Authenticator]("authenticator")
+ accessKeyStorages = NewRegistry[func() filter.AccessKeyStorage]("access
key storage")
)
// SetAuthenticator puts the @fcn into map with name
func SetAuthenticator(name string, fcn func() filter.Authenticator) {
- authenticators[name] = fcn
+ authenticators.Register(name, fcn)
}
// GetAuthenticator finds the Authenticator with @name
// Panic if not found
func GetAuthenticator(name string) (filter.Authenticator, bool) {
- if authenticators[name] == nil {
+ fcn, ok := authenticators.Get(name)
+ if !ok {
return nil, false
}
- return authenticators[name](), true
+ return fcn(), true
}
// SetAccessKeyStorages will set the @fcn into map with this name
func SetAccessKeyStorages(name string, fcn func() filter.AccessKeyStorage) {
- accessKeyStorages[name] = fcn
+ accessKeyStorages.Register(name, fcn)
}
// GetAccessKeyStorages finds the storage with the @name.
// Panic if not found
func GetAccessKeyStorages(name string) (filter.AccessKeyStorage, error) {
- f := accessKeyStorages[name]
- if f == nil {
+ f, ok := accessKeyStorages.Get(name)
+ if !ok {
return nil, errors.New("accessKeyStorages for " + name + " is
not existing, make sure you have import the package.")
}
return f(), nil
diff --git a/common/extension/cluster.go b/common/extension/cluster.go
index 62db3e9ff..238a3b934 100644
--- a/common/extension/cluster.go
+++ b/common/extension/cluster.go
@@ -30,18 +30,19 @@ import (
"dubbo.apache.org/dubbo-go/v3/common/constant"
)
-var clusters = make(map[string]func() cluster.Cluster)
+var clusters = NewRegistry[func() cluster.Cluster]("cluster")
// SetCluster sets the cluster fault-tolerant mode with @name
// For example: available/failfast/broadcast/failfast/failsafe/...
func SetCluster(name string, fcn func() cluster.Cluster) {
- clusters[name] = fcn
+ clusters.Register(name, fcn)
}
// GetCluster finds the cluster fault-tolerant mode with @name
func GetCluster(name string) (cluster.Cluster, error) {
- if clusters[name] == nil {
+ fcn, ok := clusters.Get(name)
+ if !ok {
return nil,
errors.New(fmt.Sprintf(constant.NonImportErrorMsgFormat,
constant.ClusterKeyFailover))
}
- return clusters[name](), nil
+ return fcn(), nil
}
diff --git a/common/extension/concurrency_guards_test.go
b/common/extension/concurrency_guards_test.go
new file mode 100644
index 000000000..ed2aaf389
--- /dev/null
+++ b/common/extension/concurrency_guards_test.go
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package extension
+
+import (
+ "container/list"
+ "sync"
+ "sync/atomic"
+ "testing"
+)
+
+import (
+ gxset "github.com/dubbogo/gost/container/set"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+import (
+ "dubbo.apache.org/dubbo-go/v3/cluster/directory"
+ "dubbo.apache.org/dubbo-go/v3/common"
+ commonconfig "dubbo.apache.org/dubbo-go/v3/common/config"
+ "dubbo.apache.org/dubbo-go/v3/metadata/mapping"
+ "dubbo.apache.org/dubbo-go/v3/protocol/base"
+ "dubbo.apache.org/dubbo-go/v3/registry"
+)
+
+type mockDir struct{}
+
+func (m *mockDir) GetURL() *common.URL { return
&common.URL{} }
+func (m *mockDir) IsAvailable() bool { return
true }
+func (m *mockDir) Destroy() {}
+func (m *mockDir) List(invocation base.Invocation) []base.Invoker { return nil
}
+func (m *mockDir) Subscribe(url *common.URL) error { return nil
}
+
+type mockCustomizer struct{ priority int }
+
+func (m mockCustomizer) GetPriority() int { return
m.priority }
+func (m mockCustomizer) Customize(instance registry.ServiceInstance) {}
+
+type mockServiceNameMapping struct{}
+
+func (m *mockServiceNameMapping) Map(url *common.URL) error { return nil }
+func (m *mockServiceNameMapping) Get(url *common.URL, listener
mapping.MappingListener) (*gxset.HashSet, error) {
+ return nil, nil
+}
+func (m *mockServiceNameMapping) Remove(url *common.URL) error { return nil }
+
+type mockPostProcessor struct{}
+
+func (m mockPostProcessor) PostProcessReferenceConfig(url *common.URL) {}
+func (m mockPostProcessor) PostProcessServiceConfig(url *common.URL) {}
+
+func TestGetAllCustomShutdownCallbacksReturnsCopy(t *testing.T) {
+ customShutdownCallbacksLock.Lock()
+ original := customShutdownCallbacks
+ customShutdownCallbacks = list.New()
+ customShutdownCallbacksLock.Unlock()
+
+ t.Cleanup(func() {
+ customShutdownCallbacksLock.Lock()
+ customShutdownCallbacks = original
+ customShutdownCallbacksLock.Unlock()
+ })
+
+ AddCustomShutdownCallback(func() {})
+ AddCustomShutdownCallback(func() {})
+
+ callbacks := GetAllCustomShutdownCallbacks()
+ assert.Len(t, asSlice(callbacks), 2)
+
+ callbacks.PushBack(func() {})
+ callbacksAgain := GetAllCustomShutdownCallbacks()
+ assert.Len(t, asSlice(callbacksAgain), 2)
+}
+
+func TestGetDirectoryInstanceUsesProtocolAndFallback(t *testing.T) {
+ originalDirectories := directories
+ directories = NewRegistry[registryDirectory]("registry directory test")
+ t.Cleanup(func() {
+ directories = originalDirectories
+ })
+
+ if oldDefault := defaultDirectory.Load(); oldDefault != nil {
+ t.Cleanup(func() {
defaultDirectory.Store(oldDefault.(registryDirectory)) })
+ } else {
+ t.Cleanup(func() { defaultDirectory = atomic.Value{} })
+ }
+
+ defaultHit := 0
+ protocolHit := 0
+
+ SetDefaultRegistryDirectory(func(url *common.URL, reg
registry.Registry) (directory.Directory, error) {
+ defaultHit++
+ return &mockDir{}, nil
+ })
+
+ SetDirectory("polaris", func(url *common.URL, reg registry.Registry)
(directory.Directory, error) {
+ protocolHit++
+ return &mockDir{}, nil
+ })
+
+ _, err := GetDirectoryInstance(&common.URL{Protocol: "polaris"}, nil)
+ require.NoError(t, err)
+ assert.Equal(t, 1, protocolHit)
+ assert.Equal(t, 0, defaultHit)
+
+ _, err = GetDirectoryInstance(&common.URL{Protocol: ""}, nil)
+ require.NoError(t, err)
+ assert.Equal(t, 1, defaultHit)
+
+ _, err = GetDirectoryInstance(&common.URL{Protocol: "unknown"}, nil)
+ require.NoError(t, err)
+ assert.Equal(t, 2, defaultHit)
+}
+
+func TestCustomizersAreSortedAndReturnedAsCopy(t *testing.T) {
+ customizersLock.Lock()
+ original := customizers
+ customizers = make([]registry.ServiceInstanceCustomizer, 0, 8)
+ customizersLock.Unlock()
+
+ t.Cleanup(func() {
+ customizersLock.Lock()
+ customizers = original
+ customizersLock.Unlock()
+ })
+
+ AddCustomizers(mockCustomizer{priority: 20})
+ AddCustomizers(mockCustomizer{priority: 10})
+
+ got := GetCustomizers()
+ require.Len(t, got, 2)
+ assert.Equal(t, 10, got[0].GetPriority())
+ assert.Equal(t, 20, got[1].GetPriority())
+
+ _ = append(got, mockCustomizer{priority: 1})
+ assert.Len(t, GetCustomizers(), 2)
+}
+
+func TestGlobalServiceNameMappingCreator(t *testing.T) {
+ if old := globalNameMappingCreator.Load(); old != nil {
+ t.Cleanup(func() {
globalNameMappingCreator.Store(old.(ServiceNameMappingCreator)) })
+ }
+
+ expected := &mockServiceNameMapping{}
+ SetGlobalServiceNameMapping(func() mapping.ServiceNameMapping {
+ return expected
+ })
+
+ got := GetGlobalServiceNameMapping()
+ assert.Same(t, expected, got)
+}
+
+func TestConfigPostProcessorRegistrySnapshot(t *testing.T) {
+ originalProcessors := processors
+ processors = NewRegistry[commonconfig.ConfigPostProcessor]("config post
processor test")
+ t.Cleanup(func() {
+ processors = originalProcessors
+ })
+
+ SetConfigPostProcessor("p1", mockPostProcessor{})
+ SetConfigPostProcessor("p2", mockPostProcessor{})
+
+ assert.NotNil(t, GetConfigPostProcessor("p1"))
+ all := GetConfigPostProcessors()
+ assert.Len(t, all, 2)
+}
+
+func TestConcurrentCustomShutdownCallbacksAndCustomizers(t *testing.T) {
+ customShutdownCallbacksLock.Lock()
+ originalCallbacks := customShutdownCallbacks
+ customShutdownCallbacks = list.New()
+ customShutdownCallbacksLock.Unlock()
+
+ customizersLock.Lock()
+ originalCustomizers := customizers
+ customizers = make([]registry.ServiceInstanceCustomizer, 0, 8)
+ customizersLock.Unlock()
+
+ t.Cleanup(func() {
+ customShutdownCallbacksLock.Lock()
+ customShutdownCallbacks = originalCallbacks
+ customShutdownCallbacksLock.Unlock()
+
+ customizersLock.Lock()
+ customizers = originalCustomizers
+ customizersLock.Unlock()
+ })
+
+ var wg sync.WaitGroup
+ for i := 0; i < 20; i++ {
+ wg.Add(1)
+ go func(p int) {
+ defer wg.Done()
+ AddCustomShutdownCallback(func() {})
+ AddCustomizers(mockCustomizer{priority: p})
+ _ = GetAllCustomShutdownCallbacks()
+ _ = GetCustomizers()
+ }(i)
+ }
+ wg.Wait()
+
+ assert.Len(t, asSlice(GetAllCustomShutdownCallbacks()), 20)
+ assert.Len(t, GetCustomizers(), 20)
+}
+
+func asSlice(l *list.List) []any {
+ ret := make([]any, 0, l.Len())
+ for e := l.Front(); e != nil; e = e.Next() {
+ ret = append(ret, e.Value)
+ }
+ return ret
+}
diff --git a/common/extension/config.go b/common/extension/config.go
index fe2828505..da69939c0 100644
--- a/common/extension/config.go
+++ b/common/extension/config.go
@@ -18,7 +18,7 @@
package extension
var (
- configs = map[string]Config{}
+ configs = NewRegistry[Config]("config")
)
type Config interface {
@@ -26,5 +26,5 @@ type Config interface {
}
func SetConfig(c Config) {
- configs[c.Prefix()] = c
+ configs.Register(c.Prefix(), c)
}
diff --git a/common/extension/config_center.go
b/common/extension/config_center.go
index 09c43326b..11f8a3b1a 100644
--- a/common/extension/config_center.go
+++ b/common/extension/config_center.go
@@ -26,17 +26,17 @@ import (
"dubbo.apache.org/dubbo-go/v3/config_center"
)
-var configCenters = make(map[string]func(config *common.URL)
(config_center.DynamicConfiguration, error))
+var configCenters = NewRegistry[func(config *common.URL)
(config_center.DynamicConfiguration, error)]("config center")
// SetConfigCenter sets the DynamicConfiguration with @name
func SetConfigCenter(name string, v func(*common.URL)
(config_center.DynamicConfiguration, error)) {
- configCenters[name] = v
+ configCenters.Register(name, v)
}
// GetConfigCenter finds the DynamicConfiguration with @name
func GetConfigCenter(name string, config *common.URL)
(config_center.DynamicConfiguration, error) {
- configCenterFactory := configCenters[name]
- if configCenterFactory == nil {
+ configCenterFactory, ok := configCenters.Get(name)
+ if !ok {
return nil, errors.New("config center for " + name + " is not
existing, make sure you have import the package.")
}
configCenter, err := configCenterFactory(config)
diff --git a/common/extension/config_center_factory.go
b/common/extension/config_center_factory.go
index 826a1793e..77e0509a3 100644
--- a/common/extension/config_center_factory.go
+++ b/common/extension/config_center_factory.go
@@ -25,17 +25,18 @@ import (
"dubbo.apache.org/dubbo-go/v3/config_center"
)
-var configCenterFactories = make(map[string]func()
config_center.DynamicConfigurationFactory)
+var configCenterFactories = NewRegistry[func()
config_center.DynamicConfigurationFactory]("config center factory")
// SetConfigCenterFactory sets the DynamicConfigurationFactory with @name
func SetConfigCenterFactory(name string, v func()
config_center.DynamicConfigurationFactory) {
- configCenterFactories[name] = v
+ configCenterFactories.Register(name, v)
}
// GetConfigCenterFactory finds the DynamicConfigurationFactory with @name
func GetConfigCenterFactory(name string)
(config_center.DynamicConfigurationFactory, error) {
- if configCenterFactories[name] == nil {
+ v, ok := configCenterFactories.Get(name)
+ if !ok {
return nil, errors.New("config center for " + name + " is not
existing, make sure you have import the package.")
}
- return configCenterFactories[name](), nil
+ return v(), nil
}
diff --git a/common/extension/config_post_processor.go
b/common/extension/config_post_processor.go
index 81df3fa28..2f6becfb2 100644
--- a/common/extension/config_post_processor.go
+++ b/common/extension/config_post_processor.go
@@ -21,22 +21,24 @@ import (
"dubbo.apache.org/dubbo-go/v3/common/config"
)
-var processors = make(map[string]config.ConfigPostProcessor)
+var processors = NewRegistry[config.ConfigPostProcessor]("config post
processor")
// SetConfigPostProcessor registers a ConfigPostProcessor with the given name.
func SetConfigPostProcessor(name string, processor config.ConfigPostProcessor)
{
- processors[name] = processor
+ processors.Register(name, processor)
}
// GetConfigPostProcessor finds a ConfigPostProcessor by name.
func GetConfigPostProcessor(name string) config.ConfigPostProcessor {
- return processors[name]
+ v, _ := processors.Get(name)
+ return v
}
// GetConfigPostProcessors returns all registered instances of
ConfigPostProcessor.
func GetConfigPostProcessors() []config.ConfigPostProcessor {
- ret := make([]config.ConfigPostProcessor, 0, len(processors))
- for _, v := range processors {
+ snapshot := processors.Snapshot()
+ ret := make([]config.ConfigPostProcessor, 0, len(snapshot))
+ for _, v := range snapshot {
ret = append(ret, v)
}
return ret
diff --git a/common/extension/config_reader.go
b/common/extension/config_reader.go
index 6cc2bc13f..e493d3c3f 100644
--- a/common/extension/config_reader.go
+++ b/common/extension/config_reader.go
@@ -22,29 +22,26 @@ import (
)
var (
- configReaders = make(map[string]func() interfaces.ConfigReader)
- defaults = make(map[string]string)
+ configReaders = NewRegistry[func() interfaces.ConfigReader]("config
reader")
+ defaults = NewRegistry[string]("default config reader")
)
// SetConfigReaders sets a creator of config reader with @name
func SetConfigReaders(name string, v func() interfaces.ConfigReader) {
- configReaders[name] = v
+ configReaders.Register(name, v)
}
// GetConfigReaders gets a config reader with @name
func GetConfigReaders(name string) interfaces.ConfigReader {
- if configReaders[name] == nil {
- panic("config reader for " + name + " is not existing, make
sure you have imported the package.")
- }
- return configReaders[name]()
+ return configReaders.MustGet(name)()
}
// SetDefaultConfigReader sets @name for @module in default config reader
func SetDefaultConfigReader(module, name string) {
- defaults[module] = name
+ defaults.Register(module, name)
}
// GetDefaultConfigReader gets default config reader
func GetDefaultConfigReader() map[string]string {
- return defaults
+ return defaults.Snapshot()
}
diff --git a/common/extension/configurator.go b/common/extension/configurator.go
index 5c6fda818..65b11d98b 100644
--- a/common/extension/configurator.go
+++ b/common/extension/configurator.go
@@ -29,38 +29,29 @@ const (
type getConfiguratorFunc func(url *common.URL) config_center.Configurator
-var configurator = make(map[string]getConfiguratorFunc)
+var configurators = NewRegistry[getConfiguratorFunc]("configurator")
// SetConfigurator sets the getConfiguratorFunc with @name
func SetConfigurator(name string, v getConfiguratorFunc) {
- configurator[name] = v
+ configurators.Register(name, v)
}
// GetConfigurator finds the Configurator with @name
func GetConfigurator(name string, url *common.URL) config_center.Configurator {
- if configurator[name] == nil {
- panic("configurator for " + name + " is not existing, make sure
you have import the package.")
- }
- return configurator[name](url)
+ return configurators.MustGet(name)(url)
}
// SetDefaultConfigurator sets the default Configurator
func SetDefaultConfigurator(v getConfiguratorFunc) {
- configurator[DefaultKey] = v
+ configurators.Register(DefaultKey, v)
}
// GetDefaultConfigurator gets default configurator
func GetDefaultConfigurator(url *common.URL) config_center.Configurator {
- if configurator[DefaultKey] == nil {
- panic("configurator for default is not existing, make sure you
have import the package.")
- }
- return configurator[DefaultKey](url)
+ return configurators.MustGet(DefaultKey)(url)
}
// GetDefaultConfiguratorFunc gets default configurator function
func GetDefaultConfiguratorFunc() getConfiguratorFunc {
- if configurator[DefaultKey] == nil {
- panic("configurator for default is not existing, make sure you
have import the package.")
- }
- return configurator[DefaultKey]
+ return configurators.MustGet(DefaultKey)
}
diff --git a/common/extension/filter.go b/common/extension/filter.go
index 826646a08..da0458ee8 100644
--- a/common/extension/filter.go
+++ b/common/extension/filter.go
@@ -26,32 +26,33 @@ import (
)
var (
- filters = make(map[string]func() filter.Filter)
- rejectedExecutionHandler = make(map[string]func()
filter.RejectedExecutionHandler)
+ filters = NewRegistry[func() filter.Filter]("filter")
+ rejectedExecutionHandler = NewRegistry[func()
filter.RejectedExecutionHandler]("rejected execution handler")
)
// SetFilter sets the filter extension with @name
// For example: hystrix/metrics/token/tracing/limit/...
func SetFilter(name string, v func() filter.Filter) {
- filters[name] = v
+ filters.Register(name, v)
}
// GetFilter finds the filter extension with @name
func GetFilter(name string) (filter.Filter, bool) {
- if filters[name] == nil {
+ creator, ok := filters.Get(name)
+ if !ok {
return nil, false
}
- return filters[name](), true
+ return creator(), true
}
// SetRejectedExecutionHandler sets the RejectedExecutionHandler with @name
func SetRejectedExecutionHandler(name string, creator func()
filter.RejectedExecutionHandler) {
- rejectedExecutionHandler[name] = creator
+ rejectedExecutionHandler.Register(name, creator)
}
// GetRejectedExecutionHandler finds the RejectedExecutionHandler with @name
func GetRejectedExecutionHandler(name string)
(filter.RejectedExecutionHandler, error) {
- creator, ok := rejectedExecutionHandler[name]
+ creator, ok := rejectedExecutionHandler.Get(name)
if !ok {
return nil, errors.New("RejectedExecutionHandler for " + name +
" is not existing, make sure you have import the package " +
"and you have register it by invoking
extension.SetRejectedExecutionHandler.")
@@ -62,19 +63,15 @@ func GetRejectedExecutionHandler(name string)
(filter.RejectedExecutionHandler,
// UnregisterFilter removes the filter extension with @name
// This helps prevent memory leaks in dynamic extension scenarios
func UnregisterFilter(name string) {
- delete(filters, name)
+ filters.Unregister(name)
}
// UnregisterRejectedExecutionHandler removes the RejectedExecutionHandler
with @name
func UnregisterRejectedExecutionHandler(name string) {
- delete(rejectedExecutionHandler, name)
+ rejectedExecutionHandler.Unregister(name)
}
// GetAllFilterNames returns all registered filter names
func GetAllFilterNames() []string {
- names := make([]string, 0, len(filters))
- for name := range filters {
- names = append(names, name)
- }
- return names
+ return filters.Names()
}
diff --git a/common/extension/graceful_shutdown.go
b/common/extension/graceful_shutdown.go
index 8c98192b5..1de3e583f 100644
--- a/common/extension/graceful_shutdown.go
+++ b/common/extension/graceful_shutdown.go
@@ -19,9 +19,11 @@ package extension
import (
"container/list"
+ "sync"
)
var customShutdownCallbacks = list.New()
+var customShutdownCallbacksLock sync.Mutex
/**
* AddCustomShutdownCallback
@@ -44,10 +46,20 @@ var customShutdownCallbacks = list.New()
* And it may introduce much complication for another users.
*/
func AddCustomShutdownCallback(callback func()) {
+ customShutdownCallbacksLock.Lock()
+ defer customShutdownCallbacksLock.Unlock()
+
customShutdownCallbacks.PushBack(callback)
}
// GetAllCustomShutdownCallbacks gets all custom shutdown callbacks
func GetAllCustomShutdownCallbacks() *list.List {
- return customShutdownCallbacks
+ customShutdownCallbacksLock.Lock()
+ defer customShutdownCallbacksLock.Unlock()
+
+ ret := list.New()
+ for e := customShutdownCallbacks.Front(); e != nil; e = e.Next() {
+ ret.PushBack(e.Value)
+ }
+ return ret
}
diff --git a/common/extension/loadbalance.go b/common/extension/loadbalance.go
index d04236f37..9c7174e44 100644
--- a/common/extension/loadbalance.go
+++ b/common/extension/loadbalance.go
@@ -21,33 +21,25 @@ import (
"dubbo.apache.org/dubbo-go/v3/cluster/loadbalance"
)
-var loadbalances = make(map[string]func() loadbalance.LoadBalance)
+var loadbalances = NewRegistry[func() loadbalance.LoadBalance]("loadbalance")
// SetLoadbalance sets the loadbalance extension with @name
// For example: random/round_robin/consistent_hash/least_active/...
func SetLoadbalance(name string, fcn func() loadbalance.LoadBalance) {
- loadbalances[name] = fcn
+ loadbalances.Register(name, fcn)
}
// GetLoadbalance finds the loadbalance extension with @name
func GetLoadbalance(name string) loadbalance.LoadBalance {
- if loadbalances[name] == nil {
- panic("loadbalance for " + name + " is not existing, make sure
you have import the package.")
- }
-
- return loadbalances[name]()
+ return loadbalances.MustGet(name)()
}
// UnregisterLoadbalance removes the loadbalance extension with @name
func UnregisterLoadbalance(name string) {
- delete(loadbalances, name)
+ loadbalances.Unregister(name)
}
// GetAllLoadbalanceNames returns all registered loadbalance names
func GetAllLoadbalanceNames() []string {
- names := make([]string, 0, len(loadbalances))
- for name := range loadbalances {
- names = append(names, name)
- }
- return names
+ return loadbalances.Names()
}
diff --git a/common/extension/logger.go b/common/extension/logger.go
index fab8864b1..27487db65 100644
--- a/common/extension/logger.go
+++ b/common/extension/logger.go
@@ -26,15 +26,15 @@ import (
"dubbo.apache.org/dubbo-go/v3/logger"
)
-var logs = make(map[string]func(config *common.URL) (logger.Logger, error))
+var logs = NewRegistry[func(config *common.URL) (logger.Logger,
error)]("logger")
func SetLogger(driver string, log func(config *common.URL) (logger.Logger,
error)) {
- logs[driver] = log
+ logs.Register(driver, log)
}
func GetLogger(driver string, config *common.URL) (logger.Logger, error) {
- if logs[driver] != nil {
- return logs[driver](config)
+ if logCreator, ok := logs.Get(driver); ok {
+ return logCreator(config)
} else {
return nil, errors.Errorf("logger for %s does not exist. "+
"please make sure that you have imported the package "+
diff --git a/common/extension/metadata_report_factory.go
b/common/extension/metadata_report_factory.go
index cd6559873..5644ef5b7 100644
--- a/common/extension/metadata_report_factory.go
+++ b/common/extension/metadata_report_factory.go
@@ -21,17 +21,18 @@ import (
"dubbo.apache.org/dubbo-go/v3/metadata/report"
)
-var metaDataReportFactories = make(map[string]func()
report.MetadataReportFactory, 8)
+var metaDataReportFactories = NewRegistry[func()
report.MetadataReportFactory]("metadata report factory")
// SetMetadataReportFactory sets the MetadataReportFactory with @name
func SetMetadataReportFactory(name string, v func()
report.MetadataReportFactory) {
- metaDataReportFactories[name] = v
+ metaDataReportFactories.Register(name, v)
}
// GetMetadataReportFactory finds the MetadataReportFactory with @name
func GetMetadataReportFactory(name string) report.MetadataReportFactory {
- if metaDataReportFactories[name] == nil {
+ v, ok := metaDataReportFactories.Get(name)
+ if !ok {
return nil
}
- return metaDataReportFactories[name]()
+ return v()
}
diff --git a/common/extension/otel_trace.go b/common/extension/otel_trace.go
index cee4cea80..e84ed3981 100644
--- a/common/extension/otel_trace.go
+++ b/common/extension/otel_trace.go
@@ -29,14 +29,14 @@ import (
"dubbo.apache.org/dubbo-go/v3/otel/trace"
)
-var traceExporterMap = make(map[string]func(config *trace.ExporterConfig)
(trace.Exporter, error), 4)
+var traceExporterMap = NewRegistry[func(config *trace.ExporterConfig)
(trace.Exporter, error)]("trace exporter")
func SetTraceExporter(name string, createFunc func(config
*trace.ExporterConfig) (trace.Exporter, error)) {
- traceExporterMap[name] = createFunc
+ traceExporterMap.Register(name, createFunc)
}
func GetTraceExporter(name string, config *trace.ExporterConfig)
(trace.Exporter, error) {
- createFunc, ok := traceExporterMap[name]
+ createFunc, ok := traceExporterMap.Get(name)
if !ok {
panic("Cannot find the trace provider with name " + name)
}
@@ -45,7 +45,7 @@ func GetTraceExporter(name string, config
*trace.ExporterConfig) (trace.Exporter
func GetTraceShutdownCallback() func() {
return func() {
- for name, createFunc := range traceExporterMap {
+ for name, createFunc := range traceExporterMap.Snapshot() {
if exporter, err := createFunc(nil); err == nil {
if err :=
exporter.GetTracerProvider().Shutdown(context.Background()); err != nil {
logger.Errorf("Graceful shutdown ---
Failed to shutdown trace provider %s, error: %s", name, err.Error())
diff --git a/common/extension/protocol.go b/common/extension/protocol.go
index 6a4f33c91..5dd66c2e9 100644
--- a/common/extension/protocol.go
+++ b/common/extension/protocol.go
@@ -21,32 +21,25 @@ import (
"dubbo.apache.org/dubbo-go/v3/protocol/base"
)
-var protocols = make(map[string]func() base.Protocol)
+var protocols = NewRegistry[func() base.Protocol]("protocol")
// SetProtocol sets the protocol extension with @name
func SetProtocol(name string, v func() base.Protocol) {
- protocols[name] = v
+ protocols.Register(name, v)
}
// GetProtocol finds the protocol extension with @name
func GetProtocol(name string) base.Protocol {
- if protocols[name] == nil {
- panic("protocol for [" + name + "] is not existing, make sure
you have import the package.")
- }
- return protocols[name]()
+ return protocols.MustGet(name)()
}
// UnregisterProtocol removes the protocol extension with @name
// This helps prevent memory leaks in dynamic extension scenarios
func UnregisterProtocol(name string) {
- delete(protocols, name)
+ protocols.Unregister(name)
}
// GetAllProtocolNames returns all registered protocol names
func GetAllProtocolNames() []string {
- names := make([]string, 0, len(protocols))
- for name := range protocols {
- names = append(names, name)
- }
- return names
+ return protocols.Names()
}
diff --git a/common/extension/proxy_factory.go
b/common/extension/proxy_factory.go
index 5dfe0c9da..5a7c2a957 100644
--- a/common/extension/proxy_factory.go
+++ b/common/extension/proxy_factory.go
@@ -25,11 +25,11 @@ import (
"dubbo.apache.org/dubbo-go/v3/proxy"
)
-var proxyFactories = make(map[string]func(...proxy.Option) proxy.ProxyFactory)
+var proxyFactories = NewRegistry[func(...proxy.Option)
proxy.ProxyFactory]("proxy factory")
// SetProxyFactory sets the ProxyFactory extension with @name
func SetProxyFactory(name string, f func(...proxy.Option) proxy.ProxyFactory) {
- proxyFactories[name] = f
+ proxyFactories.Register(name, f)
}
// GetProxyFactory finds the ProxyFactory extension with @name
@@ -37,9 +37,10 @@ func GetProxyFactory(name string) proxy.ProxyFactory {
if name == "" {
name = "default"
}
- if proxyFactories[name] == nil {
+ v, ok := proxyFactories.Get(name)
+ if !ok {
logger.Warn("proxy factory for " + name + " is not existing,
make sure you have import the package.")
return nil
}
- return proxyFactories[name]()
+ return v()
}
diff --git a/common/extension/registry.go b/common/extension/registry.go
index abce206f2..1320daea6 100644
--- a/common/extension/registry.go
+++ b/common/extension/registry.go
@@ -22,17 +22,18 @@ import (
"dubbo.apache.org/dubbo-go/v3/registry"
)
-var registries = make(map[string]func(config *common.URL) (registry.Registry,
error))
+var registries = NewRegistry[func(config *common.URL) (registry.Registry,
error)]("registry")
// SetRegistry sets the registry extension with @name
func SetRegistry(name string, v func(_ *common.URL) (registry.Registry,
error)) {
- registries[name] = v
+ registries.Register(name, v)
}
// GetRegistry finds the registry extension with @name
func GetRegistry(name string, config *common.URL) (registry.Registry, error) {
- if registries[name] == nil {
+ v, ok := registries.Get(name)
+ if !ok {
panic("registry for " + name + " does not exist. please make
sure that you have imported the package dubbo.apache.org/dubbo-go/v3/registry/"
+ name + ".")
}
- return registries[name](config)
+ return v(config)
}
diff --git a/common/extension/registry_directory.go
b/common/extension/registry_directory.go
index d886c6d73..6ecb21d6c 100644
--- a/common/extension/registry_directory.go
+++ b/common/extension/registry_directory.go
@@ -17,6 +17,10 @@
package extension
+import (
+ "sync/atomic"
+)
+
import (
"github.com/dubbogo/gost/log/logger"
)
@@ -29,25 +33,26 @@ import (
type registryDirectory func(url *common.URL, registry registry.Registry)
(directory.Directory, error)
-var directories = make(map[string]registryDirectory)
-var defaultDirectory registryDirectory
+var directories = NewRegistry[registryDirectory]("registry directory")
+var defaultDirectory atomic.Value
// SetDefaultRegistryDirectory sets the default registryDirectory
func SetDefaultRegistryDirectory(v registryDirectory) {
- defaultDirectory = v
+ defaultDirectory.Store(v)
}
// SetDirectory sets the default registryDirectory
func SetDirectory(key string, v registryDirectory) {
- directories[key] = v
+ directories.Register(key, v)
}
// GetDefaultRegistryDirectory finds the registryDirectory with url and
registry
func GetDefaultRegistryDirectory(url *common.URL, registry registry.Registry)
(directory.Directory, error) {
- if defaultDirectory == nil {
+ v := defaultDirectory.Load()
+ if v == nil {
panic("registry directory is not existing, make sure you have
import the package.")
}
- return defaultDirectory(url, registry)
+ return v.(registryDirectory)(url, registry)
}
// GetDirectoryInstance finds the registryDirectory with url and registry
@@ -56,9 +61,10 @@ func GetDirectoryInstance(url *common.URL, registry
registry.Registry) (director
if key == "" {
return GetDefaultRegistryDirectory(url, registry)
}
- if directories[key] == nil {
+ v, ok := directories.Get(key)
+ if !ok {
logger.Warn("registry directory " + key + " does not exist,
make sure you have import the package, will use the default directory type.")
return GetDefaultRegistryDirectory(url, registry)
}
- return directories[key](url, registry)
+ return v(url, registry)
}
diff --git a/common/extension/registry_type.go
b/common/extension/registry_type.go
new file mode 100644
index 000000000..9e601998d
--- /dev/null
+++ b/common/extension/registry_type.go
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package extension
+
+import (
+ "sync"
+)
+
+// Registry is a thread-safe generic container for extension registrations.
+type Registry[T any] struct {
+ mu sync.RWMutex
+ items map[string]T
+ name string
+}
+
+func NewRegistry[T any](name string) *Registry[T] {
+ return &Registry[T]{
+ items: make(map[string]T),
+ name: name,
+ }
+}
+
+func (r *Registry[T]) Register(name string, v T) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.items[name] = v
+}
+
+func (r *Registry[T]) Get(name string) (T, bool) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ v, ok := r.items[name]
+ return v, ok
+}
+
+func (r *Registry[T]) MustGet(name string) T {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ v, ok := r.items[name]
+ if !ok {
+ panic(r.name + " for [" + name + "] is not existing, make sure
you have import the package.")
+ }
+ return v
+}
+
+func (r *Registry[T]) Unregister(name string) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ delete(r.items, name)
+}
+
+func (r *Registry[T]) Snapshot() map[string]T {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ m := make(map[string]T, len(r.items))
+ for k, v := range r.items {
+ m[k] = v
+ }
+ return m
+}
+
+func (r *Registry[T]) Names() []string {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ names := make([]string, 0, len(r.items))
+ for k := range r.items {
+ names = append(names, k)
+ }
+ return names
+}
diff --git a/common/extension/registry_type_test.go
b/common/extension/registry_type_test.go
new file mode 100644
index 000000000..0a7aedadc
--- /dev/null
+++ b/common/extension/registry_type_test.go
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package extension
+
+import (
+ "strconv"
+ "sync"
+ "testing"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+func TestRegistryBasicOps(t *testing.T) {
+ r := NewRegistry[int]("test-registry")
+
+ _, ok := r.Get("missing")
+ assert.False(t, ok)
+
+ r.Register("a", 1)
+ v, ok := r.Get("a")
+ assert.True(t, ok)
+ assert.Equal(t, 1, v)
+
+ must := r.MustGet("a")
+ assert.Equal(t, 1, must)
+
+ snapshot := r.Snapshot()
+ assert.Equal(t, map[string]int{"a": 1}, snapshot)
+ snapshot["a"] = 99
+
+ v, ok = r.Get("a")
+ assert.True(t, ok)
+ assert.Equal(t, 1, v)
+
+ names := r.Names()
+ assert.Len(t, names, 1)
+ assert.Equal(t, "a", names[0])
+
+ r.Unregister("a")
+ _, ok = r.Get("a")
+ assert.False(t, ok)
+}
+
+func TestRegistryConcurrentAccess(t *testing.T) {
+ r := NewRegistry[int]("concurrent")
+
+ var wg sync.WaitGroup
+ workers := 32
+ iterations := 200
+
+ for i := 0; i < workers; i++ {
+ wg.Add(1)
+ go func(worker int) {
+ defer wg.Done()
+ for j := 0; j < iterations; j++ {
+ key := "k-" + strconv.Itoa(worker) + "-" +
strconv.Itoa(j)
+ r.Register(key, j)
+ _, _ = r.Get(key)
+ _ = r.Snapshot()
+ _ = r.Names()
+ r.Unregister(key)
+ }
+ }(i)
+ }
+
+ wg.Wait()
+}
diff --git a/common/extension/rest_client.go b/common/extension/rest_client.go
index bb04c53ae..83a3dc6e2 100644
--- a/common/extension/rest_client.go
+++ b/common/extension/rest_client.go
@@ -21,17 +21,14 @@ import (
"dubbo.apache.org/dubbo-go/v3/protocol/rest/client"
)
-var restClients = make(map[string]func(restOptions *client.RestOptions)
client.RestClient, 8)
+var restClients = NewRegistry[func(restOptions *client.RestOptions)
client.RestClient]("rest client")
// SetRestClient sets the RestClient with @name
func SetRestClient(name string, fun func(_ *client.RestOptions)
client.RestClient) {
- restClients[name] = fun
+ restClients.Register(name, fun)
}
// GetNewRestClient finds the RestClient with @name
func GetNewRestClient(name string, restOptions *client.RestOptions)
client.RestClient {
- if restClients[name] == nil {
- panic("restClient for " + name + " is not existing, make sure
you have import the package.")
- }
- return restClients[name](restOptions)
+ return restClients.MustGet(name)(restOptions)
}
diff --git a/common/extension/rest_server.go b/common/extension/rest_server.go
index 34eb408dd..536d13d2c 100644
--- a/common/extension/rest_server.go
+++ b/common/extension/rest_server.go
@@ -21,17 +21,14 @@ import (
"dubbo.apache.org/dubbo-go/v3/protocol/rest/server"
)
-var restServers = make(map[string]func() server.RestServer, 8)
+var restServers = NewRegistry[func() server.RestServer]("rest server")
// SetRestServer sets the RestServer with @name
func SetRestServer(name string, fun func() server.RestServer) {
- restServers[name] = fun
+ restServers.Register(name, fun)
}
// GetNewRestServer finds the RestServer with @name
func GetNewRestServer(name string) server.RestServer {
- if restServers[name] == nil {
- panic("restServer for " + name + " is not existing, make sure
you have import the package.")
- }
- return restServers[name]()
+ return restServers.MustGet(name)()
}
diff --git a/common/extension/router_factory.go
b/common/extension/router_factory.go
index 454fe00df..945c088c3 100644
--- a/common/extension/router_factory.go
+++ b/common/extension/router_factory.go
@@ -22,23 +22,20 @@ import (
)
var (
- routers = make(map[string]func() router.PriorityRouterFactory)
+ routers = NewRegistry[func() router.PriorityRouterFactory]("router
factory")
)
// SetRouterFactory sets create router factory function with @name
func SetRouterFactory(name string, fun func() router.PriorityRouterFactory) {
- routers[name] = fun
+ routers.Register(name, fun)
}
// GetRouterFactory gets create router factory function by @name
func GetRouterFactory(name string) router.PriorityRouterFactory {
- if routers[name] == nil {
- panic("router_factory for " + name + " is not existing, make
sure you have import the package.")
- }
- return routers[name]()
+ return routers.MustGet(name)()
}
// GetRouterFactories gets all create router factory function
func GetRouterFactories() map[string]func() router.PriorityRouterFactory {
- return routers
+ return routers.Snapshot()
}
diff --git a/common/extension/service_discovery.go
b/common/extension/service_discovery.go
index 9f4fdf6ce..2caad1005 100644
--- a/common/extension/service_discovery.go
+++ b/common/extension/service_discovery.go
@@ -27,13 +27,13 @@ import (
"dubbo.apache.org/dubbo-go/v3/registry"
)
-var discoveryCreatorMap = make(map[string]func(url *common.URL)
(registry.ServiceDiscovery, error), 4)
+var discoveryCreatorMap = NewRegistry[func(url *common.URL)
(registry.ServiceDiscovery, error)]("service discovery")
// SetServiceDiscovery will store the @creator and @name
// protocol indicate the implementation, like nacos
// the name like nacos-1...
func SetServiceDiscovery(protocol string, creator func(url *common.URL)
(registry.ServiceDiscovery, error)) {
- discoveryCreatorMap[protocol] = creator
+ discoveryCreatorMap.Register(protocol, creator)
}
// GetServiceDiscovery will return the registry.ServiceDiscovery
@@ -42,7 +42,7 @@ func SetServiceDiscovery(protocol string, creator func(url
*common.URL) (registr
// if not found, or initialize instance failed, it will return error.
func GetServiceDiscovery(url *common.URL) (registry.ServiceDiscovery, error) {
protocol := url.GetParam(constant.RegistryKey, "")
- creator, ok := discoveryCreatorMap[protocol]
+ creator, ok := discoveryCreatorMap.Get(protocol)
if !ok {
return nil, perrors.New("Could not find the service discovery
with discovery protocol: " + protocol)
}
diff --git a/common/extension/service_instance_customizer.go
b/common/extension/service_instance_customizer.go
index 1129d578f..0c3a2f561 100644
--- a/common/extension/service_instance_customizer.go
+++ b/common/extension/service_instance_customizer.go
@@ -19,6 +19,7 @@ package extension
import (
"sort"
+ "sync"
)
import (
@@ -26,10 +27,14 @@ import (
)
var customizers = make([]registry.ServiceInstanceCustomizer, 0, 8)
+var customizersLock sync.RWMutex
// AddCustomizers will put the customizer into slices and then sort them;
// this method will be invoked several time, so we sort them here.
func AddCustomizers(cus registry.ServiceInstanceCustomizer) {
+ customizersLock.Lock()
+ defer customizersLock.Unlock()
+
customizers = append(customizers, cus)
sort.Stable(customizerSlice(customizers))
}
@@ -37,7 +42,12 @@ func AddCustomizers(cus registry.ServiceInstanceCustomizer) {
// GetCustomizers will return the sorted customizer
// the result won't be nil
func GetCustomizers() []registry.ServiceInstanceCustomizer {
- return customizers
+ customizersLock.RLock()
+ defer customizersLock.RUnlock()
+
+ ret := make([]registry.ServiceInstanceCustomizer, len(customizers))
+ copy(ret, customizers)
+ return ret
}
type customizerSlice []registry.ServiceInstanceCustomizer
diff --git a/common/extension/service_instance_selector_factory.go
b/common/extension/service_instance_selector_factory.go
index e7c636501..13cdd991e 100644
--- a/common/extension/service_instance_selector_factory.go
+++ b/common/extension/service_instance_selector_factory.go
@@ -25,17 +25,17 @@ import (
"dubbo.apache.org/dubbo-go/v3/registry/servicediscovery/instance"
)
-var serviceInstanceSelectorMappings = make(map[string]func()
instance.ServiceInstanceSelector, 2)
+var serviceInstanceSelectorMappings = NewRegistry[func()
instance.ServiceInstanceSelector]("service instance selector")
// SetServiceInstanceSelector registers a factory for ServiceInstanceSelector.
func SetServiceInstanceSelector(name string, f func()
instance.ServiceInstanceSelector) {
- serviceInstanceSelectorMappings[name] = f
+ serviceInstanceSelectorMappings.Register(name, f)
}
// GetServiceInstanceSelector will create an instance
// it will panic if selector with the @name not found
func GetServiceInstanceSelector(name string)
(instance.ServiceInstanceSelector, error) {
- serviceInstanceSelector, ok := serviceInstanceSelectorMappings[name]
+ serviceInstanceSelector, ok := serviceInstanceSelectorMappings.Get(name)
if !ok {
return nil, perrors.New("Could not find service instance
selector with" +
"name:" + name)
diff --git a/common/extension/service_name_mapping.go
b/common/extension/service_name_mapping.go
index 5089998b9..e64a9f64c 100644
--- a/common/extension/service_name_mapping.go
+++ b/common/extension/service_name_mapping.go
@@ -17,18 +17,26 @@
package extension
+import (
+ "sync/atomic"
+)
+
import (
"dubbo.apache.org/dubbo-go/v3/metadata/mapping"
)
type ServiceNameMappingCreator func() mapping.ServiceNameMapping
-var globalNameMappingCreator ServiceNameMappingCreator
+var globalNameMappingCreator atomic.Value
func SetGlobalServiceNameMapping(nameMappingCreator ServiceNameMappingCreator)
{
- globalNameMappingCreator = nameMappingCreator
+ globalNameMappingCreator.Store(nameMappingCreator)
}
func GetGlobalServiceNameMapping() mapping.ServiceNameMapping {
- return globalNameMappingCreator()
+ v := globalNameMappingCreator.Load()
+ if v == nil {
+ panic("global service name mapping creator is not existing")
+ }
+ return v.(ServiceNameMappingCreator)()
}
diff --git a/common/extension/tps_limit.go b/common/extension/tps_limit.go
index 8f678d18b..7acfa9779 100644
--- a/common/extension/tps_limit.go
+++ b/common/extension/tps_limit.go
@@ -26,18 +26,18 @@ import (
)
var (
- tpsLimitStrategy = make(map[string]filter.TpsLimitStrategyCreator)
- tpsLimiter = make(map[string]func() filter.TpsLimiter)
+ tpsLimitStrategy = NewRegistry[filter.TpsLimitStrategyCreator]("tps
limit strategy")
+ tpsLimiter = NewRegistry[func() filter.TpsLimiter]("tps limiter")
)
// SetTpsLimiter sets the TpsLimiter with @name
func SetTpsLimiter(name string, creator func() filter.TpsLimiter) {
- tpsLimiter[name] = creator
+ tpsLimiter.Register(name, creator)
}
// GetTpsLimiter finds the TpsLimiter with @name
func GetTpsLimiter(name string) (filter.TpsLimiter, error) {
- creator, ok := tpsLimiter[name]
+ creator, ok := tpsLimiter.Get(name)
if !ok {
return nil, errors.New("TpsLimiter for " + name + " is not
existing, make sure you have import the package " +
"and you have register it by invoking
extension.SetTpsLimiter.")
@@ -47,12 +47,12 @@ func GetTpsLimiter(name string) (filter.TpsLimiter, error) {
// SetTpsLimitStrategy sets the TpsLimitStrategyCreator with @name
func SetTpsLimitStrategy(name string, creator filter.TpsLimitStrategyCreator) {
- tpsLimitStrategy[name] = creator
+ tpsLimitStrategy.Register(name, creator)
}
// GetTpsLimitStrategyCreator finds the TpsLimitStrategyCreator with @name
func GetTpsLimitStrategyCreator(name string) (filter.TpsLimitStrategyCreator,
error) {
- creator, ok := tpsLimitStrategy[name]
+ creator, ok := tpsLimitStrategy.Get(name)
if !ok {
return nil, errors.New("TpsLimitStrategy for " + name + " is
not existing, make sure you have import the package " +
"and you have register it by invoking
extension.SetTpsLimitStrategy.")