This is an automated email from the ASF dual-hosted git repository. zfeng pushed a commit to branch feature/saga in repository https://gitbox.apache.org/repos/asf/incubator-seata-go.git
The following commit(s) were added to refs/heads/feature/saga by this push: new fcc1cf25 Feature/add script manager&actuator (#868) fcc1cf25 is described below commit fcc1cf2527a7defecee7ddb4a6854c8a6093e02e Author: flypiggy <150204110+flypiggyyoy...@users.noreply.github.com> AuthorDate: Tue Aug 12 10:15:10 2025 +0800 Feature/add script manager&actuator (#868) * add script manager & actuator * Otto package replaces Goja package * Resolve the issues raised by Copilot * add test for vm-pool * update GetServiceInvoker method --- go.mod | 10 +- go.sum | 16 +- .../engine/config/default_statemachine_config.go | 30 ++- pkg/saga/statemachine/engine/invoker/invoker.go | 169 +++---------- .../engine/invoker/javascript_script_invoker.go | 162 +++++++++++++ .../invoker/javascript_script_invoker_test.go | 262 +++++++++++++++++++++ .../invoker/{invoker.go => local_invoker.go} | 119 ++++------ .../engine/invoker/local_invoker_test.go | 212 +++++++++++++++++ 8 files changed, 762 insertions(+), 218 deletions(-) diff --git a/go.mod b/go.mod index f886c5ef..42949982 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/seata/seata-go -go 1.18 +go 1.20 require ( dubbo.apache.org/dubbo-go/v3 v3.0.4 @@ -35,7 +35,8 @@ require ( github.com/agiledragon/gomonkey/v2 v2.9.0 github.com/google/cel-go v0.18.0 github.com/mattn/go-sqlite3 v1.14.19 - golang.org/x/sync v0.6.0 + github.com/robertkrimen/otto v0.4.0 + golang.org/x/sync v0.16.0 google.golang.org/protobuf v1.33.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -90,8 +91,9 @@ require ( github.com/yusufpapurcu/wmi v1.2.2 // indirect go.uber.org/multierr v1.8.0 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/text v0.27.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect + gopkg.in/sourcemap.v1 v1.0.5 // indirect ) require ( @@ -106,7 +108,7 @@ require ( golang.org/x/crypto v0.17.0 // indirect golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/sys v0.32.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5 // indirect vimagination.zapto.org/memio v0.0.0-20200222190306-588ebc67b97d // indirect diff --git a/go.sum b/go.sum index 68d4931f..61c5eafd 100644 --- a/go.sum +++ b/go.sum @@ -672,6 +672,8 @@ github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40T github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rhnvrm/simples3 v0.6.1/go.mod h1:Y+3vYm2V7Y4VijFoJHHTrja6OgPrJ2cBti8dPGkC3sA= +github.com/robertkrimen/otto v0.4.0 h1:/c0GRrK1XDPcgIasAsnlpBT5DelIeB9U/Z/JCQsgr7E= +github.com/robertkrimen/otto v0.4.0/go.mod h1:uW9yN1CYflmUQYvAMS0m+ZiNo3dMzRUDQJX0jWbzgxw= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -952,8 +954,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1031,8 +1033,8 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1045,8 +1047,8 @@ golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1251,6 +1253,8 @@ gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= +gopkg.in/sourcemap.v1 v1.0.5 h1:inv58fC9f9J3TK2Y2R1NPntXEn3/wjWHkonhIUODNTI= +gopkg.in/sourcemap.v1 v1.0.5/go.mod h1:2RlvNNSMglmRrcvhfuzp4hQHwOtjxlbjX7UPY/GXb78= gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= diff --git a/pkg/saga/statemachine/engine/config/default_statemachine_config.go b/pkg/saga/statemachine/engine/config/default_statemachine_config.go index 07e4f953..f2902572 100644 --- a/pkg/saga/statemachine/engine/config/default_statemachine_config.go +++ b/pkg/saga/statemachine/engine/config/default_statemachine_config.go @@ -292,8 +292,17 @@ func (c *DefaultStateMachineConfig) GetExpressionFactory(expressionType string) return c.expressionFactoryManager.GetExpressionFactory(expressionType) } -func (c *DefaultStateMachineConfig) GetServiceInvoker(serviceType string) invoker.ServiceInvoker { - return c.serviceInvokerManager.ServiceInvoker(serviceType) +func (c *DefaultStateMachineConfig) GetServiceInvoker(serviceType string) (invoker.ServiceInvoker, error) { + if serviceType == "" { + serviceType = "local" + } + + invoker := c.serviceInvokerManager.ServiceInvoker(serviceType) + if invoker == nil { + return nil, fmt.Errorf("service invoker not found for type: %s", serviceType) + } + + return invoker, nil } func (c *DefaultStateMachineConfig) RegisterStateMachineDef(resources []string) error { @@ -492,9 +501,20 @@ func (c *DefaultStateMachineConfig) initServiceInvokers() error { c.serviceInvokerManager = invoker.NewServiceInvokerManagerImpl() } - defaultServiceType := "local" - if existingInvoker := c.serviceInvokerManager.ServiceInvoker(defaultServiceType); existingInvoker == nil { - c.RegisterServiceInvoker(defaultServiceType, invoker.NewLocalServiceInvoker()) + if existing := c.serviceInvokerManager.ServiceInvoker("local"); existing == nil { + c.RegisterServiceInvoker("local", invoker.NewLocalServiceInvoker()) + } + + if existing := c.serviceInvokerManager.ServiceInvoker("http"); existing == nil { + c.RegisterServiceInvoker("http", invoker.NewHTTPInvoker()) + } + + if existing := c.serviceInvokerManager.ServiceInvoker("grpc"); existing == nil { + c.RegisterServiceInvoker("grpc", invoker.NewGRPCInvoker()) + } + + if existing := c.serviceInvokerManager.ServiceInvoker("func"); existing == nil { + c.RegisterServiceInvoker("func", invoker.NewFuncInvoker()) } return nil diff --git a/pkg/saga/statemachine/engine/invoker/invoker.go b/pkg/saga/statemachine/engine/invoker/invoker.go index d84798b1..40acfc1f 100644 --- a/pkg/saga/statemachine/engine/invoker/invoker.go +++ b/pkg/saga/statemachine/engine/invoker/invoker.go @@ -20,7 +20,6 @@ package invoker import ( "context" "encoding/json" - "fmt" "reflect" "sync" @@ -43,164 +42,72 @@ func (p *DefaultJsonParser) Marshal(v any) ([]byte, error) { } type ScriptInvokerManager interface { + GetInvoker(scriptType string) (ScriptInvoker, error) + RegisterInvoker(invoker ScriptInvoker) + Execute(ctx context.Context, scriptType string, script string, params map[string]interface{}) (interface{}, error) } type ScriptInvoker interface { -} - -type ServiceInvokerManager interface { - ServiceInvoker(serviceType string) ServiceInvoker - PutServiceInvoker(serviceType string, invoker ServiceInvoker) -} - -type ServiceInvoker interface { - Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error) + Invoke(ctx context.Context, script string, params map[string]interface{}) (interface{}, error) + Type() string Close(ctx context.Context) error } -type ServiceInvokerManagerImpl struct { - invokers map[string]ServiceInvoker +type ScriptInvokerManagerImpl struct { + invokers map[string]ScriptInvoker mutex sync.Mutex } -type LocalServiceInvoker struct { - serviceRegistry map[string]interface{} - methodCache map[string]*reflect.Method - jsonParser JsonParser - mutex sync.RWMutex -} - -func NewLocalServiceInvoker() *LocalServiceInvoker { - return &LocalServiceInvoker{ - serviceRegistry: make(map[string]interface{}), - methodCache: make(map[string]*reflect.Method), - jsonParser: &DefaultJsonParser{}, - } -} - -func (l *LocalServiceInvoker) RegisterService(serviceName string, instance interface{}) { - l.mutex.Lock() - defer l.mutex.Unlock() - l.serviceRegistry[serviceName] = instance -} - -func (l *LocalServiceInvoker) Invoke(ctx context.Context, input []any, service state.ServiceTaskState) ([]reflect.Value, error) { - serviceName := service.ServiceName() - instance, exists := l.serviceRegistry[serviceName] - if !exists { - return nil, fmt.Errorf("service %s not registered", serviceName) - } - - methodName := service.ServiceMethod() - method, err := l.getMethod(serviceName, methodName, service.ParameterTypes()) - if err != nil { - return nil, err - } - - params, err := l.resolveParameters(input, method.Type) - if err != nil { - return nil, err +func NewScriptInvokerManager() *ScriptInvokerManagerImpl { + return &ScriptInvokerManagerImpl{ + invokers: make(map[string]ScriptInvoker), } - - return l.invokeMethod(instance, method, params), nil } -func (l *LocalServiceInvoker) resolveMethod(key, serviceName, methodName string) (*reflect.Method, error) { - l.mutex.Lock() - defer l.mutex.Unlock() - - if cachedMethod, ok := l.methodCache[key]; ok { - return cachedMethod, nil +func (m *ScriptInvokerManagerImpl) GetInvoker(scriptType string) (ScriptInvoker, error) { + if scriptType == "" { + return nil, nil } + m.mutex.Lock() + defer m.mutex.Unlock() - instance, exists := l.serviceRegistry[serviceName] + invoker, exists := m.invokers[scriptType] if !exists { - return nil, fmt.Errorf("service %s not found", serviceName) - } - - objType := reflect.TypeOf(instance) - method, ok := objType.MethodByName(methodName) - if !ok { - return nil, fmt.Errorf("method %s not found in service %s", methodName, serviceName) + return nil, nil } - - l.methodCache[key] = &method - return &method, nil + return invoker, nil } -func (l *LocalServiceInvoker) getMethod(serviceName, methodName string, paramTypes []string) (*reflect.Method, error) { - key := fmt.Sprintf("%s.%s", serviceName, methodName) - - l.mutex.RLock() - if method, ok := l.methodCache[key]; ok { - l.mutex.RUnlock() - return method, nil +func (m *ScriptInvokerManagerImpl) RegisterInvoker(invoker ScriptInvoker) { + if invoker == nil || invoker.Type() == "" { + return } - l.mutex.RUnlock() - - return l.resolveMethod(key, serviceName, methodName) + m.mutex.Lock() + defer m.mutex.Unlock() + m.invokers[invoker.Type()] = invoker } -func (l *LocalServiceInvoker) resolveParameters(input []any, methodType reflect.Type) ([]reflect.Value, error) { - params := make([]reflect.Value, methodType.NumIn()) - for i := 0; i < methodType.NumIn(); i++ { - paramType := methodType.In(i) - if i >= len(input) { - params[i] = reflect.Zero(paramType) - continue - } - - converted, err := l.convertParam(input[i], paramType) - if err != nil { - return nil, err - } - params[i] = reflect.ValueOf(converted) +func (m *ScriptInvokerManagerImpl) Execute(ctx context.Context, scriptType string, script string, params map[string]interface{}) (interface{}, error) { + invoker, err := m.GetInvoker(scriptType) + if err != nil || invoker == nil { + return nil, err } - return params, nil + return invoker.Invoke(ctx, script, params) } -func (l *LocalServiceInvoker) convertParam(value any, targetType reflect.Type) (any, error) { - if targetType.Kind() == reflect.Ptr { - targetType = targetType.Elem() - value = reflect.ValueOf(value).Interface() - } - - if targetType.Kind() == reflect.Int && reflect.TypeOf(value).Kind() == reflect.Float64 { - return int(value.(float64)), nil - } else if targetType == reflect.TypeOf("") && reflect.TypeOf(value).Kind() == reflect.Int { - return fmt.Sprintf("%d", value), nil - } - - if targetType.Kind() == reflect.Struct { - jsonData, err := l.jsonParser.Marshal(value) - if err != nil { - return nil, err - } - instance := reflect.New(targetType).Interface() - if err := l.jsonParser.Unmarshal(jsonData, instance); err != nil { - return nil, err - } - return instance, nil - } - - return value, nil +type ServiceInvokerManager interface { + ServiceInvoker(serviceType string) ServiceInvoker + PutServiceInvoker(serviceType string, invoker ServiceInvoker) } -func (l *LocalServiceInvoker) invokeMethod(instance interface{}, method *reflect.Method, params []reflect.Value) []reflect.Value { - instanceValue := reflect.ValueOf(instance) - if method.Func.IsValid() { - allParams := append([]reflect.Value{instanceValue}, params...) - return method.Func.Call(allParams) - } - return nil +type ServiceInvoker interface { + Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error) + Close(ctx context.Context) error } -func (l *LocalServiceInvoker) Close(ctx context.Context) error { - l.mutex.Lock() - defer l.mutex.Unlock() - l.serviceRegistry = nil - l.methodCache = nil - return nil +type ServiceInvokerManagerImpl struct { + invokers map[string]ServiceInvoker + mutex sync.Mutex } func NewServiceInvokerManagerImpl() *ServiceInvokerManagerImpl { diff --git a/pkg/saga/statemachine/engine/invoker/javascript_script_invoker.go b/pkg/saga/statemachine/engine/invoker/javascript_script_invoker.go new file mode 100644 index 00000000..7817494d --- /dev/null +++ b/pkg/saga/statemachine/engine/invoker/javascript_script_invoker.go @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package invoker + +import ( + "context" + "fmt" + "sync" + + "github.com/robertkrimen/otto" +) + +const defaultPoolSize = 10 + +type JavaScriptScriptInvoker struct { + mutex sync.Mutex + jsonParser JsonParser + closed bool + vmPool chan *otto.Otto + poolSize int +} + +func NewJavaScriptScriptInvoker() *JavaScriptScriptInvoker { + return &JavaScriptScriptInvoker{ + jsonParser: &DefaultJsonParser{}, + closed: false, + poolSize: defaultPoolSize, + vmPool: make(chan *otto.Otto, defaultPoolSize), + } +} + +func NewJavaScriptScriptInvokerWithPoolSize(poolSize int) *JavaScriptScriptInvoker { + if poolSize <= 0 { + poolSize = defaultPoolSize + } + return &JavaScriptScriptInvoker{ + jsonParser: &DefaultJsonParser{}, + closed: false, + poolSize: poolSize, + vmPool: make(chan *otto.Otto, poolSize), + } +} + +func (j *JavaScriptScriptInvoker) Type() string { + return "javascript" +} + +func (j *JavaScriptScriptInvoker) Invoke(ctx context.Context, script string, params map[string]interface{}) (interface{}, error) { + j.mutex.Lock() + closed := j.closed + j.mutex.Unlock() + + if closed { + return nil, fmt.Errorf("javascript invoker has been closed") + } + + var vm *otto.Otto + select { + case vm = <-j.vmPool: + if err := cleanVMState(vm); err != nil { + vm = otto.New() + } + default: + vm = otto.New() + } + + defer func() { + j.mutex.Lock() + defer j.mutex.Unlock() + if !j.closed { + select { + case j.vmPool <- vm: + default: + // Pool full, discard current instance + } + } + }() + + for key, value := range params { + if err := vm.Set(key, value); err != nil { + return nil, fmt.Errorf("javascript set param %s error: %w", key, err) + } + } + + resultChan := make(chan struct { + val otto.Value + err error + }, 1) + + go func() { + defer func() { + if r := recover(); r != nil { + resultChan <- struct { + val otto.Value + err error + }{otto.UndefinedValue(), fmt.Errorf("javascript engine panic: %v", r)} + } + }() + + val, err := vm.Run(script) + resultChan <- struct { + val otto.Value + err error + }{val, err} + }() + + select { + case <-ctx.Done(): + return nil, fmt.Errorf("javascript execution timeout: %w", ctx.Err()) + case res := <-resultChan: + if res.err != nil { + return nil, fmt.Errorf("javascript execute error: %w", res.err) + } + val, err := res.val.Export() + if err != nil { + return nil, fmt.Errorf("failed to export javascript result: %w", err) + } + return val, nil + } +} + +func (j *JavaScriptScriptInvoker) Close(ctx context.Context) error { + j.mutex.Lock() + defer j.mutex.Unlock() + + if j.closed { + return nil + } + + j.closed = true + close(j.vmPool) + for range j.vmPool { + // Let GC recycle VM resources + } + return nil +} + +func cleanVMState(vm *otto.Otto) error { + _, err := vm.Run(` + for (const prop in global) { + if (!['Object', 'Array', 'Function', 'String', 'Number', 'Boolean', 'JSON', 'Date', 'RegExp'].includes(prop)) { + delete global[prop]; + } + } + `) + return err +} diff --git a/pkg/saga/statemachine/engine/invoker/javascript_script_invoker_test.go b/pkg/saga/statemachine/engine/invoker/javascript_script_invoker_test.go new file mode 100644 index 00000000..e41d1e5c --- /dev/null +++ b/pkg/saga/statemachine/engine/invoker/javascript_script_invoker_test.go @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package invoker + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/robertkrimen/otto" + "github.com/stretchr/testify/assert" +) + +func TestJavaScriptScriptInvoker_Type(t *testing.T) { + invoker := NewJavaScriptScriptInvoker() + assert.Equal(t, "javascript", invoker.Type()) +} + +func TestJavaScriptScriptInvoker_Invoke_Basic(t *testing.T) { + tests := []struct { + name string + script string + params map[string]interface{} + expected interface{} + }{ + { + name: "simple expression", + script: "1 + 2", + params: nil, + expected: float64(3), + }, + { + name: "param calculation", + script: "a * b + c", + params: map[string]interface{}{"a": 2, "b": 3, "c": 4}, + expected: float64(10), + }, + { + name: "return string", + script: "['hello', name].join(' ')", + params: map[string]interface{}{"name": "world"}, + expected: "hello world", + }, + { + name: "return object", + script: `var obj = {id: 1, name: name}; obj;`, + params: map[string]interface{}{"name": "test"}, + expected: map[string]interface{}{"id": float64(1), "name": "test"}, + }, + } + + invoker := NewJavaScriptScriptInvoker() + ctx := context.Background() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := invoker.Invoke(ctx, tt.script, tt.params) + assert.NoError(t, err) + + if resultMap, ok := result.(map[string]interface{}); ok { + for k, v := range resultMap { + if intVal, isInt := v.(int64); isInt { + resultMap[k] = float64(intVal) + } + } + } + + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestJavaScriptScriptInvoker_Invoke_Error(t *testing.T) { + tests := []struct { + name string + script string + params map[string]interface{} + errMsg string + }{ + { + name: "syntax error", + script: "1 + ", + params: nil, + errMsg: "javascript execute error", + }, + { + name: "reference undefined variable", + script: "undefinedVar", + params: nil, + errMsg: "javascript execute error", + }, + } + + invoker := NewJavaScriptScriptInvoker() + ctx := context.Background() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := invoker.Invoke(ctx, tt.script, tt.params) + + if err == nil { + t.Fatalf("Test case [%s] expected error but got none", tt.name) + } + assert.Contains(t, err.Error(), tt.errMsg, "Test case [%s] error message mismatch", tt.name) + }) + } +} + +func TestJavaScriptScriptInvoker_Invoke_Timeout(t *testing.T) { + + script := `var target = 300; var start = new Date().getTime(); var elapsed = 0; while (elapsed < target) { elapsed = new Date().getTime() - start; } "done";` + invoker := NewJavaScriptScriptInvoker() + + ctx1, cancel1 := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel1() + _, err := invoker.Invoke(ctx1, script, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "javascript execution timeout") + + ctx2, cancel2 := context.WithTimeout(context.Background(), 400*time.Millisecond) + defer cancel2() + result, err := invoker.Invoke(ctx2, script, nil) + assert.NoError(t, err, "Scenario 2: script execution should not return error") + assert.Equal(t, "done", result, "Scenario 2: should return 'done'") +} + +func TestJavaScriptScriptInvoker_Invoke_Concurrent(t *testing.T) { + invoker := NewJavaScriptScriptInvoker() + ctx := context.Background() + var wg sync.WaitGroup + concurrency := 100 + errChan := make(chan error, concurrency) + + script := `a + b` + params := map[string]interface{}{"a": 10, "b": 20} + + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + result, err := invoker.Invoke(ctx, script, params) + if err != nil { + errChan <- err + return + } + if result != float64(30) { + errChan <- assert.AnError + } + }() + } + + wg.Wait() + close(errChan) + + assert.Empty(t, errChan, "Concurrent execution has errors") +} + +func TestJavaScriptScriptInvoker_Close(t *testing.T) { + invoker := NewJavaScriptScriptInvoker() + ctx := context.Background() + + result, err := invoker.Invoke(ctx, "1 + 1", nil) + assert.NoError(t, err) + assert.Equal(t, float64(2), result) + + err = invoker.Close(ctx) + assert.NoError(t, err) + + _, err = invoker.Invoke(ctx, "1 + 1", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "javascript invoker has been closed") +} + +func TestOttoScript(t *testing.T) { + vm := otto.New() + script := `var target = 300; var start = new Date().getTime(); var elapsed = 0; while (elapsed < target) { elapsed = new Date().getTime() - start; } "done";` + val, err := vm.Run(script) + if err != nil { + t.Fatalf("otto failed to parse script: %v", err) + } + + result, exportErr := val.Export() + if exportErr != nil { + t.Fatalf("failed to export otto value: %v", exportErr) + } + t.Logf("Script execution result: %v", result) +} + +func TestJavaScriptScriptInvoker_VMPoolReuse(t *testing.T) { + poolSize := 2 + invoker := NewJavaScriptScriptInvokerWithPoolSize(poolSize) + ctx := context.Background() + + vmIDs := make([]string, 0, 5) + + script := ` + if (!this.vmId) { + this.vmId = Math.random().toString(36).substr(2, 8); + } + this.vmId; + ` + + for i := 0; i < 5; i++ { + result, err := invoker.Invoke(ctx, script, nil) + assert.NoError(t, err, "Error occurred while executing script") + + id, ok := result.(string) + assert.True(t, ok, "VM ID should be a string type") + vmIDs = append(vmIDs, id) + } + + uniqueIDs := make(map[string]bool) + for _, id := range vmIDs { + uniqueIDs[id] = true + } + + assert.True(t, len(uniqueIDs) <= 5, "Abnormal number of VM instances created") + assert.True(t, len(uniqueIDs) >= 1, "No VM instances reused from the pool") +} + +func TestJavaScriptScriptInvoker_VMStateClean(t *testing.T) { + invoker := NewJavaScriptScriptInvokerWithPoolSize(1) + ctx := context.Background() + + _, err := invoker.Invoke(ctx, `this.foo = "polluted data"`, nil) + assert.NoError(t, err) + + result, err := invoker.Invoke(ctx, `typeof this.foo`, nil) + assert.NoError(t, err) + assert.Equal(t, "undefined", result, "VM state not cleaned, residual global variable exists") + + _, err = invoker.Invoke(ctx, `this.bar = function() { return "residual function"; }`, nil) + assert.NoError(t, err) + + result, err = invoker.Invoke(ctx, `typeof this.bar`, nil) + assert.NoError(t, err) + assert.Equal(t, "undefined", result, "VM state not cleaned, residual function exists") +} + +func TestJavaScriptScriptInvoker_PoolSizeDefault(t *testing.T) { + invoker := NewJavaScriptScriptInvokerWithPoolSize(0) + assert.Equal(t, defaultPoolSize, invoker.poolSize, "Default pool size not used when pool size is 0") + + invoker = NewJavaScriptScriptInvokerWithPoolSize(-5) + assert.Equal(t, defaultPoolSize, invoker.poolSize, "Default pool size not used when pool size is negative") +} diff --git a/pkg/saga/statemachine/engine/invoker/invoker.go b/pkg/saga/statemachine/engine/invoker/local_invoker.go similarity index 74% copy from pkg/saga/statemachine/engine/invoker/invoker.go copy to pkg/saga/statemachine/engine/invoker/local_invoker.go index d84798b1..14c32241 100644 --- a/pkg/saga/statemachine/engine/invoker/invoker.go +++ b/pkg/saga/statemachine/engine/invoker/local_invoker.go @@ -19,50 +19,12 @@ package invoker import ( "context" - "encoding/json" "fmt" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" "reflect" "sync" - - "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" ) -type JsonParser interface { - Unmarshal(data []byte, v any) error - Marshal(v any) ([]byte, error) -} - -type DefaultJsonParser struct{} - -func (p *DefaultJsonParser) Unmarshal(data []byte, v any) error { - return json.Unmarshal(data, v) -} - -func (p *DefaultJsonParser) Marshal(v any) ([]byte, error) { - return json.Marshal(v) -} - -type ScriptInvokerManager interface { -} - -type ScriptInvoker interface { -} - -type ServiceInvokerManager interface { - ServiceInvoker(serviceType string) ServiceInvoker - PutServiceInvoker(serviceType string, invoker ServiceInvoker) -} - -type ServiceInvoker interface { - Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error) - Close(ctx context.Context) error -} - -type ServiceInvokerManagerImpl struct { - invokers map[string]ServiceInvoker - mutex sync.Mutex -} - type LocalServiceInvoker struct { serviceRegistry map[string]interface{} methodCache map[string]*reflect.Method @@ -142,45 +104,74 @@ func (l *LocalServiceInvoker) getMethod(serviceName, methodName string, paramTyp } func (l *LocalServiceInvoker) resolveParameters(input []any, methodType reflect.Type) ([]reflect.Value, error) { - params := make([]reflect.Value, methodType.NumIn()) - for i := 0; i < methodType.NumIn(); i++ { - paramType := methodType.In(i) - if i >= len(input) { - params[i] = reflect.Zero(paramType) - continue + numIn := methodType.NumIn() + paramStart, paramCount := 1, 0 + + if numIn > 0 { + paramCount = numIn - paramStart + } + + if paramCount == 0 { + if len(input) > 0 { + return nil, fmt.Errorf("unexpected parameters: expected 0, got %d", len(input)) } + return []reflect.Value{}, nil + } + + if len(input) < paramCount { + return nil, fmt.Errorf("insufficient parameters: expected %d, got %d", paramCount, len(input)) + } + + if len(input) > paramCount { + return nil, fmt.Errorf("too many parameters: expected %d, got %d", paramCount, len(input)) + } + + params := make([]reflect.Value, paramCount) + for i := 0; i < paramCount; i++ { + methodParamIndex := i + paramStart + paramType := methodType.In(methodParamIndex) converted, err := l.convertParam(input[i], paramType) if err != nil { - return nil, err + return nil, fmt.Errorf("parameter %d conversion error: %w", i, err) } + params[i] = reflect.ValueOf(converted) } + return params, nil } func (l *LocalServiceInvoker) convertParam(value any, targetType reflect.Type) (any, error) { if targetType.Kind() == reflect.Ptr { - targetType = targetType.Elem() - value = reflect.ValueOf(value).Interface() - } - - if targetType.Kind() == reflect.Int && reflect.TypeOf(value).Kind() == reflect.Float64 { - return int(value.(float64)), nil - } else if targetType == reflect.TypeOf("") && reflect.TypeOf(value).Kind() == reflect.Int { - return fmt.Sprintf("%d", value), nil + elemType := targetType.Elem() + instance := reflect.New(elemType).Interface() + jsonData, err := l.jsonParser.Marshal(value) + if err != nil { + return nil, err + } + if err := l.jsonParser.Unmarshal(jsonData, instance); err != nil { + return nil, err + } + return instance, nil } if targetType.Kind() == reflect.Struct { + instance := reflect.New(targetType).Interface() jsonData, err := l.jsonParser.Marshal(value) if err != nil { return nil, err } - instance := reflect.New(targetType).Interface() if err := l.jsonParser.Unmarshal(jsonData, instance); err != nil { return nil, err } - return instance, nil + return reflect.ValueOf(instance).Elem().Interface(), nil + } + + if targetType.Kind() == reflect.Int && reflect.TypeOf(value).Kind() == reflect.Float64 { + return int(value.(float64)), nil + } else if targetType == reflect.TypeOf("") && reflect.TypeOf(value).Kind() == reflect.Int { + return fmt.Sprintf("%d", value), nil } return value, nil @@ -202,19 +193,3 @@ func (l *LocalServiceInvoker) Close(ctx context.Context) error { l.methodCache = nil return nil } - -func NewServiceInvokerManagerImpl() *ServiceInvokerManagerImpl { - return &ServiceInvokerManagerImpl{ - invokers: make(map[string]ServiceInvoker), - } -} - -func (manager *ServiceInvokerManagerImpl) ServiceInvoker(serviceType string) ServiceInvoker { - return manager.invokers[serviceType] -} - -func (manager *ServiceInvokerManagerImpl) PutServiceInvoker(serviceType string, invoker ServiceInvoker) { - manager.mutex.Lock() - defer manager.mutex.Unlock() - manager.invokers[serviceType] = invoker -} diff --git a/pkg/saga/statemachine/engine/invoker/local_invoker_test.go b/pkg/saga/statemachine/engine/invoker/local_invoker_test.go new file mode 100644 index 00000000..b1ed6ee9 --- /dev/null +++ b/pkg/saga/statemachine/engine/invoker/local_invoker_test.go @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package invoker + +import ( + "context" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" +) + +type MockLocalService struct { + invokeCount int +} + +func (m *MockLocalService) GetServiceName() string { + return "MockLocalService" +} + +func (m *MockLocalService) Add(a, b int) int { + m.invokeCount++ + return a + b +} + +func (m *MockLocalService) Multiply(f float64, i int) float64 { + m.invokeCount++ + return f * float64(i) +} + +type User struct { + Name string `json:"name"` + Age int `json:"age"` +} + +func (m *MockLocalService) GetUserName(user User) string { + m.invokeCount++ + return user.Name +} + +func (m *MockLocalService) ErrorMethod() error { + return errors.New("expected error") +} + +func TestLocalInvoker_ServiceNotRegistered(t *testing.T) { + invoker := NewLocalServiceInvoker() + ctx := context.Background() + taskState := newLocalServiceTaskState("unregisteredService", "AnyMethod") + + _, err := invoker.Invoke(ctx, []any{}, taskState) + if err == nil { + t.Error("expected error when service not registered, but got nil") + } + if err.Error() != "service unregisteredService not registered" { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestLocalInvoker_MethodNotFound(t *testing.T) { + invoker := NewLocalServiceInvoker() + service := &MockLocalService{} + invoker.RegisterService("mockService", service) + + ctx := context.Background() + taskState := newLocalServiceTaskState("mockService", "NonExistentMethod") + + _, err := invoker.Invoke(ctx, []any{}, taskState) + if err == nil { + t.Error("expected error when method not found, but got nil") + } + if err.Error() != "method NonExistentMethod not found in service mockService" { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestLocalInvoker_InvokeSuccess(t *testing.T) { + tests := []struct { + name string + service interface{} + serviceName string + methodName string + input []any + expected interface{} + }{ + { + name: "test basic method call", + service: &MockLocalService{}, + serviceName: "mockService", + methodName: "GetServiceName", + input: []any{}, + expected: "MockLocalService", + }, + { + name: "test method with parameters", + service: &MockLocalService{}, + serviceName: "mockService", + methodName: "Add", + input: []any{2, 3}, + expected: 5, + }, + { + name: "test parameter type conversion", + service: &MockLocalService{}, + serviceName: "mockService", + methodName: "Multiply", + input: []any{2.5, 4}, + expected: 10.0, + }, + } + + invoker := NewLocalServiceInvoker() + ctx := context.Background() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + invoker.RegisterService(tt.serviceName, tt.service) + taskState := newLocalServiceTaskState(tt.serviceName, tt.methodName) + + results, err := invoker.Invoke(ctx, tt.input, taskState) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(results) == 0 { + t.Fatal("no results returned") + } + + result := results[0].Interface() + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestLocalInvoker_StructParameterConversion(t *testing.T) { + invoker := NewLocalServiceInvoker() + service := &MockLocalService{} + invoker.RegisterService("userService", service) + + ctx := context.Background() + taskState := newLocalServiceTaskState("userService", "GetUserName") + + input := []any{map[string]interface{}{"name": "Alice", "age": 30}} + results, err := invoker.Invoke(ctx, input, taskState) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(results) == 0 { + t.Fatal("no results returned") + } + + result := results[0].Interface() + if result != "Alice" { + t.Errorf("expected 'Alice', got %v", result) + } +} + +func TestLocalInvoker_MethodCaching(t *testing.T) { + invoker := NewLocalServiceInvoker() + service := &MockLocalService{} + invoker.RegisterService("cacheTestService", service) + + ctx := context.Background() + taskState := newLocalServiceTaskState("cacheTestService", "Add") + + _, err := invoker.Invoke(ctx, []any{1, 1}, taskState) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + results, err := invoker.Invoke(ctx, []any{2, 3}, taskState) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if results[0].Interface() != 5 { + t.Errorf("expected 5, got %v", results[0].Interface()) + } + + if service.invokeCount != 2 { + t.Errorf("expected 2 invocations, got %d", service.invokeCount) + } +} + +func newLocalServiceTaskState(serviceName, methodName string) state.ServiceTaskState { + serviceTaskStateImpl := state.NewServiceTaskStateImpl() + serviceTaskStateImpl.SetName(fmt.Sprintf("%s_%s", serviceName, methodName)) + serviceTaskStateImpl.SetIsAsync(false) + serviceTaskStateImpl.SetServiceName(serviceName) + serviceTaskStateImpl.SetServiceType("local") + serviceTaskStateImpl.SetServiceMethod(methodName) + return serviceTaskStateImpl +} --------------------------------------------------------------------- To unsubscribe, e-mail: notifications-unsubscr...@seata.apache.org For additional commands, e-mail: notifications-h...@seata.apache.org