This is an automated email from the ASF dual-hosted git repository. xjlgod pushed a commit to branch feature/saga in repository https://gitbox.apache.org/repos/asf/incubator-seata-go.git
The following commit(s) were added to refs/heads/feature/saga by this push: new 689c5d6f Feature: Database persistence for seata-go Saga state machine (#794) 689c5d6f is described below commit 689c5d6f7ca8cf99655b1eb18607f55caae84790 Author: lxfeng1997 <33981743+lxfeng1...@users.noreply.github.com> AuthorDate: Sat Mar 29 22:27:15 2025 +0800 Feature: Database persistence for seata-go Saga state machine (#794) * implement state machine repository * temporary recording * temporary recording * temporary storage * Supplementary test --- pkg/saga/statemachine/constant/constant.go | 3 + .../engine/core/default_statemachine_config.go | 66 +++- .../statemachine/engine/core/statemachine_store.go | 2 + .../statelang/statemachine_instance.go | 12 + pkg/saga/statemachine/store/db/statelang.go | 2 +- pkg/saga/statemachine/store/db/statelog.go | 354 ++++++++++++++++++++- pkg/saga/statemachine/store/db/statelog_test.go | 7 + .../store/repository/state_log_repository.go | 18 ++ .../store/repository/state_machine_repository.go | 219 ++++++++++++- .../repository/state_machine_repository_test.go | 118 +++++++ pkg/tm/global_transaction.go | 25 ++ 11 files changed, 784 insertions(+), 42 deletions(-) diff --git a/pkg/saga/statemachine/constant/constant.go b/pkg/saga/statemachine/constant/constant.go index 72af09e9..6032f3a4 100644 --- a/pkg/saga/statemachine/constant/constant.go +++ b/pkg/saga/statemachine/constant/constant.go @@ -67,6 +67,7 @@ const ( // TODO: this lock in process context only has one, try to add more to add concurrent VarNameProcessContextMutexLock string = "_current_context_mutex_lock" VarNameFailEndStateFlag string = "_fail_end_state_flag_" + VarNameGlobalTx string = "_global_transaction_" // end region // region of loop @@ -79,10 +80,12 @@ const ( // end region // region others + SeqEntityStateMachine string = "STATE_MACHINE" SeqEntityStateMachineInst string = "STATE_MACHINE_INST" SeqEntityStateInst string = "STATE_INST" OperationNameForward string = "forward" LoopStateNamePattern string = "-loop-" + SagaTransNamePrefix string = "$Saga_" // end region SeperatorParentId string = ":" diff --git a/pkg/saga/statemachine/engine/core/default_statemachine_config.go b/pkg/saga/statemachine/engine/core/default_statemachine_config.go index fa2ef157..cd118f55 100644 --- a/pkg/saga/statemachine/engine/core/default_statemachine_config.go +++ b/pkg/saga/statemachine/engine/core/default_statemachine_config.go @@ -25,16 +25,24 @@ import ( ) const ( - DefaultTransOperTimeout = 60000 * 30 - DefaultServiceInvokeTimeout = 60000 * 5 + DefaultTransOperTimeout = 60000 * 30 + DefaultServiceInvokeTimeout = 60000 * 5 + DefaultClientSagaRetryPersistModeUpdate = false + DefaultClientSagaCompensatePersistModeUpdate = false + DefaultClientReportSuccessEnable = false + DefaultClientSagaBranchRegisterEnable = true ) type DefaultStateMachineConfig struct { // Configuration - transOperationTimeout int - serviceInvokeTimeout int - charset string - defaultTenantId string + transOperationTimeout int + serviceInvokeTimeout int + charset string + defaultTenantId string + sagaRetryPersistModeUpdate bool + sagaCompensatePersistModeUpdate bool + sagaBranchRegisterEnable bool + rmReportSuccessEnable bool // Components @@ -202,13 +210,49 @@ func (c *DefaultStateMachineConfig) ServiceInvokeTimeout() int { return c.serviceInvokeTimeout } +func (c *DefaultStateMachineConfig) IsSagaRetryPersistModeUpdate() bool { + return c.sagaRetryPersistModeUpdate +} + +func (c *DefaultStateMachineConfig) SetSagaRetryPersistModeUpdate(sagaRetryPersistModeUpdate bool) { + c.sagaRetryPersistModeUpdate = sagaRetryPersistModeUpdate +} + +func (c *DefaultStateMachineConfig) IsSagaCompensatePersistModeUpdate() bool { + return c.sagaCompensatePersistModeUpdate +} + +func (c *DefaultStateMachineConfig) SetSagaCompensatePersistModeUpdate(sagaCompensatePersistModeUpdate bool) { + c.sagaCompensatePersistModeUpdate = sagaCompensatePersistModeUpdate +} + +func (c *DefaultStateMachineConfig) IsSagaBranchRegisterEnable() bool { + return c.sagaBranchRegisterEnable +} + +func (c *DefaultStateMachineConfig) SetSagaBranchRegisterEnable(sagaBranchRegisterEnable bool) { + c.sagaBranchRegisterEnable = sagaBranchRegisterEnable +} + +func (c *DefaultStateMachineConfig) IsRmReportSuccessEnable() bool { + return c.rmReportSuccessEnable +} + +func (c *DefaultStateMachineConfig) SetRmReportSuccessEnable(rmReportSuccessEnable bool) { + c.rmReportSuccessEnable = rmReportSuccessEnable +} + func NewDefaultStateMachineConfig() *DefaultStateMachineConfig { c := &DefaultStateMachineConfig{ - transOperationTimeout: DefaultTransOperTimeout, - serviceInvokeTimeout: DefaultServiceInvokeTimeout, - charset: "UTF-8", - defaultTenantId: "000001", - componentLock: &sync.Mutex{}, + transOperationTimeout: DefaultTransOperTimeout, + serviceInvokeTimeout: DefaultServiceInvokeTimeout, + charset: "UTF-8", + defaultTenantId: "000001", + sagaRetryPersistModeUpdate: DefaultClientSagaRetryPersistModeUpdate, + sagaCompensatePersistModeUpdate: DefaultClientSagaCompensatePersistModeUpdate, + sagaBranchRegisterEnable: DefaultClientSagaBranchRegisterEnable, + rmReportSuccessEnable: DefaultClientReportSuccessEnable, + componentLock: &sync.Mutex{}, } // TODO: init config diff --git a/pkg/saga/statemachine/engine/core/statemachine_store.go b/pkg/saga/statemachine/engine/core/statemachine_store.go index f47065dd..7e0a9e94 100644 --- a/pkg/saga/statemachine/engine/core/statemachine_store.go +++ b/pkg/saga/statemachine/engine/core/statemachine_store.go @@ -55,6 +55,8 @@ type StateLogStore interface { GetStateInstance(stateInstanceId string, stateMachineInstanceId string) (statelang.StateInstance, error) GetStateInstanceListByMachineInstanceId(stateMachineInstanceId string) ([]statelang.StateInstance, error) + + ClearUp(context ProcessContext) } type StateMachineRepository interface { diff --git a/pkg/saga/statemachine/statelang/statemachine_instance.go b/pkg/saga/statemachine/statelang/statemachine_instance.go index b794e101..9ba2c06a 100644 --- a/pkg/saga/statemachine/statelang/statemachine_instance.go +++ b/pkg/saga/statemachine/statelang/statemachine_instance.go @@ -72,6 +72,10 @@ type StateMachineInstance interface { SetStatus(status ExecutionStatus) + StateMap() map[string]StateInstance + + SetStateMap(stateMap map[string]StateInstance) + CompensationStatus() ExecutionStatus SetCompensationStatus(compensationStatus ExecutionStatus) @@ -234,6 +238,14 @@ func (s *StateMachineInstanceImpl) SetStatus(status ExecutionStatus) { s.status = status } +func (s *StateMachineInstanceImpl) StateMap() map[string]StateInstance { + return s.stateMap +} + +func (s *StateMachineInstanceImpl) SetStateMap(stateMap map[string]StateInstance) { + s.stateMap = stateMap +} + func (s *StateMachineInstanceImpl) CompensationStatus() ExecutionStatus { return s.compensationStatus } diff --git a/pkg/saga/statemachine/store/db/statelang.go b/pkg/saga/statemachine/store/db/statelang.go index f0c25e33..82e4fd29 100644 --- a/pkg/saga/statemachine/store/db/statelang.go +++ b/pkg/saga/statemachine/store/db/statelang.go @@ -104,7 +104,7 @@ func scanRowsToStateMachine(rows *sql.Rows) (statelang.StateMachine, error) { if recoverStrategy != "" { stateMachine.SetRecoverStrategy(statelang.RecoverStrategy(recoverStrategy)) } - stateMachine.SetTenantId(t) + stateMachine.SetTenantId(tenantId) stateMachine.SetStatus(statelang.StateMachineStatus(status)) return stateMachine, nil } diff --git a/pkg/saga/statemachine/store/db/statelog.go b/pkg/saga/statemachine/store/db/statelog.go index ba60c61f..a9235dbb 100644 --- a/pkg/saga/statemachine/store/db/statelog.go +++ b/pkg/saga/statemachine/store/db/statelog.go @@ -21,17 +21,24 @@ import ( "context" "database/sql" "fmt" + "regexp" + "strconv" + "strings" + "time" + "github.com/pkg/errors" + constant2 "github.com/seata/seata-go/pkg/constant" + "github.com/seata/seata-go/pkg/protocol/branch" + "github.com/seata/seata-go/pkg/protocol/message" + "github.com/seata/seata-go/pkg/rm" "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/engine/core" "github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence" "github.com/seata/seata-go/pkg/saga/statemachine/engine/serializer" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" + "github.com/seata/seata-go/pkg/tm" "github.com/seata/seata-go/pkg/util/log" - "regexp" - "strconv" - "strings" - "time" ) const ( @@ -124,14 +131,19 @@ func (s *StateLogStore) RecordStateMachineStarted(ctx context.Context, machineIn parentId := machineInstance.ParentID() if parentId == "" { - //TODO begin transaction + // begin transaction + err = s.beginTransaction(ctx, machineInstance, context) + if err != nil { + return err + } } if machineInstance.ID() == "" && s.seqGenerator != nil { machineInstance.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachineInst, "")) } - //TODO bind SAGA branch type + // bind SAGA branch type + context.SetVariable(constant2.BranchTypeKey, branch.BranchTypeSAGA) serializedStartParams, err := s.paramsSerializer.Serialize(machineInstance.StartParams()) if err != nil { @@ -149,14 +161,48 @@ func (s *StateLogStore) RecordStateMachineStarted(ctx context.Context, machineIn return nil } +func (s *StateLogStore) beginTransaction(ctx context.Context, machineInstance statelang.StateMachineInstance, context core.ProcessContext) error { + cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.StateMachineConfig) + if !ok { + return errors.New("begin transaction fail, stateMachineConfig is required in context") + } + + defer func() { + isAsync, ok := context.GetVariable(constant.VarNameIsAsyncExecution).(bool) + if ok && isAsync { + s.ClearUp(context) + } + }() + + tm.SetTxRole(ctx, tm.Launcher) + tm.SetTxStatus(ctx, message.GlobalStatusUnKnown) + tm.SetTxName(ctx, constant.SagaTransNamePrefix+machineInstance.StateMachine().Name()) + + err := tm.GetGlobalTransactionManager().Begin(ctx, time.Duration(cfg.TransOperationTimeout())) + if err != nil { + return err + } + + machineInstance.SetID(tm.GetXID(ctx)) + return nil +} + func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineInstance statelang.StateMachineInstance, context core.ProcessContext) error { if machineInstance == nil { return nil } + defer func() { + s.ClearUp(context) + }() + endParams := machineInstance.EndParams() + if endParams != nil { + delete(endParams, constant.VarNameGlobalTx) + } + // if success, clear exception if statelang.SU == machineInstance.Status() && machineInstance.Exception() != nil { machineInstance.SetException(nil) } @@ -183,10 +229,86 @@ func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineI return nil } - //TODO check if timeout or else report transaction finished + // check if timeout or else report transaction finished + cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.StateMachineConfig) + if !ok { + return errors.New("stateMachineConfig is required in context") + } + + if core.IsTimeout(machineInstance.UpdatedTime(), cfg.TransOperationTimeout()) { + log.Warnf("StateMachineInstance[%s] is execution timeout, skip report transaction finished to server.", machineInstance.ID()) + } else if machineInstance.ParentID() == "" { + //if parentId is not null, machineInstance is a SubStateMachine, do not report global transaction. + err = s.reportTransactionFinished(ctx, machineInstance, context) + if err != nil { + return err + } + } + return nil +} + +func (s *StateLogStore) reportTransactionFinished(ctx context.Context, machineInstance statelang.StateMachineInstance, context core.ProcessContext) error { + var err error + defer func() { + s.ClearUp(context) + if err != nil { + log.Errorf("Report transaction finish to server error: %v, StateMachine: %s, XID: %s, Reason: %s", + err, machineInstance.StateMachine().Name(), machineInstance.ID(), err.Error()) + } + }() + + globalTransaction, err := s.getGlobalTransaction(machineInstance, context) + if err != nil { + log.Errorf("Failed to get global transaction: %v", err) + return err + } + + var globalStatus message.GlobalStatus + if statelang.SU == machineInstance.Status() && machineInstance.CompensationStatus() == "" { + globalStatus = message.GlobalStatusCommitted + } else if statelang.SU == machineInstance.CompensationStatus() { + globalStatus = message.GlobalStatusRollbacked + } else if statelang.FA == machineInstance.CompensationStatus() || statelang.UN == machineInstance.CompensationStatus() { + globalStatus = message.GlobalStatusRollbackRetrying + } else if statelang.FA == machineInstance.Status() && machineInstance.CompensationStatus() == "" { + globalStatus = message.GlobalStatusFinished + } else if statelang.UN == machineInstance.Status() && machineInstance.CompensationStatus() == "" { + globalStatus = message.GlobalStatusCommitRetrying + } else { + globalStatus = message.GlobalStatusUnKnown + } + + globalTransaction.TxStatus = globalStatus + _, err = tm.GetGlobalTransactionManager().GlobalReport(ctx, globalTransaction) + if err != nil { + return err + } return nil } +func (s *StateLogStore) getGlobalTransaction(machineInstance statelang.StateMachineInstance, context core.ProcessContext) (*tm.GlobalTransaction, error) { + globalTransaction, ok := context.GetVariable(constant.VarNameGlobalTx).(*tm.GlobalTransaction) + if ok { + return globalTransaction, nil + } + + var xid string + parentId := machineInstance.ParentID() + if parentId == "" { + xid = machineInstance.ID() + } else { + xid = parentId[:strings.LastIndex(parentId, constant.SeperatorParentId)] + } + globalTransaction = &tm.GlobalTransaction{ + Xid: xid, + TxStatus: message.GlobalStatusUnKnown, + TxRole: tm.Launcher, + } + + context.SetVariable(constant.VarNameGlobalTx, globalTransaction) + return globalTransaction, nil +} + func (s *StateLogStore) RecordStateMachineRestarted(ctx context.Context, machineInstance statelang.StateMachineInstance, context core.ProcessContext) error { if machineInstance == nil { @@ -211,18 +333,25 @@ func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance st if stateInstance == nil { return nil } - isUpdateMode := s.isUpdateMode(stateInstance, context) + isUpdateMode, err := s.isUpdateMode(stateInstance, context) + if err != nil { + return err + } + // if this state is for retry, do not register branch if stateInstance.StateIDRetriedFor() != "" { if isUpdateMode { stateInstance.SetID(stateInstance.StateIDRetriedFor()) } else { + // generate id by default stateInstance.SetID(s.generateRetryStateInstanceId(stateInstance)) } } else if stateInstance.StateIDCompensatedFor() != "" { + // if this state is for compensation, do not register branch stateInstance.SetID(s.generateCompensateStateInstanceId(stateInstance, isUpdateMode)) } else { - //TODO register branch + // register branch + s.branchRegister(stateInstance, context) } if stateInstance.ID() == "" && s.seqGenerator != nil { @@ -252,9 +381,45 @@ func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance st return nil } -func (s *StateLogStore) isUpdateMode(instance statelang.StateInstance, context core.ProcessContext) bool { - //TODO implement me, add forward logic - return false +func (s *StateLogStore) isUpdateMode(stateInstance statelang.StateInstance, context core.ProcessContext) (bool, error) { + cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) + if !ok { + return false, errors.New("stateMachineConfig is required in context") + } + + instruction, ok := context.GetInstruction().(*core.StateInstruction) + if !ok { + return false, errors.New("stateInstruction is required in processContext") + } + instructionState, err := instruction.GetState(context) + if err != nil { + return false, err + } + taskState, _ := instructionState.(*state.AbstractTaskState) + stateMachine := stateInstance.StateMachineInstance().StateMachine() + + if stateInstance.StateIDRetriedFor() != "" { + if taskState != nil && taskState.RetryPersistModeUpdate() { + return taskState.RetryPersistModeUpdate(), nil + } else if stateMachine.IsRetryPersistModeUpdate() { + return stateMachine.IsRetryPersistModeUpdate(), nil + } + return cfg.IsSagaRetryPersistModeUpdate(), nil + } else if stateInstance.StateIDCompensatedFor() != "" { + // find if this compensate has been executed + stateList := stateInstance.StateMachineInstance().StateList() + for _, instance := range stateList { + if instance.IsForCompensation() && instance.Name() == stateInstance.Name() { + if taskState != nil && taskState.CompensatePersistModeUpdate() { + return taskState.CompensatePersistModeUpdate(), nil + } else if stateMachine.IsCompensatePersistModeUpdate() { + return stateMachine.IsCompensatePersistModeUpdate(), nil + } + return cfg.IsSagaCompensatePersistModeUpdate(), nil + } + } + } + return false, nil } func (s *StateLogStore) generateRetryStateInstanceId(stateInstance statelang.StateInstance) string { @@ -293,6 +458,49 @@ func (s *StateLogStore) generateCompensateStateInstanceId(stateInstance statelan return fmt.Sprintf("%s-%d", originalCompensateStateInstId, maxIndex) } +func (s *StateLogStore) branchRegister(stateInstance statelang.StateInstance, context core.ProcessContext) error { + cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) + if !ok { + return errors.New("stateMachineConfig is required in context") + } + + if !cfg.IsSagaBranchRegisterEnable() { + log.Debugf("sagaBranchRegisterEnable = false, skip branch report. state[%s]", stateInstance.Name()) + return nil + } + + //Register branch + var err error + machineInstance := stateInstance.StateMachineInstance() + defer func() { + if err != nil { + log.Errorf("Branch transaction failure. StateMachine: %s, XID: %s, State: %s, stateId: %s, err: %v", + machineInstance.StateMachine().Name(), machineInstance.ID(), stateInstance.Name(), stateInstance.ID(), err) + } + }() + + globalTransaction, err := s.getGlobalTransaction(machineInstance, context) + if err != nil { + return err + } + if globalTransaction == nil { + err = errors.New("Global transaction is not exists") + return err + } + + branchId, err := rm.GetRMRemotingInstance().BranchRegister(rm.BranchRegisterParam{ + BranchType: branch.BranchTypeSAGA, + ResourceId: machineInstance.StateMachine().Name() + "#" + stateInstance.Name(), + Xid: globalTransaction.Xid, + }) + if err != nil { + return err + } + + stateInstance.SetID(strconv.FormatInt(branchId, 10)) + return nil +} + func (s *StateLogStore) getIdIndex(stateInstanceId string, separator string) int { if stateInstanceId != "" { start := strings.LastIndex(stateInstanceId, separator) @@ -332,11 +540,124 @@ func (s *StateLogStore) RecordStateFinished(ctx context.Context, stateInstance s return err } - //TODO report branch + // A switch to skip branch report on branch success, in order to optimize performance + cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) + if !(ok && !cfg.IsRmReportSuccessEnable() && statelang.SU == stateInstance.Status()) { + err = s.branchReport(stateInstance, context) + return err + } + return nil } +func (s *StateLogStore) branchReport(stateInstance statelang.StateInstance, context core.ProcessContext) error { + cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) + if ok && !cfg.IsSagaBranchRegisterEnable() { + log.Debugf("sagaBranchRegisterEnable = false, skip branch report. state[%s]", stateInstance.Name()) + return nil + } + + var branchStatus branch.BranchStatus + // find out the original state instance, only the original state instance is registered on the server, + // and its status should be reported. + var originalStateInst statelang.StateInstance + if stateInstance.StateIDRetriedFor() != "" { + isUpdateMode, err := s.isUpdateMode(stateInstance, context) + if err != nil { + return err + } + + if isUpdateMode { + originalStateInst = stateInstance + } else { + originalStateInst = s.findOutOriginalStateInstanceOfRetryState(stateInstance) + } + + if statelang.SU == stateInstance.Status() { + branchStatus = branch.BranchStatusPhasetwoCommitted + } else if statelang.FA == stateInstance.Status() || statelang.UN == stateInstance.Status() { + branchStatus = branch.BranchStatusPhaseoneFailed + } else { + branchStatus = branch.BranchStatusUnknown + } + } else if stateInstance.StateIDCompensatedFor() != "" { + isUpdateMode, err := s.isUpdateMode(stateInstance, context) + if err != nil { + return err + } + + if isUpdateMode { + originalStateInst = stateInstance.StateMachineInstance().StateMap()[stateInstance.StateIDCompensatedFor()] + } else { + originalStateInst = s.findOutOriginalStateInstanceOfCompensateState(stateInstance) + } + } + + if originalStateInst == nil { + originalStateInst = stateInstance + } + + if branchStatus == branch.BranchStatusUnknown { + if statelang.SU == originalStateInst.Status() && originalStateInst.CompensationStatus() == "" { + branchStatus = branch.BranchStatusPhasetwoCommitted + } else if statelang.SU == originalStateInst.CompensationStatus() { + branchStatus = branch.BranchStatusPhasetwoRollbacked + } else if statelang.FA == originalStateInst.CompensationStatus() || statelang.UN == originalStateInst.CompensationStatus() { + branchStatus = branch.BranchStatusPhasetwoRollbackFailedRetryable + } else if (statelang.FA == originalStateInst.Status() || statelang.UN == originalStateInst.Status()) && originalStateInst.CompensationStatus() == "" { + branchStatus = branch.BranchStatusPhaseoneFailed + } else { + branchStatus = branch.BranchStatusUnknown + } + } + + var err error + defer func() { + if err != nil { + log.Errorf("Report branch status to server error:%s, StateMachine:%s, StateName:%s, XID:%s, branchId:%s, branchStatus:%s, err:%v", + err.Error(), originalStateInst.StateMachineInstance().StateMachine().Name(), originalStateInst.Name(), + originalStateInst.StateMachineInstance().ID(), originalStateInst.ID(), branchStatus, err) + } + }() + + globalTransaction, err := s.getGlobalTransaction(stateInstance.StateMachineInstance(), context) + if err != nil { + return err + } + if globalTransaction == nil { + err = errors.New("Global transaction is not exists") + return err + } + + branchId, err := strconv.ParseInt(originalStateInst.ID(), 10, 0) + err = rm.GetRMRemotingInstance().BranchReport(rm.BranchReportParam{ + BranchType: branch.BranchTypeSAGA, + Xid: globalTransaction.Xid, + BranchId: branchId, + Status: branchStatus, + }) + return err +} + +func (s *StateLogStore) findOutOriginalStateInstanceOfRetryState(stateInstance statelang.StateInstance) statelang.StateInstance { + stateMap := stateInstance.StateMachineInstance().StateMap() + originalStateInst := stateMap[stateInstance.StateIDRetriedFor()] + for originalStateInst.StateIDRetriedFor() != "" { + originalStateInst = stateMap[stateInstance.StateIDRetriedFor()] + } + return originalStateInst +} + +func (s *StateLogStore) findOutOriginalStateInstanceOfCompensateState(stateInstance statelang.StateInstance) statelang.StateInstance { + stateMap := stateInstance.StateMachineInstance().StateMap() + originalStateInst := stateMap[stateInstance.StateIDCompensatedFor()] + for originalStateInst.StateIDRetriedFor() != "" { + originalStateInst = stateMap[stateInstance.StateIDRetriedFor()] + } + return originalStateInst +} + func (s *StateLogStore) GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateMachineInstance, error) { stateMachineInstance, err := SelectOne(s.db, s.getStateMachineInstanceByIdSql, scanRowsToStateMachineInstance, stateMachineInstanceId) @@ -437,7 +758,7 @@ func (s *StateLogStore) GetStateInstanceListByMachineInstanceId(stateMachineInst if lastStateInstance.EndTime().IsZero() { lastStateInstance.SetStatus(statelang.RU) } - //TODO add forward and compensate logic + originStateMap := make(map[string]statelang.StateInstance) compensatedStateMap := make(map[string]statelang.StateInstance) retriedStateMap := make(map[string]statelang.StateInstance) @@ -522,6 +843,11 @@ func (s *StateLogStore) SetSeqGenerator(seqGenerator sequence.SeqGenerator) { s.seqGenerator = seqGenerator } +func (s *StateLogStore) ClearUp(context core.ProcessContext) { + context.RemoveVariable(constant2.XidKey) + context.RemoveVariable(constant2.BranchTypeKey) +} + func execStateMachineInstanceStatementForInsert(obj statelang.StateMachineInstance, stmt *sql.Stmt) (int64, error) { result, err := stmt.Exec( obj.ID(), diff --git a/pkg/saga/statemachine/store/db/statelog_test.go b/pkg/saga/statemachine/store/db/statelog_test.go index fb15b8cb..1445f2e9 100644 --- a/pkg/saga/statemachine/store/db/statelog_test.go +++ b/pkg/saga/statemachine/store/db/statelog_test.go @@ -57,6 +57,12 @@ func mockMachineInstance(stateMachineName string) statelang.StateMachineInstance return inst } +func mockStateMachineConfig(context core.ProcessContext) core.StateMachineConfig { + cfg := core.NewDefaultStateMachineConfig() + context.SetVariable(constant.VarNameStateMachineConfig, cfg) + return cfg +} + func TestStateLogStore_RecordStateMachineStarted(t *testing.T) { prepareDB() @@ -65,6 +71,7 @@ func TestStateLogStore_RecordStateMachineStarted(t *testing.T) { expected := mockMachineInstance(stateMachineName) expected.SetBusinessKey("test_started") ctx := mockProcessContext(stateMachineName, expected) + mockStateMachineConfig(ctx) err := stateLogStore.RecordStateMachineStarted(context.Background(), expected, ctx) assert.Nil(t, err) actual, err := stateLogStore.GetStateMachineInstance(expected.ID()) diff --git a/pkg/saga/statemachine/store/repository/state_log_repository.go b/pkg/saga/statemachine/store/repository/state_log_repository.go new file mode 100644 index 00000000..a00f57bf --- /dev/null +++ b/pkg/saga/statemachine/store/repository/state_log_repository.go @@ -0,0 +1,18 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. + */ + +package repository diff --git a/pkg/saga/statemachine/store/repository/state_machine_repository.go b/pkg/saga/statemachine/store/repository/state_machine_repository.go index ea37171b..da31406e 100644 --- a/pkg/saga/statemachine/store/repository/state_machine_repository.go +++ b/pkg/saga/statemachine/store/repository/state_machine_repository.go @@ -18,34 +18,221 @@ package repository import ( - "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "io" + "sync" + "time" + + "github.com/seata/seata-go/pkg/saga/statemachine/constant" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser" + "github.com/seata/seata-go/pkg/saga/statemachine/store/db" + "github.com/seata/seata-go/pkg/util/log" +) + +const ( + DefaultJsonParser = "fastjson" +) + +var ( + stateMachineRepositoryImpl *StateMachineRepositoryImpl + onceStateMachineRepositoryImpl sync.Once ) type StateMachineRepositoryImpl struct { + stateMachineMapById map[string]statelang.StateMachine + stateMachineMapByNameAndTenant map[string]statelang.StateMachine + + stateLangStore *db.StateLangStore + seqGenerator sequence.SeqGenerator + defaultTenantId string + jsonParserName string + charset string + mutex *sync.Mutex +} + +func GetStateMachineRepositoryImpl() *StateMachineRepositoryImpl { + if stateMachineRepositoryImpl == nil { + onceStateMachineRepositoryImpl.Do(func() { + //TODO get charset by config + //TODO charset is not use + //TODO using json parser + stateMachineRepositoryImpl = &StateMachineRepositoryImpl{ + stateMachineMapById: make(map[string]statelang.StateMachine), + stateMachineMapByNameAndTenant: make(map[string]statelang.StateMachine), + seqGenerator: sequence.NewUUIDSeqGenerator(), + jsonParserName: DefaultJsonParser, + charset: "UTF-8", + mutex: &sync.Mutex{}, + } + }) + } + + return stateMachineRepositoryImpl +} + +func (s *StateMachineRepositoryImpl) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) { + stateMachine := s.stateMachineMapById[stateMachineId] + if stateMachine == nil && s.stateLangStore != nil { + s.mutex.Lock() + defer s.mutex.Unlock() + + stateMachine = s.stateMachineMapById[stateMachineId] + if stateMachine == nil { + oldStateMachine, err := s.stateLangStore.GetStateMachineById(stateMachineId) + if err != nil { + return oldStateMachine, err + } + + parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content()) + if err != nil { + return oldStateMachine, err + } + + oldStateMachine.SetStartState(parseStatMachine.StartState()) + for key, val := range parseStatMachine.States() { + oldStateMachine.States()[key] = val + } + + s.stateMachineMapById[stateMachineId] = oldStateMachine + s.stateMachineMapByNameAndTenant[oldStateMachine.Name()+"_"+oldStateMachine.TenantId()] = oldStateMachine + return oldStateMachine, nil + } + } + return stateMachine, nil +} + +func (s *StateMachineRepositoryImpl) GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error) { + return s.GetLastVersionStateMachine(stateMachineName, tenantId) +} + +func (s *StateMachineRepositoryImpl) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) { + key := stateMachineName + "_" + tenantId + stateMachine := s.stateMachineMapByNameAndTenant[key] + if stateMachine == nil && s.stateLangStore != nil { + s.mutex.Lock() + defer s.mutex.Unlock() + + stateMachine = s.stateMachineMapById[key] + if stateMachine == nil { + oldStateMachine, err := s.stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId) + if err != nil { + return oldStateMachine, err + } + + parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content()) + if err != nil { + return oldStateMachine, err + } + + oldStateMachine.SetStartState(parseStatMachine.StartState()) + for key, val := range parseStatMachine.States() { + oldStateMachine.States()[key] = val + } + + s.stateMachineMapById[oldStateMachine.ID()] = oldStateMachine + s.stateMachineMapByNameAndTenant[key] = oldStateMachine + return oldStateMachine, nil + } + } + return stateMachine, nil +} + +func (s *StateMachineRepositoryImpl) RegistryStateMachine(machine statelang.StateMachine) error { + stateMachineName := machine.Name() + tenantId := machine.TenantId() + + if s.stateLangStore != nil { + oldStateMachine, err := s.stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId) + if err != nil { + return err + } + + if oldStateMachine != nil { + if oldStateMachine.Content() == machine.Content() && machine.Version() != "" && machine.Version() == oldStateMachine.Version() { + log.Debugf("StateMachine[%s] is already exist a same version", stateMachineName) + machine.SetID(oldStateMachine.ID()) + machine.SetCreateTime(oldStateMachine.CreateTime()) + + s.stateMachineMapById[machine.ID()] = machine + s.stateMachineMapByNameAndTenant[machine.Name()+"_"+machine.TenantId()] = machine + return nil + } + } + + if machine.ID() == "" { + machine.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachine, "")) + } + + machine.SetCreateTime(time.Now()) + + err = s.stateLangStore.StoreStateMachine(machine) + if err != nil { + return err + } + } + + if machine.ID() == "" { + machine.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachine, "")) + } + + s.stateMachineMapById[machine.ID()] = machine + s.stateMachineMapByNameAndTenant[machine.Name()+"_"+machine.TenantId()] = machine + return nil +} + +func (s *StateMachineRepositoryImpl) RegistryStateMachineByReader(reader io.Reader) error { + jsonByte, err := io.ReadAll(reader) + if err != nil { + return err + } + + json := string(jsonByte) + parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(json) + if err != nil { + return err + } + + if parseStatMachine == nil { + return nil + } + + parseStatMachine.SetContent(json) + s.RegistryStateMachine(parseStatMachine) + + log.Debugf("===== StateMachine Loaded: %s", json) + + return nil +} + +func (s *StateMachineRepositoryImpl) SetStateLangStore(stateLangStore *db.StateLangStore) { + s.stateLangStore = stateLangStore +} + +func (s *StateMachineRepositoryImpl) SetSeqGenerator(seqGenerator sequence.SeqGenerator) { + s.seqGenerator = seqGenerator +} + +func (s *StateMachineRepositoryImpl) SetCharset(charset string) { + s.charset = charset } -func (s StateMachineRepositoryImpl) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) { - //TODO implement me - panic("implement me") +func (s *StateMachineRepositoryImpl) GetCharset() string { + return s.charset } -func (s StateMachineRepositoryImpl) GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error) { - //TODO implement me - panic("implement me") +func (s *StateMachineRepositoryImpl) SetDefaultTenantId(defaultTenantId string) { + s.defaultTenantId = defaultTenantId } -func (s StateMachineRepositoryImpl) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) { - //TODO implement me - panic("implement me") +func (s *StateMachineRepositoryImpl) GetDefaultTenantId() string { + return s.defaultTenantId } -func (s StateMachineRepositoryImpl) RegistryStateMachine(machine statelang.StateMachine) error { - //TODO implement me - panic("implement me") +func (s *StateMachineRepositoryImpl) SetJsonParserName(jsonParserName string) { + s.jsonParserName = jsonParserName } -func (s StateMachineRepositoryImpl) RegistryStateMachineByReader(reader io.Reader) error { - //TODO implement me - panic("implement me") +func (s *StateMachineRepositoryImpl) GetJsonParserName() string { + return s.jsonParserName } diff --git a/pkg/saga/statemachine/store/repository/state_machine_repository_test.go b/pkg/saga/statemachine/store/repository/state_machine_repository_test.go new file mode 100644 index 00000000..cfd9a12a --- /dev/null +++ b/pkg/saga/statemachine/store/repository/state_machine_repository_test.go @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package repository + +import ( + "database/sql" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser" + "os" + "sync" + "testing" + "time" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + "github.com/seata/seata-go/pkg/saga/statemachine/store/db" +) + +var ( + oncePrepareDB sync.Once + testdb *sql.DB +) + +func prepareDB() { + oncePrepareDB.Do(func() { + var err error + testdb, err = sql.Open("sqlite3", ":memory:") + query_, err := os.ReadFile("../../../../../testdata/sql/saga/sqlite_init.sql") + initScript := string(query_) + if err != nil { + panic(err) + } + if _, err := testdb.Exec(initScript); err != nil { + panic(err) + } + }) +} + +func loadStateMachineByYaml() string { + query, _ := os.ReadFile("../../../../../testdata/saga/statelang/simple_statemachine.json") + return string(query) +} + +func TestStateMachineInMemory(t *testing.T) { + const stateMachineId, stateMachineName, tenantId = "simpleStateMachine", "simpleStateMachine", "test" + stateMachine := statelang.NewStateMachineImpl() + stateMachine.SetID(stateMachineId) + stateMachine.SetName(stateMachineName) + stateMachine.SetTenantId(tenantId) + stateMachine.SetComment("This is a test state machine") + stateMachine.SetCreateTime(time.Now()) + + repository := GetStateMachineRepositoryImpl() + + err := repository.RegistryStateMachine(stateMachine) + assert.Nil(t, err) + + machineById, err := repository.GetStateMachineById(stateMachine.ID()) + assert.Nil(t, err) + assert.Equal(t, stateMachine.Name(), machineById.Name()) + assert.Equal(t, stateMachine.TenantId(), machineById.TenantId()) + assert.Equal(t, stateMachine.Comment(), machineById.Comment()) + assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) + + machineByNameAndTenantId, err := repository.GetLastVersionStateMachine(stateMachine.Name(), stateMachine.TenantId()) + assert.Nil(t, err) + assert.Equal(t, stateMachine.ID(), machineByNameAndTenantId.ID()) + assert.Equal(t, stateMachine.Comment(), machineById.Comment()) + assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) +} + +func TestStateMachineInDb(t *testing.T) { + prepareDB() + + const tenantId = "test" + yaml := loadStateMachineByYaml() + stateMachine, err := parser.NewJSONStateMachineParser().Parse(yaml) + assert.Nil(t, err) + stateMachine.SetTenantId(tenantId) + stateMachine.SetContent(yaml) + + repository := GetStateMachineRepositoryImpl() + repository.SetStateLangStore(db.NewStateLangStore(testdb, "seata_")) + + err = repository.RegistryStateMachine(stateMachine) + assert.Nil(t, err) + + repository.stateMachineMapById[stateMachine.ID()] = nil + machineById, err := repository.GetStateMachineById(stateMachine.ID()) + assert.Nil(t, err) + assert.Equal(t, stateMachine.Name(), machineById.Name()) + assert.Equal(t, stateMachine.TenantId(), machineById.TenantId()) + assert.Equal(t, stateMachine.Comment(), machineById.Comment()) + assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) + + repository.stateMachineMapByNameAndTenant[stateMachine.Name()+"_"+stateMachine.TenantId()] = nil + machineByNameAndTenantId, err := repository.GetLastVersionStateMachine(stateMachine.Name(), stateMachine.TenantId()) + assert.Nil(t, err) + assert.Equal(t, stateMachine.ID(), machineByNameAndTenantId.ID()) + assert.Equal(t, stateMachine.Comment(), machineById.Comment()) + assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) +} diff --git a/pkg/tm/global_transaction.go b/pkg/tm/global_transaction.go index 8ee554ee..b8bb28cf 100644 --- a/pkg/tm/global_transaction.go +++ b/pkg/tm/global_transaction.go @@ -151,3 +151,28 @@ func (g *GlobalTransactionManager) Rollback(ctx context.Context, gtr *GlobalTran return nil } + +// GlobalReport Global report. +func (g *GlobalTransactionManager) GlobalReport(ctx context.Context, gtr *GlobalTransaction) (message.GlobalStatus, error) { + if gtr.Xid == "" { + return message.GlobalStatusUnKnown, fmt.Errorf("GlobalReport xid should not be empty") + } + + req := message.GlobalReportRequest{ + AbstractGlobalEndRequest: message.AbstractGlobalEndRequest{ + Xid: gtr.Xid, + }, + GlobalStatus: gtr.TxStatus, + } + res, err := getty.GetGettyRemotingClient().SendSyncRequest(req) + if err != nil { + log.Errorf("GlobalBeginRequest error %v", err) + return message.GlobalStatusUnKnown, err + } + if res == nil || res.(message.GlobalReportResponse).ResultCode == message.ResultCodeFailed { + log.Errorf("GlobalReportRequest result is empty or result code is failed, res %v", res) + return message.GlobalStatusUnKnown, fmt.Errorf("GlobalReportRequest result is empty or result code is failed.") + } + log.Infof("GlobalReportRequest success, res %v", res) + return res.(message.GlobalReportResponse).GlobalStatus, nil +} --------------------------------------------------------------------- To unsubscribe, e-mail: notifications-unsubscr...@seata.apache.org For additional commands, e-mail: notifications-h...@seata.apache.org