This is an automated email from the ASF dual-hosted git repository.

luky116 pushed a commit to branch feature/saga
in repository https://gitbox.apache.org/repos/asf/incubator-seata-go.git


The following commit(s) were added to refs/heads/feature/saga by this push:
     new 1385f9e8 feat: add func invoker (#744)
1385f9e8 is described below

commit 1385f9e856123857e9c01e81b57d7f7bdccabc41
Author: marsevilspirit <marsevilspi...@gmail.com>
AuthorDate: Sat Dec 21 15:38:04 2024 +0800

    feat: add func invoker (#744)
    
    * add func task state
---
 .../statemachine/engine/invoker/func_invoker.go    | 208 +++++++++++++++++++++
 .../engine/invoker/func_invoker_test.go            | 151 +++++++++++++++
 2 files changed, 359 insertions(+)

diff --git a/pkg/saga/statemachine/engine/invoker/func_invoker.go 
b/pkg/saga/statemachine/engine/invoker/func_invoker.go
new file mode 100644
index 00000000..085decb8
--- /dev/null
+++ b/pkg/saga/statemachine/engine/invoker/func_invoker.go
@@ -0,0 +1,208 @@
+package invoker
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "reflect"
+       "strings"
+       "sync"
+       "time"
+
+       "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
+       "github.com/seata/seata-go/pkg/util/log"
+)
+
+type FuncInvoker struct {
+       ServicesMapLock sync.Mutex
+       servicesMap     map[string]FuncService
+}
+
+func NewFuncInvoker() *FuncInvoker {
+       return &FuncInvoker{
+               servicesMap: make(map[string]FuncService),
+       }
+}
+
+func (f *FuncInvoker) RegisterService(serviceName string, service FuncService) 
{
+       f.ServicesMapLock.Lock()
+       defer f.ServicesMapLock.Unlock()
+       f.servicesMap[serviceName] = service
+}
+
+func (f *FuncInvoker) GetService(serviceName string) FuncService {
+       f.ServicesMapLock.Lock()
+       defer f.ServicesMapLock.Unlock()
+       return f.servicesMap[serviceName]
+}
+
+func (f *FuncInvoker) Invoke(ctx context.Context, input []any, service 
state.ServiceTaskState) (output []reflect.Value, err error) {
+       serviceTaskStateImpl := service.(*state.ServiceTaskStateImpl)
+       FuncService := f.GetService(serviceTaskStateImpl.ServiceName())
+       if FuncService == nil {
+               return nil, errors.New("no func service " + 
serviceTaskStateImpl.ServiceName() + " for service task state")
+       }
+
+       if serviceTaskStateImpl.IsAsync() {
+               go func() {
+                       _, err := FuncService.CallMethod(serviceTaskStateImpl, 
input)
+                       if err != nil {
+                               log.Errorf("invoke Service[%s].%s failed, err 
is %s", serviceTaskStateImpl.ServiceName(), 
serviceTaskStateImpl.ServiceMethod(), err.Error())
+                       }
+               }()
+               return nil, nil
+       }
+
+       return FuncService.CallMethod(serviceTaskStateImpl, input)
+}
+
+func (f *FuncInvoker) Close(ctx context.Context) error {
+       return nil
+}
+
+type FuncService interface {
+       CallMethod(ServiceTaskStateImpl *state.ServiceTaskStateImpl, input 
[]any) ([]reflect.Value, error)
+}
+
+type FuncServiceImpl struct {
+       serviceName string
+       methodLock  sync.Mutex
+       method      any
+}
+
+func NewFuncService(serviceName string, method any) *FuncServiceImpl {
+       return &FuncServiceImpl{
+               serviceName: serviceName,
+               method:      method,
+       }
+}
+
+func (f *FuncServiceImpl) getMethod(serviceTaskStateImpl 
*state.ServiceTaskStateImpl) (*reflect.Value, error) {
+       method := serviceTaskStateImpl.Method()
+       if method == nil {
+               return f.initMethod(serviceTaskStateImpl)
+       }
+       return method, nil
+}
+
+func (f *FuncServiceImpl) prepareArguments(input []any) []reflect.Value {
+       args := make([]reflect.Value, len(input))
+       for i, arg := range input {
+               args[i] = reflect.ValueOf(arg)
+       }
+       return args
+}
+
+func (f *FuncServiceImpl) CallMethod(serviceTaskStateImpl 
*state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error) {
+       method, err := f.getMethod(serviceTaskStateImpl)
+       if err != nil {
+               return nil, err
+       }
+
+       args := f.prepareArguments(input)
+
+       retryCountMap := make(map[state.Retry]int)
+       for {
+               res, err, shouldRetry := f.invokeMethod(method, args, 
serviceTaskStateImpl, retryCountMap)
+
+               if !shouldRetry {
+                       if err != nil {
+                               return nil, errors.New("invoke service[" + 
serviceTaskStateImpl.ServiceName() + "]." + 
serviceTaskStateImpl.ServiceMethod() + " failed, err is " + err.Error())
+                       }
+                       return res, nil
+               }
+       }
+}
+
+func (f *FuncServiceImpl) initMethod(serviceTaskStateImpl 
*state.ServiceTaskStateImpl) (*reflect.Value, error) {
+       methodName := serviceTaskStateImpl.ServiceMethod()
+       f.methodLock.Lock()
+       defer f.methodLock.Unlock()
+       methodValue := reflect.ValueOf(f.method)
+       if methodValue.IsZero() {
+               return nil, errors.New("invalid method when func call, 
serviceName: " + f.serviceName)
+       }
+
+       if methodValue.Kind() == reflect.Func {
+               serviceTaskStateImpl.SetMethod(&methodValue)
+               return &methodValue, nil
+       }
+
+       method := methodValue.MethodByName(methodName)
+       if method.IsZero() {
+               return nil, errors.New("invalid method name when func call, 
serviceName: " + f.serviceName + ", methodName: " + methodName)
+       }
+       serviceTaskStateImpl.SetMethod(&method)
+       return &method, nil
+}
+
+func (f *FuncServiceImpl) invokeMethod(method *reflect.Value, args 
[]reflect.Value, serviceTaskStateImpl *state.ServiceTaskStateImpl, 
retryCountMap map[state.Retry]int) ([]reflect.Value, error, bool) {
+       var res []reflect.Value
+       var resErr error
+       var shouldRetry bool
+
+       defer func() {
+               if r := recover(); r != nil {
+                       errStr := fmt.Sprintf("%v", r)
+                       retry := f.matchRetry(serviceTaskStateImpl, errStr)
+                       resErr = errors.New(errStr)
+                       if retry != nil {
+                               shouldRetry = f.needRetry(serviceTaskStateImpl, 
retryCountMap, retry, resErr)
+                       }
+               }
+       }()
+
+       outs := method.Call(args)
+       if err, ok := outs[len(outs)-1].Interface().(error); ok {
+               resErr = err
+               errStr := err.Error()
+               retry := f.matchRetry(serviceTaskStateImpl, errStr)
+               if retry != nil {
+                       shouldRetry = f.needRetry(serviceTaskStateImpl, 
retryCountMap, retry, resErr)
+               }
+               return nil, resErr, shouldRetry
+       }
+
+       res = outs
+       return res, nil, false
+}
+
+func (f *FuncServiceImpl) matchRetry(impl *state.ServiceTaskStateImpl, str 
string) state.Retry {
+       if impl.Retry() != nil {
+               for _, retry := range impl.Retry() {
+                       if retry.Exceptions() != nil {
+                               for _, exception := range retry.Exceptions() {
+                                       if strings.Contains(str, exception) {
+                                               return retry
+                                       }
+                               }
+                       }
+               }
+       }
+       return nil
+}
+
+func (f *FuncServiceImpl) needRetry(impl *state.ServiceTaskStateImpl, countMap 
map[state.Retry]int, retry state.Retry, err error) bool {
+       attempt, exist := countMap[retry]
+       if !exist {
+               countMap[retry] = 0
+       }
+
+       if attempt >= retry.MaxAttempt() {
+               return false
+       }
+
+       interval := retry.IntervalSecond()
+       backoffRate := retry.BackoffRate()
+       curInterval := int64(interval * 1000)
+       if attempt != 0 {
+               curInterval = int64(interval * backoffRate * float64(attempt) * 
1000)
+       }
+
+       log.Warnf("invoke service[%s.%s] failed, will retry after %s millis, 
current retry count: %s, current err: %s",
+               impl.ServiceName(), impl.ServiceMethod(), curInterval, attempt, 
err)
+
+       time.Sleep(time.Duration(curInterval) * time.Millisecond)
+       countMap[retry] = attempt + 1
+       return true
+}
diff --git a/pkg/saga/statemachine/engine/invoker/func_invoker_test.go 
b/pkg/saga/statemachine/engine/invoker/func_invoker_test.go
new file mode 100644
index 00000000..e800fdc1
--- /dev/null
+++ b/pkg/saga/statemachine/engine/invoker/func_invoker_test.go
@@ -0,0 +1,151 @@
+package invoker
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "testing"
+
+       "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
+)
+
+// struct's method test
+type mockFuncImpl struct {
+       invokeCount int
+}
+
+func (m *mockFuncImpl) SayHelloRight(word string) (string, error) {
+       m.invokeCount++
+       fmt.Println("invoke right")
+       return word, nil
+}
+
+func (m *mockFuncImpl) SayHelloRightLater(word string, delay int) (string, 
error) {
+       m.invokeCount++
+       if delay == m.invokeCount {
+               fmt.Println("invoke right")
+               return word, nil
+       }
+       fmt.Println("invoke fail")
+       return "", errors.New("invoke failed")
+}
+
+func TestFuncInvokerInvokeSucceed(t *testing.T) {
+       tests := []struct {
+               name      string
+               input     []any
+               taskState state.ServiceTaskState
+               expected  string
+               expectErr bool
+       }{
+               {
+                       name:      "Invoke Struct Succeed",
+                       input:     []any{"hello"},
+                       taskState: newFuncHelloServiceTaskState(),
+                       expected:  "hello",
+                       expectErr: false,
+               },
+               {
+                       name:      "Invoke Struct In Retry",
+                       input:     []any{"hello", 2},
+                       taskState: newFuncHelloServiceTaskStateWithRetry(),
+                       expected:  "hello",
+                       expectErr: false,
+               },
+       }
+
+       ctx := context.Background()
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       invoker := newFuncServiceInvoker()
+                       values, err := invoker.Invoke(ctx, tt.input, 
tt.taskState)
+
+                       if (err != nil) != tt.expectErr {
+                               t.Errorf("expected error: %v, got: %v", 
tt.expectErr, err)
+                       }
+
+                       if values == nil || len(values) == 0 {
+                               t.Fatal("no value in values")
+                       }
+
+                       if resultString, ok := values[0].Interface().(string); 
ok {
+                               if resultString != tt.expected {
+                                       t.Errorf("expect %s, but got %s", 
tt.expected, resultString)
+                               }
+                       } else {
+                               t.Errorf("expected string, but got %v", 
values[0].Interface())
+                       }
+
+                       if resultError, ok := values[1].Interface().(error); ok 
{
+                               if resultError != nil {
+                                       t.Errorf("expect nil, but got %s", 
resultError)
+                               }
+                       }
+               })
+       }
+}
+
+func TestFuncInvokerInvokeFailed(t *testing.T) {
+       tests := []struct {
+               name      string
+               input     []any
+               taskState state.ServiceTaskState
+               expected  string
+               expectErr bool
+       }{
+               {
+                       name:      "Invoke Struct Failed In Retry",
+                       input:     []any{"hello", 5},
+                       taskState: newFuncHelloServiceTaskStateWithRetry(),
+                       expected:  "",
+                       expectErr: true,
+               },
+       }
+
+       ctx := context.Background()
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       invoker := newFuncServiceInvoker()
+                       _, err := invoker.Invoke(ctx, tt.input, tt.taskState)
+
+                       if (err != nil) != tt.expectErr {
+                               t.Errorf("expected error: %v, got: %v", 
tt.expectErr, err)
+                       }
+               })
+       }
+}
+
+func newFuncServiceInvoker() ServiceInvoker {
+       mockFuncInvoker := NewFuncInvoker()
+       mockFuncService := &mockFuncImpl{}
+       mockService := NewFuncService("hello", mockFuncService)
+       mockFuncInvoker.RegisterService("hello", mockService)
+       return mockFuncInvoker
+}
+
+func newFuncHelloServiceTaskState() state.ServiceTaskState {
+       serviceTaskStateImpl := state.NewServiceTaskStateImpl()
+       serviceTaskStateImpl.SetName("hello")
+       serviceTaskStateImpl.SetIsAsync(false)
+       serviceTaskStateImpl.SetServiceName("hello")
+       serviceTaskStateImpl.SetServiceType("func")
+       serviceTaskStateImpl.SetServiceMethod("SayHelloRight")
+       return serviceTaskStateImpl
+}
+
+func newFuncHelloServiceTaskStateWithRetry() state.ServiceTaskState {
+       serviceTaskStateImpl := state.NewServiceTaskStateImpl()
+       serviceTaskStateImpl.SetName("hello")
+       serviceTaskStateImpl.SetIsAsync(false)
+       serviceTaskStateImpl.SetServiceName("hello")
+       serviceTaskStateImpl.SetServiceType("func")
+       serviceTaskStateImpl.SetServiceMethod("SayHelloRightLater")
+
+       retryImpl := &state.RetryImpl{}
+       retryImpl.SetExceptions([]string{"fail"})
+       retryImpl.SetIntervalSecond(1)
+       retryImpl.SetMaxAttempt(3)
+       retryImpl.SetBackoffRate(0.9)
+       serviceTaskStateImpl.SetRetry([]state.Retry{retryImpl})
+       return serviceTaskStateImpl
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscr...@seata.apache.org
For additional commands, e-mail: notifications-h...@seata.apache.org

Reply via email to