This is an automated email from the ASF dual-hosted git repository. yixia pushed a commit to branch feature/saga in repository https://gitbox.apache.org/repos/asf/incubator-seata-go.git
The following commit(s) were added to refs/heads/feature/saga by this push: new 33ff4e59 feature: add saga persistence layer (#649) 33ff4e59 is described below commit 33ff4e59cc15f98da6494421fb10ec5bc2502bc3 Author: Xiangkun Yin <32592585+pt...@users.noreply.github.com> AuthorDate: Thu Feb 22 10:31:58 2024 +0800 feature: add saga persistence layer (#649) --- go.mod | 1 + go.sum | 2 + pkg/remoting/loadbalance/loadbalance.go | 2 + .../loadbalance/round_robin_loadbalance.go | 69 ++ .../loadbalance/round_robin_loadbalance_test.go | 100 +++ pkg/saga/statemachine/constant/constant.go | 1 + .../engine/process_ctrl/process_context.go | 6 + .../statemachine/engine/serializer/serializer.go | 42 ++ .../engine/serializer/serializer_test.go | 17 + pkg/saga/statemachine/engine/store/db/db.go | 102 +++ pkg/saga/statemachine/engine/store/db/db_test.go | 28 + pkg/saga/statemachine/engine/store/db/statelang.go | 114 ++++ .../statemachine/engine/store/db/statelang_test.go | 73 ++ pkg/saga/statemachine/engine/store/db/statelog.go | 741 +++++++++++++++++++++ .../statemachine/engine/store/db/statelog_test.go | 279 ++++++++ .../engine/store/statemachine_store.go | 6 +- pkg/saga/statemachine/engine/utils.go | 2 +- pkg/saga/statemachine/statelang/state_instance.go | 10 + .../statelang/statemachine_instance.go | 2 +- testdata/sql/saga/sqlite_init.sql | 81 +++ 20 files changed, 1673 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 29267878..62144e81 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( require ( github.com/agiledragon/gomonkey v2.0.2+incompatible github.com/agiledragon/gomonkey/v2 v2.9.0 + github.com/mattn/go-sqlite3 v1.14.19 ) require ( diff --git a/go.sum b/go.sum index 9bfa347d..9c0fbc0b 100644 --- a/go.sum +++ b/go.sum @@ -515,6 +515,8 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI= +github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= diff --git a/pkg/remoting/loadbalance/loadbalance.go b/pkg/remoting/loadbalance/loadbalance.go index c5ddb679..8451e170 100644 --- a/pkg/remoting/loadbalance/loadbalance.go +++ b/pkg/remoting/loadbalance/loadbalance.go @@ -37,6 +37,8 @@ func Select(loadBalanceType string, sessions *sync.Map, xid string) getty.Sessio return RandomLoadBalance(sessions, xid) case xidLoadBalance: return XidLoadBalance(sessions, xid) + case roundRobinLoadBalance: + return RoundRobinLoadBalance(sessions, xid) default: return RandomLoadBalance(sessions, xid) } diff --git a/pkg/remoting/loadbalance/round_robin_loadbalance.go b/pkg/remoting/loadbalance/round_robin_loadbalance.go new file mode 100644 index 00000000..9cebc926 --- /dev/null +++ b/pkg/remoting/loadbalance/round_robin_loadbalance.go @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "math" + "sort" + "sync" + "sync/atomic" + + getty "github.com/apache/dubbo-getty" +) + +var sequence int32 + +func RoundRobinLoadBalance(sessions *sync.Map, s string) getty.Session { + // collect sync.Map adderToSession + // filter out closed session instance + adderToSession := make(map[string]getty.Session, 0) + // map has no sequence, we should sort it to make sure the sequence is always the same + adders := make([]string, 0) + sessions.Range(func(key, value interface{}) bool { + session := key.(getty.Session) + if session.IsClosed() { + sessions.Delete(key) + } else { + adderToSession[session.RemoteAddr()] = session + adders = append(adders, session.RemoteAddr()) + } + return true + }) + sort.Strings(adders) + // adderToSession eq 0 means there are no available session + if len(adderToSession) == 0 { + return nil + } + index := getPositiveSequence() % len(adderToSession) + return adderToSession[adders[index]] +} + +func getPositiveSequence() int { + for { + current := atomic.LoadInt32(&sequence) + var next int32 + if current == math.MaxInt32 { + next = 0 + } else { + next = current + 1 + } + if atomic.CompareAndSwapInt32(&sequence, current, next) { + return int(current) + } + } +} diff --git a/pkg/remoting/loadbalance/round_robin_loadbalance_test.go b/pkg/remoting/loadbalance/round_robin_loadbalance_test.go new file mode 100644 index 00000000..c5826570 --- /dev/null +++ b/pkg/remoting/loadbalance/round_robin_loadbalance_test.go @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "fmt" + "math" + "sync" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/remoting/mock" +) + +func TestRoundRobinLoadBalance_Normal(t *testing.T) { + ctrl := gomock.NewController(t) + sessions := &sync.Map{} + + for i := 0; i < 10; i++ { + session := mock.NewMockTestSession(ctrl) + session.EXPECT().IsClosed().Return(i == 2).AnyTimes() + session.EXPECT().RemoteAddr().Return(fmt.Sprintf("%d", i)).AnyTimes() + sessions.Store(session, fmt.Sprintf("session-%d", i+1)) + } + + for i := 0; i < 10; i++ { + if i == 2 { + continue + } + result := RoundRobinLoadBalance(sessions, "some_xid") + assert.Equal(t, fmt.Sprintf("%d", i), result.RemoteAddr()) + assert.NotNil(t, result) + assert.False(t, result.IsClosed()) + } +} + +func TestRoundRobinLoadBalance_OverSequence(t *testing.T) { + ctrl := gomock.NewController(t) + sessions := &sync.Map{} + sequence = math.MaxInt32 + + for i := 0; i < 10; i++ { + session := mock.NewMockTestSession(ctrl) + session.EXPECT().IsClosed().Return(false).AnyTimes() + session.EXPECT().RemoteAddr().Return(fmt.Sprintf("%d", i)).AnyTimes() + sessions.Store(session, fmt.Sprintf("session-%d", i+1)) + } + + for i := 0; i < 10; i++ { + // over sequence here + if i == 0 { + result := RoundRobinLoadBalance(sessions, "some_xid") + assert.Equal(t, "7", result.RemoteAddr()) + assert.NotNil(t, result) + assert.False(t, result.IsClosed()) + continue + } + result := RoundRobinLoadBalance(sessions, "some_xid") + assert.Equal(t, fmt.Sprintf("%d", i-1), result.RemoteAddr()) + assert.NotNil(t, result) + assert.False(t, result.IsClosed()) + } +} + +func TestRoundRobinLoadBalance_All_Closed(t *testing.T) { + ctrl := gomock.NewController(t) + sessions := &sync.Map{} + for i := 0; i < 10; i++ { + session := mock.NewMockTestSession(ctrl) + session.EXPECT().IsClosed().Return(true).AnyTimes() + sessions.Store(session, fmt.Sprintf("session-%d", i+1)) + } + if result := RoundRobinLoadBalance(sessions, "some_xid"); result != nil { + t.Errorf("Expected nil, actual got %+v", result) + } +} + +func TestRoundRobinLoadBalance_Empty(t *testing.T) { + sessions := &sync.Map{} + if result := RoundRobinLoadBalance(sessions, "some_xid"); result != nil { + t.Errorf("Expected nil, actual got %+v", result) + } +} diff --git a/pkg/saga/statemachine/constant/constant.go b/pkg/saga/statemachine/constant/constant.go index f9e1eeba..ef332c18 100644 --- a/pkg/saga/statemachine/constant/constant.go +++ b/pkg/saga/statemachine/constant/constant.go @@ -13,6 +13,7 @@ const ( VarNameIsAsyncExecution string = "_is_async_execution_" VarNameStateInst string = "_current_state_instance_" SeqEntityStateMachineInst string = "STATE_MACHINE_INST" + SeqEntityStateInst string = "STATE_INST" VarNameBusinesskey string = "_business_key_" VarNameParentId string = "_parent_id_" StateTypeServiceTask string = "ServiceTask" diff --git a/pkg/saga/statemachine/engine/process_ctrl/process_context.go b/pkg/saga/statemachine/engine/process_ctrl/process_context.go index c13faa3f..7d02e470 100644 --- a/pkg/saga/statemachine/engine/process_ctrl/process_context.go +++ b/pkg/saga/statemachine/engine/process_ctrl/process_context.go @@ -47,6 +47,12 @@ type ProcessContextImpl struct { instruction Instruction } +func NewProcessContextImpl() *ProcessContextImpl { + return &ProcessContextImpl{ + mp: make(map[string]interface{}), + } +} + func (p *ProcessContextImpl) GetVariable(name string) interface{} { p.mu.RLock() defer p.mu.RUnlock() diff --git a/pkg/saga/statemachine/engine/serializer/serializer.go b/pkg/saga/statemachine/engine/serializer/serializer.go new file mode 100644 index 00000000..4d668ef5 --- /dev/null +++ b/pkg/saga/statemachine/engine/serializer/serializer.go @@ -0,0 +1,42 @@ +package serializer + +import ( + "bytes" + "encoding/gob" + "encoding/json" + "github.com/pkg/errors" +) + +type ParamsSerializer struct{} + +func (ParamsSerializer) Serialize(object any) (string, error) { + result, err := json.Marshal(object) + return string(result), err +} + +func (ParamsSerializer) Deserialize(object string) (any, error) { + var result any + err := json.Unmarshal([]byte(object), &result) + return result, err +} + +type ErrorSerializer struct{} + +func (ErrorSerializer) Serialize(object error) ([]byte, error) { + var buffer bytes.Buffer + encoder := gob.NewEncoder(&buffer) + if object != nil { + err := encoder.Encode(object.Error()) + return buffer.Bytes(), err + } + return nil, nil +} + +func (ErrorSerializer) Deserialize(object []byte) (error, error) { + var errorMsg string + buffer := bytes.NewReader(object) + encoder := gob.NewDecoder(buffer) + err := encoder.Decode(&errorMsg) + + return errors.New(errorMsg), err +} diff --git a/pkg/saga/statemachine/engine/serializer/serializer_test.go b/pkg/saga/statemachine/engine/serializer/serializer_test.go new file mode 100644 index 00000000..9dba88a0 --- /dev/null +++ b/pkg/saga/statemachine/engine/serializer/serializer_test.go @@ -0,0 +1,17 @@ +package serializer + +import ( + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestErrorSerializer(t *testing.T) { + serializer := ErrorSerializer{} + expected := errors.New("This is a test error") + serialized, err := serializer.Serialize(expected) + assert.Nil(t, err) + actual, err := serializer.Deserialize(serialized) + assert.Nil(t, err) + assert.Equal(t, expected.Error(), actual.Error()) +} diff --git a/pkg/saga/statemachine/engine/store/db/db.go b/pkg/saga/statemachine/engine/store/db/db.go new file mode 100644 index 00000000..39d62e97 --- /dev/null +++ b/pkg/saga/statemachine/engine/store/db/db.go @@ -0,0 +1,102 @@ +package db + +import ( + "database/sql" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/util/log" +) + +const TimeLayout = "2006-01-02 15:04:05.999999999-07:00" + +type ExecStatement[T any] func(obj T, stmt *sql.Stmt) (int64, error) + +type ScanRows[T any] func(rows *sql.Rows) (T, error) + +type Store struct { + db *sql.DB +} + +func SelectOne[T any](db *sql.DB, sql string, fn ScanRows[T], args ...any) (T, error) { + var result T + log.Debugf("Preparing SQL: %s", sql) + stmt, err := db.Prepare(sql) + defer stmt.Close() + if err != nil { + return result, err + } + + log.Debugf("setting params to Stmt: %v", args) + rows, err := stmt.Query(args...) + defer rows.Close() + if err != nil { + return result, nil + } + + if rows.Next() { + return fn(rows) + } + return result, errors.New("no target selected") +} + +func SelectList[T any](db *sql.DB, sql string, fn ScanRows[T], args ...any) ([]T, error) { + result := make([]T, 0) + + log.Debugf("Preparing SQL: %s", sql) + stmt, err := db.Prepare(sql) + defer stmt.Close() + if err != nil { + return result, err + } + + log.Debugf("setting params to Stmt: %v", args) + rows, err := stmt.Query(args...) + defer rows.Close() + if err != nil { + return result, err + } + + for rows.Next() { + obj, err := fn(rows) + if err != nil { + return result, err + } + result = append(result, obj) + } + + return result, nil +} + +func ExecuteUpdate[T any](db *sql.DB, sql string, fn ExecStatement[T], obj T) (int64, error) { + log.Debugf("Preparing SQL: %s", sql) + stmt, err := db.Prepare(sql) + defer stmt.Close() + if err != nil { + return 0, err + } + + log.Debugf("setting params to Stmt: %v", obj) + + rowsAffected, err := fn(obj, stmt) + if err != nil { + return rowsAffected, err + } + + return rowsAffected, nil +} + +func ExecuteUpdateArgs(db *sql.DB, sql string, args ...any) (int64, error) { + log.Debugf("Preparing SQL: %s", sql) + stmt, err := db.Prepare(sql) + defer stmt.Close() + if err != nil { + return 0, err + } + + log.Debugf("setting params to Stmt: %v", args) + + result, err := stmt.Exec(args...) + if err != nil { + return 0, err + } + return result.RowsAffected() +} diff --git a/pkg/saga/statemachine/engine/store/db/db_test.go b/pkg/saga/statemachine/engine/store/db/db_test.go new file mode 100644 index 00000000..a5e3edc1 --- /dev/null +++ b/pkg/saga/statemachine/engine/store/db/db_test.go @@ -0,0 +1,28 @@ +package db + +import ( + "database/sql" + "os" + "sync" +) + +var ( + oncePrepareDB sync.Once + db *sql.DB +) + +func prepareDB() { + oncePrepareDB.Do(func() { + var err error + db, err = sql.Open("sqlite3", ":memory:") + query_, err := os.ReadFile("testdata/sql/saga/sqlite_init.sql") + initScript := string(query_) + if err != nil { + panic(err) + } + if _, err := db.Exec(initScript); err != nil { + panic(err) + } + }) + +} diff --git a/pkg/saga/statemachine/engine/store/db/statelang.go b/pkg/saga/statemachine/engine/store/db/statelang.go new file mode 100644 index 00000000..e49bf96a --- /dev/null +++ b/pkg/saga/statemachine/engine/store/db/statelang.go @@ -0,0 +1,114 @@ +package db + +import ( + "database/sql" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + "regexp" + "time" +) + +const ( + StateMachineFields = "id, tenant_id, app_name, name, status, gmt_create, ver, type, content, recover_strategy, comment_" + GetStateMachineByIdSql = "SELECT " + StateMachineFields + " FROM ${TABLE_PREFIX}state_machine_def WHERE id = ?" + QueryStateMachinesByNameAndTenantSql = "SELECT " + StateMachineFields + " FROM ${TABLE_PREFIX}state_machine_def WHERE name = ? AND tenant_id = ? ORDER BY gmt_create DESC" + InsertStateMachineSql = "INSERT INTO ${TABLE_PREFIX}state_machine_def (" + StateMachineFields + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + TablePrefix = "\\$\\{TABLE_PREFIX}" +) + +type StateLangStore struct { + Store + tablePrefix string + getStateMachineByIdSql string + queryStateMachinesByNameAndTenantSql string + insertStateMachineSql string +} + +func NewStateLangStore(db *sql.DB, tablePrefix string) *StateLangStore { + r := regexp.MustCompile(TablePrefix) + + stateLangStore := &StateLangStore{ + Store: Store{db}, + tablePrefix: tablePrefix, + getStateMachineByIdSql: r.ReplaceAllString(GetStateMachineByIdSql, tablePrefix), + queryStateMachinesByNameAndTenantSql: r.ReplaceAllString(QueryStateMachinesByNameAndTenantSql, tablePrefix), + insertStateMachineSql: r.ReplaceAllString(InsertStateMachineSql, tablePrefix), + } + + return stateLangStore +} + +func (s *StateLangStore) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) { + return SelectOne(s.db, s.getStateMachineByIdSql, scanRowsToStateMachine, stateMachineId) +} + +func (s *StateLangStore) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) { + stateMachineList, err := SelectList(s.db, s.queryStateMachinesByNameAndTenantSql, scanRowsToStateMachine, stateMachineName, tenantId) + if err != nil { + return nil, err + } + + if len(stateMachineList) > 0 { + return stateMachineList[0], nil + } + return nil, nil +} + +func (s *StateLangStore) StoreStateMachine(stateMachine statelang.StateMachine) error { + rows, err := ExecuteUpdate(s.db, s.insertStateMachineSql, execStateMachineStatement, stateMachine) + if err != nil { + return err + } + if rows <= 0 { + return errors.New("affected rows is smaller than 0") + } + + return nil +} + +func scanRowsToStateMachine(rows *sql.Rows) (statelang.StateMachine, error) { + stateMachine := statelang.NewStateMachineImpl() + //var id, name, comment, version, appName, content, t, recoverStrategy, tenantId, status string + var id, tenantId, appName, name, status, created, version, t, content, recoverStrategy, comment string + //var created int64 + err := rows.Scan(&id, &tenantId, &appName, &name, &status, &created, &version, &t, &content, &recoverStrategy, &comment) + if err != nil { + return stateMachine, err + } + stateMachine.SetID(id) + stateMachine.SetName(name) + stateMachine.SetComment(comment) + stateMachine.SetVersion(version) + stateMachine.SetAppName(appName) + stateMachine.SetContent(content) + createdTime, _ := time.Parse(TimeLayout, created) + stateMachine.SetCreateTime(createdTime) + stateMachine.SetType(t) + if recoverStrategy != "" { + stateMachine.SetRecoverStrategy(statelang.RecoverStrategy(recoverStrategy)) + } + stateMachine.SetTenantId(t) + stateMachine.SetStatus(statelang.StateMachineStatus(status)) + return stateMachine, nil +} + +func execStateMachineStatement(obj statelang.StateMachine, stmt *sql.Stmt) (int64, error) { + result, err := stmt.Exec( + obj.ID(), + obj.TenantId(), + obj.AppName(), + obj.Name(), + obj.Status(), + obj.CreateTime(), + obj.Version(), + obj.Type(), + obj.Content(), + obj.RecoverStrategy(), + obj.Comment(), + ) + if err != nil { + return 0, err + } + rowsAffected, err := result.RowsAffected() + return rowsAffected, err +} diff --git a/pkg/saga/statemachine/engine/store/db/statelang_test.go b/pkg/saga/statemachine/engine/store/db/statelang_test.go new file mode 100644 index 00000000..87f868d8 --- /dev/null +++ b/pkg/saga/statemachine/engine/store/db/statelang_test.go @@ -0,0 +1,73 @@ +package db + +import ( + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + "github.com/stretchr/testify/assert" + "testing" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +func TestStoreAndGetStateMachine(t *testing.T) { + prepareDB() + + stateLangStore := NewStateLangStore(db, "seata_") + const stateMachineId = "simpleStateMachine" + expected := statelang.NewStateMachineImpl() + expected.SetID(stateMachineId) + expected.SetName("simpleStateMachine") + expected.SetComment("This is a test state machine") + expected.SetCreateTime(time.Now()) + err := stateLangStore.StoreStateMachine(expected) + if err != nil { + t.Error(err) + return + } + + actual, err := stateLangStore.GetStateMachineById(stateMachineId) + if err != nil { + t.Error(err) + return + } + assert.Equal(t, expected.ID(), actual.ID()) + assert.Equal(t, expected.Name(), actual.Name()) + assert.Equal(t, expected.Comment(), actual.Comment()) + assert.Equal(t, expected.CreateTime().UnixNano(), actual.CreateTime().UnixNano()) +} + +func TestStoreAndGetLastVersionStateMachine(t *testing.T) { + prepareDB() + + stateLangStore := NewStateLangStore(db, "seata_") + const stateMachineName, tenantId = "simpleStateMachine", "test" + stateMachineV1 := statelang.NewStateMachineImpl() + stateMachineV1.SetID("simpleStateMachineV1") + stateMachineV1.SetName(stateMachineName) + stateMachineV1.SetTenantId(tenantId) + stateMachineV1.SetCreateTime(time.Now().Add(time.Duration(-1) * time.Millisecond)) + + stateMachineV2 := statelang.NewStateMachineImpl() + stateMachineV2.SetID("simpleStateMachineV2") + stateMachineV2.SetName(stateMachineName) + stateMachineV2.SetTenantId(tenantId) + stateMachineV2.SetCreateTime(time.Now()) + + err := stateLangStore.StoreStateMachine(stateMachineV1) + if err != nil { + t.Error(err) + return + } + err = stateLangStore.StoreStateMachine(stateMachineV2) + if err != nil { + t.Error(err) + return + } + + actual, err := stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId) + if err != nil { + t.Error(err) + return + } + assert.Equal(t, stateMachineV2.ID(), actual.ID()) +} diff --git a/pkg/saga/statemachine/engine/store/db/statelog.go b/pkg/saga/statemachine/engine/store/db/statelog.go new file mode 100644 index 00000000..26b89ed6 --- /dev/null +++ b/pkg/saga/statemachine/engine/store/db/statelog.go @@ -0,0 +1,741 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/process_ctrl" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/serializer" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + "github.com/seata/seata-go/pkg/util/log" + "regexp" + "strconv" + "strings" + "time" +) + +const ( + StateMachineInstanceFields = "id, machine_id, tenant_id, parent_id, business_key, gmt_started, gmt_end, status, compensation_status, is_running, gmt_updated, start_params, end_params, excep" + StateMachineInstanceFieldsWithoutParams = "id, machine_id, tenant_id, parent_id, business_key, gmt_started, gmt_end, status, compensation_status, is_running, gmt_updated" + RecordStateMachineStartedSql = "INSERT INTO ${TABLE_PREFIX}state_machine_inst\n(id, machine_id, tenant_id, parent_id, gmt_started, business_key, start_params, is_running, status, gmt_updated)\n VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + RecordStateMachineFinishedSql = "UPDATE ${TABLE_PREFIX}state_machine_inst SET gmt_end = ?, excep = ?, end_params = ?,status = ?, compensation_status = ?, is_running = ?, gmt_updated = ? WHERE id = ? and gmt_updated = ?" + UpdateStateMachineRunningStatusSql = "UPDATE ${TABLE_PREFIX}state_machine_inst SET\nis_running = ?, gmt_updated = ? where id = ? and gmt_updated = ?" + GetStateMachineInstanceByIdSql = "SELECT " + StateMachineInstanceFields + " FROM ${TABLE_PREFIX}state_machine_inst WHERE id = ?" + GetStateMachineInstanceByBusinessKeySql = "SELECT " + StateMachineInstanceFields + " FROM ${TABLE_PREFIX}state_machine_inst WHERE business_key = ? AND tenant_id = ?" + QueryStateMachineInstancesByParentIdSql = "SELECT " + StateMachineInstanceFieldsWithoutParams + " FROM ${TABLE_PREFIX}state_machine_inst WHERE parent_id = ? ORDER BY gmt_started DESC" + + StateInstanceFields = "id, machine_inst_id, name, type, business_key, gmt_started, service_name, service_method, service_type, is_for_update, status, input_params, output_params, excep, gmt_end, state_id_compensated_for, state_id_retried_for" + RecordStateStartedSql = "INSERT INTO ${TABLE_PREFIX}state_inst (id, machine_inst_id, name, type, gmt_started, service_name, service_method, service_type, is_for_update, input_params, status, business_key, state_id_compensated_for, state_id_retried_for, gmt_updated)\nVALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + RecordStateFinishedSql = "UPDATE ${TABLE_PREFIX}state_inst SET gmt_end = ?, excep = ?, status = ?, output_params = ?, gmt_updated = ? WHERE id = ? AND machine_inst_id = ?" + UpdateStateExecutionStatusSql = "UPDATE ${TABLE_PREFIX}state_inst SET status = ?, gmt_updated = ? WHERE machine_inst_id = ? AND id = ?" + QueryStateInstancesByMachineInstanceIdSql = "SELECT " + StateInstanceFields + " FROM ${TABLE_PREFIX}state_inst WHERE machine_inst_id = ? ORDER BY gmt_started, ID ASC" + GetStateInstanceByIdAndMachineInstanceIdSql = "SELECT " + StateInstanceFields + " FROM ${TABLE_PREFIX}state_inst WHERE machine_inst_id = ? AND id = ?" +) + +type StateLogStore struct { + Store + seqGenerator sequence.SeqGenerator + paramsSerializer serializer.ParamsSerializer + errorSerializer serializer.ErrorSerializer + + tablePrefix string + defaultTenantId string + + recordStateMachineStartedSql string + recordStateMachineFinishedSql string + updateStateMachineRunningStatusSql string + getStateMachineInstanceByIdSql string + getStateMachineInstanceByBusinessKeySql string + queryStateMachineInstancesByParentIdSql string + + recordStateStartedSql string + recordStateFinishedSql string + updateStateExecutionStatusSql string + queryStateInstancesByMachineInstanceIdSql string + getStateInstanceByIdAndMachineInstanceIdSql string +} + +func NewStateLogStore(db *sql.DB, tablePrefix string) *StateLogStore { + r := regexp.MustCompile(TablePrefix) + + stateLogStore := &StateLogStore{ + Store: Store{db}, + + seqGenerator: sequence.NewUUIDSeqGenerator(), + paramsSerializer: serializer.ParamsSerializer{}, + errorSerializer: serializer.ErrorSerializer{}, + + tablePrefix: tablePrefix, + defaultTenantId: "000001", + + recordStateMachineStartedSql: r.ReplaceAllString(RecordStateMachineStartedSql, tablePrefix), + recordStateMachineFinishedSql: r.ReplaceAllString(RecordStateMachineFinishedSql, tablePrefix), + updateStateMachineRunningStatusSql: r.ReplaceAllString(UpdateStateMachineRunningStatusSql, tablePrefix), + getStateMachineInstanceByIdSql: r.ReplaceAllString(GetStateMachineInstanceByIdSql, tablePrefix), + getStateMachineInstanceByBusinessKeySql: r.ReplaceAllString(GetStateMachineInstanceByBusinessKeySql, tablePrefix), + queryStateMachineInstancesByParentIdSql: r.ReplaceAllString(QueryStateMachineInstancesByParentIdSql, tablePrefix), + + recordStateStartedSql: r.ReplaceAllString(RecordStateStartedSql, tablePrefix), + recordStateFinishedSql: r.ReplaceAllString(RecordStateFinishedSql, tablePrefix), + updateStateExecutionStatusSql: r.ReplaceAllString(UpdateStateExecutionStatusSql, tablePrefix), + queryStateInstancesByMachineInstanceIdSql: r.ReplaceAllString(QueryStateInstancesByMachineInstanceIdSql, tablePrefix), + getStateInstanceByIdAndMachineInstanceIdSql: r.ReplaceAllString(GetStateInstanceByIdAndMachineInstanceIdSql, tablePrefix), + } + + return stateLogStore +} + +func (s *StateLogStore) RecordStateMachineStarted(ctx context.Context, machineInstance statelang.StateMachineInstance, + context process_ctrl.ProcessContext) error { + if machineInstance == nil { + return nil + } + + var err error + defer func() { + if err != nil { + log.Errorf("record state machine start error: %v, StateMachine %s, XID: %s", + err, machineInstance.StateMachine().Name(), machineInstance.ID()) + } + }() + + //if parentId is not null, machineInstance is a SubStateMachine, do not start a new global transaction, + //use parent transaction instead. + parentId := machineInstance.ParentID() + + if parentId == "" { + //TODO begin transaction + } + + if machineInstance.ID() == "" && s.seqGenerator != nil { + machineInstance.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachineInst, "")) + } + + //TODO bind SAGA branch type + + serializedStartParams, err := s.paramsSerializer.Serialize(machineInstance.StartParams()) + if err != nil { + return err + } + machineInstance.SetSerializedStartParams(serializedStartParams) + affected, err := ExecuteUpdate(s.db, s.recordStateMachineStartedSql, execStateMachineInstanceStatementForInsert, machineInstance) + if err != nil { + return err + } + if affected <= 0 { + return errors.New("affected rows is smaller than 0") + } + + return nil +} + +func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineInstance statelang.StateMachineInstance, + context process_ctrl.ProcessContext) error { + if machineInstance == nil { + return nil + } + + endParams := machineInstance.EndParams() + + if statelang.SU == machineInstance.Status() && machineInstance.Error() != nil { + machineInstance.SetError(nil) + } + + serializedEndParams, err := s.paramsSerializer.Serialize(endParams) + if err != nil { + return err + } + machineInstance.SetSerializedEndParams(serializedEndParams) + serializedError, err := s.errorSerializer.Serialize(machineInstance.Error()) + if err != nil { + return err + } + if len(serializedError) > 0 { + machineInstance.SetSerializedError(serializedError) + } + + affected, err := ExecuteUpdate(s.db, s.recordStateMachineFinishedSql, execStateMachineInstanceStatementForUpdate, machineInstance) + if err != nil { + return err + } + if affected <= 0 { + log.Warnf("StateMachineInstance[%s] is recovered by server, skip RecordStateMachineFinished", machineInstance.ID()) + return nil + } + + //TODO check if timeout or else report transaction finished + return nil +} + +func (s *StateLogStore) RecordStateMachineRestarted(ctx context.Context, machineInstance statelang.StateMachineInstance, + context process_ctrl.ProcessContext) error { + if machineInstance == nil { + return nil + } + updated := time.Now() + affected, err := ExecuteUpdateArgs(s.db, s.updateStateMachineRunningStatusSql, + machineInstance.IsRunning(), updated, machineInstance.ID(), machineInstance.UpdatedTime()) + if err != nil { + return err + } + if affected <= 0 { + return errors.New(fmt.Sprintf("StateMachineInstance [id:%s] is recovered by another execution, restart denied", + machineInstance.ID())) + } + machineInstance.SetUpdatedTime(updated) + return nil +} + +func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance statelang.StateInstance, + context process_ctrl.ProcessContext) error { + if stateInstance == nil { + return nil + } + isUpdateMode := s.isUpdateMode(stateInstance, context) + + if stateInstance.StateIDRetriedFor() != "" { + if isUpdateMode { + stateInstance.SetID(stateInstance.StateIDRetriedFor()) + } else { + stateInstance.SetID(s.generateRetryStateInstanceId(stateInstance)) + } + } else if stateInstance.StateIDCompensatedFor() != "" { + stateInstance.SetID(s.generateCompensateStateInstanceId(stateInstance, isUpdateMode)) + } else { + //TODO register branch + } + + if stateInstance.ID() == "" && s.seqGenerator != nil { + stateInstance.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateInst, "")) + } + + serializedParams, err := s.paramsSerializer.Serialize(stateInstance.InputParams()) + if err != nil { + return err + } + stateInstance.SetSerializedInputParams(serializedParams) + + var affected int64 + if !isUpdateMode { + affected, err = ExecuteUpdate(s.db, s.recordStateStartedSql, execStateInstanceStatementForInsert, stateInstance) + } else { + affected, err = ExecuteUpdateArgs(s.db, s.updateStateExecutionStatusSql, + stateInstance.Status(), time.Now(), stateInstance.StateMachineInstance().ID(), stateInstance.ID()) + } + + if err != nil { + return err + } + if affected <= 0 { + return errors.New("affected rows is smaller than 0") + } + return nil +} + +func (s *StateLogStore) isUpdateMode(instance statelang.StateInstance, context process_ctrl.ProcessContext) bool { + //TODO implement me, add forward logic + return false +} + +func (s *StateLogStore) generateRetryStateInstanceId(stateInstance statelang.StateInstance) string { + originalStateInstId := stateInstance.StateIDRetriedFor() + maxIndex := 1 + machineInstance := stateInstance.StateMachineInstance() + originalStateInst := machineInstance.State(originalStateInstId) + for originalStateInst.StateIDRetriedFor() != "" { + originalStateInst = machineInstance.State(originalStateInst.StateIDRetriedFor()) + idIndex := s.getIdIndex(originalStateInst.ID(), ".") + if idIndex > maxIndex { + maxIndex = idIndex + } + originalStateInstId = originalStateInst.ID() + } + return fmt.Sprintf("%s.%d", originalStateInstId, maxIndex) +} + +func (s *StateLogStore) generateCompensateStateInstanceId(stateInstance statelang.StateInstance, isUpdateMode bool) string { + originalCompensateStateInstId := stateInstance.StateIDCompensatedFor() + maxIndex := 1 + if isUpdateMode { + return fmt.Sprintf("%s-%d", originalCompensateStateInstId, maxIndex) + } + + machineInstance := stateInstance.StateMachineInstance() + for _, aStateInstance := range machineInstance.StateList() { + if aStateInstance != stateInstance && originalCompensateStateInstId == aStateInstance.StateIDCompensatedFor() { + idIndex := s.getIdIndex(aStateInstance.ID(), "-") + if idIndex > maxIndex { + maxIndex = idIndex + } + maxIndex++ + } + } + return fmt.Sprintf("%s-%d", originalCompensateStateInstId, maxIndex) +} + +func (s *StateLogStore) getIdIndex(stateInstanceId string, separator string) int { + if stateInstanceId != "" { + start := strings.LastIndex(stateInstanceId, separator) + if start > 0 { + indexStr := stateInstanceId[start+1:] + index, err := strconv.Atoi(indexStr) + if err != nil { + log.Warnf("get stateInstance id index failed %v", err) + return -1 + } + return index + } + } + return -1 +} + +func (s *StateLogStore) RecordStateFinished(ctx context.Context, stateInstance statelang.StateInstance, + context process_ctrl.ProcessContext) error { + if stateInstance == nil { + return nil + } + + serializedOutputParams, err := s.paramsSerializer.Serialize(stateInstance.OutputParams()) + if err != nil { + return err + } + stateInstance.SetSerializedOutputParams(serializedOutputParams) + + serializedError, err := s.errorSerializer.Serialize(stateInstance.Error()) + if err != nil { + return err + } + stateInstance.SetSerializedError(serializedError) + + _, err = ExecuteUpdate(s.db, s.recordStateFinishedSql, execStateInstanceStatementForUpdate, stateInstance) + if err != nil { + return err + } + + //TODO report branch + return nil + +} + +func (s *StateLogStore) GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateMachineInstance, error) { + stateMachineInstance, err := SelectOne(s.db, s.getStateMachineInstanceByIdSql, scanRowsToStateMachineInstance, + stateMachineInstanceId) + if err != nil { + return stateMachineInstance, err + } + if stateMachineInstance == nil { + return nil, nil + } + + stateInstanceList, err := s.GetStateInstanceListByMachineInstanceId(stateMachineInstanceId) + if err != nil { + return stateMachineInstance, err + } + for _, stateInstance := range stateInstanceList { + stateMachineInstance.PutState(stateInstance.ID(), stateInstance) + } + err = s.deserializeStateMachineParamsAndException(stateMachineInstance) + return stateMachineInstance, err +} + +func (s *StateLogStore) GetStateMachineInstanceByBusinessKey(businessKey string, tenantId string) (statelang.StateMachineInstance, error) { + if tenantId == "" { + tenantId = s.defaultTenantId + } + stateMachineInstance, err := SelectOne(s.db, s.getStateMachineInstanceByBusinessKeySql, + scanRowsToStateMachineInstance, businessKey, tenantId) + if err != nil || stateMachineInstance == nil { + return stateMachineInstance, err + } + stateInstanceList, err := s.GetStateInstanceListByMachineInstanceId(stateMachineInstance.ID()) + if err != nil { + return stateMachineInstance, err + } + for _, stateInstance := range stateInstanceList { + stateMachineInstance.PutState(stateInstance.ID(), stateInstance) + } + err = s.deserializeStateMachineParamsAndException(stateMachineInstance) + return stateMachineInstance, err +} + +func (s *StateLogStore) deserializeStateMachineParamsAndException(stateMachineInstance statelang.StateMachineInstance) error { + if stateMachineInstance == nil { + return nil + } + + serializedError := stateMachineInstance.SerializedError().([]byte) + if serializedError != nil && len(serializedError) > 0 { + deserializedError, err := s.errorSerializer.Deserialize(serializedError) + if err != nil { + return err + } + stateMachineInstance.SetError(deserializedError) + } + + serializedStartParams := stateMachineInstance.SerializedStartParams() + if serializedStartParams != nil && serializedStartParams != "" { + startParams, err := s.paramsSerializer.Deserialize(serializedStartParams.(string)) + if err != nil { + return err + } + stateMachineInstance.SetStartParams(startParams.(map[string]any)) + } + + serializedOutputParams := stateMachineInstance.SerializedEndParams() + if serializedOutputParams != nil && serializedOutputParams != "" { + endParams, err := s.paramsSerializer.Deserialize(serializedOutputParams.(string)) + if err != nil { + return err + } + stateMachineInstance.SetEndParams(endParams.(map[string]any)) + } + return nil +} + +func (s *StateLogStore) GetStateMachineInstanceByParentId(parentId string) ([]statelang.StateMachineInstance, error) { + return SelectList(s.db, s.queryStateMachineInstancesByParentIdSql, scanRowsToStateMachineInstance, parentId) +} + +func (s *StateLogStore) GetStateInstance(stateInstanceId string, stateMachineInstanceId string) (statelang.StateInstance, error) { + stateInstance, err := SelectOne(s.db, s.getStateInstanceByIdAndMachineInstanceIdSql, scanRowsToStateInstance, + stateMachineInstanceId, stateInstanceId) + if err != nil { + return stateInstance, err + } + err = s.deserializeStateParamsAndException(stateInstance) + return stateInstance, err +} + +func (s *StateLogStore) GetStateInstanceListByMachineInstanceId(stateMachineInstanceId string) ([]statelang.StateInstance, error) { + stateInstanceList, err := SelectList(s.db, s.queryStateInstancesByMachineInstanceIdSql, + scanRowsToStateInstance, stateMachineInstanceId) + if err != nil || len(stateInstanceList) == 0 { + return stateInstanceList, err + } + + lastStateInstance := stateInstanceList[len(stateInstanceList)-1] + if lastStateInstance.EndTime().IsZero() { + lastStateInstance.SetStatus(statelang.RU) + } + //TODO add forward and compensate logic + originStateMap := make(map[string]statelang.StateInstance) + compensatedStateMap := make(map[string]statelang.StateInstance) + retriedStateMap := make(map[string]statelang.StateInstance) + + for _, tempStateInstance := range stateInstanceList { + err := s.deserializeStateParamsAndException(tempStateInstance) + if err != nil { + return stateInstanceList, err + } + + if tempStateInstance.StateIDCompensatedFor() != "" { + s.putLastStateToMap(compensatedStateMap, tempStateInstance, tempStateInstance.StateIDCompensatedFor()) + } else { + if tempStateInstance.StateIDRetriedFor() != "" { + s.putLastStateToMap(retriedStateMap, tempStateInstance, tempStateInstance.StateIDRetriedFor()) + } + originStateMap[tempStateInstance.ID()] = tempStateInstance + } + } + + if len(compensatedStateMap) != 0 { + for _, originState := range originStateMap { + originState.SetCompensationState(compensatedStateMap[originState.ID()]) + } + } + + if len(retriedStateMap) != 0 { + for _, originState := range originStateMap { + if _, ok := retriedStateMap[originState.ID()]; ok { + originState.SetIgnoreStatus(true) + } + } + } + + return stateInstanceList, nil +} + +func (s *StateLogStore) putLastStateToMap(resultMap map[string]statelang.StateInstance, newState statelang.StateInstance, key string) { + existed, ok := resultMap[key] + if !ok { + resultMap[key] = newState + } else if newState.EndTime().After(existed.EndTime()) { + existed.SetIgnoreStatus(true) + resultMap[key] = newState + } else { + newState.SetIgnoreStatus(true) + } +} + +func (s *StateLogStore) deserializeStateParamsAndException(stateInstance statelang.StateInstance) error { + if stateInstance == nil { + return nil + } + serializedInputParams := stateInstance.SerializedInputParams() + if serializedInputParams != nil && serializedInputParams != "" { + inputParams, err := s.paramsSerializer.Deserialize(serializedInputParams.(string)) + if err != nil { + return err + } + stateInstance.SetInputParams(inputParams) + } + serializedOutputParams := stateInstance.SerializedOutputParams() + if serializedOutputParams != nil && serializedOutputParams != "" { + outputParams, err := s.paramsSerializer.Deserialize(serializedOutputParams.(string)) + if err != nil { + return err + } + stateInstance.SetOutputParams(outputParams) + } + serializedError := stateInstance.SerializedError().([]byte) + if serializedError != nil { + deserializedError, err := s.errorSerializer.Deserialize(serializedError) + if err != nil { + return err + } + stateInstance.SetError(deserializedError) + } + return nil +} + +func (s *StateLogStore) SetSeqGenerator(seqGenerator sequence.SeqGenerator) { + s.seqGenerator = seqGenerator +} + +func execStateMachineInstanceStatementForInsert(obj statelang.StateMachineInstance, stmt *sql.Stmt) (int64, error) { + result, err := stmt.Exec( + obj.ID(), + obj.MachineID(), + obj.TenantID(), + obj.ParentID(), + obj.StartedTime(), + obj.BusinessKey(), + obj.SerializedStartParams(), + obj.IsRunning(), + obj.Status(), + obj.UpdatedTime(), + ) + if err != nil { + return 0, err + } + rowsAffected, err := result.RowsAffected() + return rowsAffected, err +} + +func execStateMachineInstanceStatementForUpdate(obj statelang.StateMachineInstance, stmt *sql.Stmt) (int64, error) { + var serializedError []byte + if obj.SerializedError() != nil && len(obj.SerializedError().([]byte)) > 0 { + serializedError = obj.SerializedError().([]byte) + } + var compensationStatus sql.NullString + if obj.CompensationStatus() != "" { + compensationStatus.Valid = true + compensationStatus.String = string(obj.CompensationStatus()) + } + + result, err := stmt.Exec( + obj.EndTime(), + serializedError, + obj.SerializedEndParams(), + obj.Status(), + compensationStatus, + obj.IsRunning(), + time.Now(), + obj.ID(), + obj.UpdatedTime(), + ) + if err != nil { + return 0, err + } + rowsAffected, err := result.RowsAffected() + return rowsAffected, err +} + +func execStateInstanceStatementForInsert(obj statelang.StateInstance, stmt *sql.Stmt) (int64, error) { + result, err := stmt.Exec( + obj.ID(), + obj.StateMachineInstance().ID(), + obj.Name(), + obj.Type(), + obj.StartedTime(), + obj.ServiceName(), + obj.ServiceMethod(), + obj.ServiceType(), + obj.IsForUpdate(), + obj.SerializedInputParams(), + obj.Status(), + obj.BusinessKey(), + obj.StateIDCompensatedFor(), + obj.StateIDRetriedFor(), + obj.UpdatedTime(), + ) + if err != nil { + return 0, err + } + rowsAffected, err := result.RowsAffected() + return rowsAffected, err +} + +func execStateInstanceStatementForUpdate(obj statelang.StateInstance, stmt *sql.Stmt) (int64, error) { + var serializedError []byte + if obj.SerializedError() != nil && len(obj.SerializedError().([]byte)) > 0 { + serializedError = obj.SerializedError().([]byte) + } + + result, err := stmt.Exec( + obj.EndTime(), + serializedError, + obj.Status(), + obj.SerializedOutputParams(), + obj.EndTime(), + obj.ID(), + obj.MachineInstanceID(), + ) + if err != nil { + return 0, err + } + rowsAffected, err := result.RowsAffected() + return rowsAffected, err +} + +func scanRowsToStateMachineInstance(rows *sql.Rows) (statelang.StateMachineInstance, error) { + stateMachineInstance := statelang.NewStateMachineInstanceImpl() + var id, machineId, tenantId, parentId, businessKey, started, end, status, compensationStatus, + updated, startParams, endParams sql.NullString + var isRunning sql.NullBool + var errorBlob []byte + columns, _ := rows.Columns() + + args := []any{&id, &machineId, &tenantId, &parentId, &businessKey, &started, &end, &status, + &compensationStatus, &isRunning, &updated} + if len(columns) > 11 { + args = append(args, &startParams, &endParams, &errorBlob) + } + + err := rows.Scan(args...) + if err != nil { + return nil, err + } + + if id.Valid { + stateMachineInstance.SetID(id.String) + } + if machineId.Valid { + stateMachineInstance.SetMachineID(machineId.String) + } + if tenantId.Valid { + stateMachineInstance.SetTenantID(tenantId.String) + } + if parentId.Valid { + stateMachineInstance.SetParentID(parentId.String) + } + if businessKey.Valid { + stateMachineInstance.SetBusinessKey(businessKey.String) + } + if started.Valid { + startedTime, _ := time.Parse(TimeLayout, started.String) + stateMachineInstance.SetStartedTime(startedTime) + } + if end.Valid { + endTime, _ := time.Parse(TimeLayout, end.String) + stateMachineInstance.SetEndTime(endTime) + } + if status.Valid { + stateMachineInstance.SetStatus(statelang.ExecutionStatus(status.String)) + } + + if compensationStatus.Valid { + if compensationStatus.String != "" { + stateMachineInstance.SetCompensationStatus(statelang.ExecutionStatus(compensationStatus.String)) + } + } + if isRunning.Valid { + stateMachineInstance.SetRunning(isRunning.Bool) + } + if updated.Valid { + updatedTime, _ := time.Parse(TimeLayout, updated.String) + stateMachineInstance.SetUpdatedTime(updatedTime) + } + + if len(columns) > 11 { + if startParams.Valid { + stateMachineInstance.SetSerializedStartParams(startParams.String) + } + if endParams.Valid { + stateMachineInstance.SetSerializedEndParams(endParams.String) + } + stateMachineInstance.SetSerializedError(errorBlob) + } + return stateMachineInstance, nil +} + +func scanRowsToStateInstance(rows *sql.Rows) (statelang.StateInstance, error) { + stateInstance := statelang.NewStateInstanceImpl() + var id, machineInstId, name, t, businessKey, started, serviceName, serviceMethod, serviceType, status, + inputParams, outputParams, end, stateIdCompensatedFor, stateIdRetriedFor sql.NullString + var isForUpdate sql.NullBool + var errorBlob []byte + err := rows.Scan(&id, &machineInstId, &name, &t, &businessKey, &started, &serviceName, &serviceMethod, &serviceType, + &isForUpdate, &status, &inputParams, &outputParams, &errorBlob, &end, &stateIdCompensatedFor, &stateIdRetriedFor) + if err != nil { + return nil, err + } + + if id.Valid { + stateInstance.SetID(id.String) + } + if machineInstId.Valid { + stateInstance.SetMachineInstanceID(machineInstId.String) + } + if name.Valid { + stateInstance.SetName(name.String) + } + if t.Valid { + stateInstance.SetType(t.String) + } + if businessKey.Valid { + stateInstance.SetBusinessKey(businessKey.String) + } + if status.Valid { + stateInstance.SetStatus(statelang.ExecutionStatus(status.String)) + } + if started.Valid { + startedTime, _ := time.Parse(TimeLayout, started.String) + stateInstance.SetStartedTime(startedTime) + } + if end.Valid { + endTime, _ := time.Parse(TimeLayout, end.String) + stateInstance.SetEndTime(endTime) + } + + if serviceName.Valid { + stateInstance.SetServiceName(serviceName.String) + } + if serviceMethod.Valid { + stateInstance.SetServiceMethod(serviceMethod.String) + } + if serviceType.Valid { + stateInstance.SetServiceType(serviceType.String) + } + if isForUpdate.Valid { + stateInstance.SetForUpdate(isForUpdate.Bool) + } + if stateIdCompensatedFor.Valid { + stateInstance.SetStateIDCompensatedFor(stateIdCompensatedFor.String) + } + if stateIdRetriedFor.Valid { + stateInstance.SetStateIDRetriedFor(stateIdRetriedFor.String) + } + + if inputParams.Valid { + stateInstance.SetSerializedInputParams(inputParams.String) + } + if outputParams.Valid { + stateInstance.SetSerializedOutputParams(outputParams.String) + } + stateInstance.SetSerializedError(errorBlob) + return stateInstance, err +} diff --git a/pkg/saga/statemachine/engine/store/db/statelog_test.go b/pkg/saga/statemachine/engine/store/db/statelog_test.go new file mode 100644 index 00000000..9044e6df --- /dev/null +++ b/pkg/saga/statemachine/engine/store/db/statelog_test.go @@ -0,0 +1,279 @@ +package db + +import ( + "context" + "fmt" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" + "github.com/seata/seata-go/pkg/saga/statemachine/engine" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/process_ctrl" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func mockProcessContext(stateMachineName string, stateMachineInstance statelang.StateMachineInstance) process_ctrl.ProcessContext { + ctx := engine.NewProcessContextBuilder(). + WithProcessType(process_ctrl.StateLang). + WithOperationName(constant.OperationNameStart). + WithInstruction(process_ctrl.NewStateInstruction(stateMachineName, "000001")). + WithStateMachineInstance(stateMachineInstance). + Build() + return ctx +} + +func mockMachineInstance(stateMachineName string) statelang.StateMachineInstance { + stateMachine := statelang.NewStateMachineImpl() + stateMachine.SetName(stateMachineName) + stateMachine.SetComment("This is a test state machine") + stateMachine.SetCreateTime(time.Now()) + + inst := statelang.NewStateMachineInstanceImpl() + inst.SetStateMachine(stateMachine) + inst.SetMachineID(stateMachineName) + + inst.SetStartParams(map[string]any{"start": 100}) + inst.SetStatus(statelang.RU) + inst.SetStartedTime(time.Now()) + inst.SetUpdatedTime(time.Now()) + return inst +} + +func TestStateLogStore_RecordStateMachineStarted(t *testing.T) { + prepareDB() + + const stateMachineName = "stateMachine" + stateLogStore := NewStateLogStore(db, "seata_") + expected := mockMachineInstance(stateMachineName) + expected.SetBusinessKey("test_started") + ctx := mockProcessContext(stateMachineName, expected) + err := stateLogStore.RecordStateMachineStarted(context.Background(), expected, ctx) + assert.Nil(t, err) + actual, err := stateLogStore.GetStateMachineInstance(expected.ID()) + assert.Nil(t, err) + assert.Equal(t, expected.ID(), actual.ID()) + assert.Equal(t, expected.MachineID(), actual.MachineID()) + assert.Equal(t, fmt.Sprint(expected.StartParams()), fmt.Sprint(actual.StartParams())) + assert.Nil(t, actual.Error()) + assert.Nil(t, actual.SerializedError()) + assert.Equal(t, expected.Status(), actual.Status()) + assert.Equal(t, expected.StartedTime().UnixNano(), actual.StartedTime().UnixNano()) + assert.Equal(t, expected.UpdatedTime().UnixNano(), actual.UpdatedTime().UnixNano()) +} + +func TestStateLogStore_RecordStateMachineFinished(t *testing.T) { + prepareDB() + + const stateMachineName = "stateMachine" + stateLogStore := NewStateLogStore(db, "seata_") + expected := mockMachineInstance(stateMachineName) + expected.SetBusinessKey("test_finished") + ctx := mockProcessContext(stateMachineName, expected) + err := stateLogStore.RecordStateMachineStarted(context.Background(), expected, ctx) + assert.Nil(t, err) + expected.SetEndParams(map[string]any{"end": 100}) + expected.SetError(errors.New("this is a test error")) + expected.SetStatus(statelang.FA) + expected.SetEndTime(time.Now()) + expected.SetRunning(false) + err = stateLogStore.RecordStateMachineFinished(context.Background(), expected, ctx) + assert.Equal(t, "{\"end\":100}", expected.SerializedEndParams()) + assert.NotEmpty(t, expected.SerializedError()) + actual, err := stateLogStore.GetStateMachineInstance(expected.ID()) + assert.Nil(t, err) + + assert.Equal(t, expected.ID(), actual.ID()) + assert.Equal(t, expected.MachineID(), actual.MachineID()) + assert.Equal(t, fmt.Sprint(expected.StartParams()), fmt.Sprint(actual.StartParams())) + assert.Equal(t, "this is a test error", actual.Error().Error()) + assert.Equal(t, expected.Status(), actual.Status()) + assert.Equal(t, expected.IsRunning(), actual.IsRunning()) + assert.Equal(t, expected.StartedTime().UnixNano(), actual.StartedTime().UnixNano()) + assert.Greater(t, actual.UpdatedTime().UnixNano(), expected.UpdatedTime().UnixNano()) + assert.False(t, expected.EndTime().IsZero()) +} + +func TestStateLogStore_RecordStateMachineRestarted(t *testing.T) { + prepareDB() + + const stateMachineName = "stateMachine" + stateLogStore := NewStateLogStore(db, "seata_") + expected := mockMachineInstance(stateMachineName) + expected.SetBusinessKey("test_restarted") + ctx := mockProcessContext(stateMachineName, expected) + err := stateLogStore.RecordStateMachineStarted(context.Background(), expected, ctx) + assert.Nil(t, err) + expected.SetRunning(false) + err = stateLogStore.RecordStateMachineFinished(context.Background(), expected, ctx) + + actual, err := stateLogStore.GetStateMachineInstance(expected.ID()) + assert.Nil(t, err) + assert.False(t, actual.IsRunning()) + + actual.SetRunning(true) + err = stateLogStore.RecordStateMachineRestarted(context.Background(), actual, ctx) + assert.Nil(t, err) + actual, err = stateLogStore.GetStateMachineInstance(actual.ID()) + assert.Nil(t, err) + assert.True(t, actual.IsRunning()) +} + +func TestStateLogStore_RecordStateStarted(t *testing.T) { + prepareDB() + + const stateMachineName = "stateMachine" + stateLogStore := NewStateLogStore(db, "seata_") + + machineInstance := mockMachineInstance("stateMachine") + ctx := mockProcessContext(stateMachineName, machineInstance) + machineInstance.SetID("test") + + common := statelang.NewStateInstanceImpl() + common.SetStateMachineInstance(machineInstance) + common.SetMachineInstanceID(machineInstance.ID()) + common.SetName("ServiceTask1") + common.SetType("ServiceTask") + common.SetStartedTime(time.Now()) + common.SetServiceName("DemoService") + common.SetServiceMethod("foo") + common.SetServiceType("RPC") + common.SetForUpdate(false) + common.SetInputParams(map[string]string{"input": "test"}) + common.SetStatus(statelang.RU) + common.SetBusinessKey("test_state_started") + + origin := statelang.NewStateInstanceImpl() + origin.SetID("origin") + origin.SetStateMachineInstance(machineInstance) + origin.SetMachineInstanceID(machineInstance.ID()) + machineInstance.PutState("origin", origin) + + retried := statelang.NewStateInstanceImpl() + retried.SetStateMachineInstance(machineInstance) + retried.SetMachineInstanceID(machineInstance.ID()) + retried.SetID("origin.1") + retried.SetStateIDRetriedFor("origin") + + compensated := statelang.NewStateInstanceImpl() + compensated.SetStateMachineInstance(machineInstance) + compensated.SetMachineInstanceID(machineInstance.ID()) + compensated.SetID("origin-1") + compensated.SetStateIDCompensatedFor("origin") + + tests := []struct { + name string + expected statelang.StateInstance + }{ + {"common", common}, + {"retried", retried}, + {"compensated", compensated}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := stateLogStore.RecordStateStarted(context.Background(), test.expected, ctx) + assert.Nil(t, err) + actual, err := stateLogStore.GetStateInstance(test.expected.ID(), machineInstance.ID()) + assert.Nil(t, err) + assert.Equal(t, test.expected.ID(), actual.ID()) + assert.Equal(t, test.expected.StateMachineInstance().ID(), actual.MachineInstanceID()) + assert.Equal(t, test.expected.Name(), actual.Name()) + assert.Equal(t, test.expected.Type(), actual.Type()) + assert.Equal(t, test.expected.StartedTime().UnixNano(), actual.StartedTime().UnixNano()) + assert.Equal(t, test.expected.ServiceName(), actual.ServiceName()) + assert.Equal(t, test.expected.ServiceMethod(), actual.ServiceMethod()) + assert.Equal(t, test.expected.ServiceType(), actual.ServiceType()) + assert.Equal(t, test.expected.IsForUpdate(), actual.IsForUpdate()) + assert.Equal(t, test.expected.SerializedInputParams(), actual.SerializedInputParams()) + assert.Equal(t, test.expected.Status(), actual.Status()) + assert.Equal(t, test.expected.BusinessKey(), actual.BusinessKey()) + assert.Equal(t, test.expected.StateIDCompensatedFor(), actual.StateIDCompensatedFor()) + assert.Equal(t, test.expected.StateIDRetriedFor(), actual.StateIDRetriedFor()) + }) + } +} + +func TestStateLogStore_RecordStateFinished(t *testing.T) { + prepareDB() + + const stateMachineName = "stateMachine" + stateLogStore := NewStateLogStore(db, "seata_") + + machineInstance := mockMachineInstance("stateMachine") + ctx := mockProcessContext(stateMachineName, machineInstance) + machineInstance.SetID("test") + + expected := statelang.NewStateInstanceImpl() + expected.SetStateMachineInstance(machineInstance) + expected.SetMachineInstanceID(machineInstance.ID()) + + err := stateLogStore.RecordStateStarted(context.Background(), expected, ctx) + assert.Nil(t, err) + + expected.SetStatus(statelang.UN) + expected.SetError(errors.New("this is a test error")) + expected.SetOutputParams(map[string]string{"output": "test"}) + err = stateLogStore.RecordStateFinished(context.Background(), expected, ctx) + assert.Nil(t, err) + actual, err := stateLogStore.GetStateInstance(expected.ID(), machineInstance.ID()) + assert.Nil(t, err) + assert.Equal(t, expected.Status(), actual.Status()) + assert.Equal(t, expected.Error().Error(), actual.Error().Error()) + assert.NotEmpty(t, actual.OutputParams()) + assert.Equal(t, expected.SerializedOutputParams(), actual.SerializedOutputParams()) +} + +func TestStateLogStore_GetStateMachineInstanceByBusinessKey(t *testing.T) { + prepareDB() + + const stateMachineName = "stateMachine" + stateLogStore := NewStateLogStore(db, "seata_") + expected := mockMachineInstance(stateMachineName) + expected.SetBusinessKey("test_business_key") + expected.SetTenantID("000001") + ctx := mockProcessContext(stateMachineName, expected) + + err := stateLogStore.RecordStateMachineStarted(context.Background(), expected, ctx) + assert.Nil(t, err) + actual, err := stateLogStore.GetStateMachineInstanceByBusinessKey(expected.BusinessKey(), expected.TenantID()) + assert.Nil(t, err) + assert.Equal(t, expected.ID(), actual.ID()) + assert.Equal(t, expected.MachineID(), actual.MachineID()) + assert.Equal(t, fmt.Sprint(expected.StartParams()), fmt.Sprint(actual.StartParams())) + assert.Nil(t, actual.Error()) + assert.Nil(t, actual.SerializedError()) + assert.Equal(t, expected.Status(), actual.Status()) + assert.Equal(t, expected.StartedTime().UnixNano(), actual.StartedTime().UnixNano()) + assert.Equal(t, expected.UpdatedTime().UnixNano(), actual.UpdatedTime().UnixNano()) +} + +func TestStateLogStore_GetStateMachineInstanceByParentId(t *testing.T) { + prepareDB() + + const ( + stateMachineName = "stateMachine" + parentId = "parent" + ) + stateLogStore := NewStateLogStore(db, "seata_") + expected := mockMachineInstance(stateMachineName) + expected.SetBusinessKey("test_parent_id") + expected.SetParentID(parentId) + ctx := mockProcessContext(stateMachineName, expected) + + err := stateLogStore.RecordStateMachineStarted(context.Background(), expected, ctx) + assert.Nil(t, err) + actualList, err := stateLogStore.GetStateMachineInstanceByParentId(parentId) + assert.Nil(t, err) + + assert.Equal(t, 1, len(actualList)) + actual := actualList[0] + assert.Equal(t, expected.ID(), actual.ID()) + assert.Equal(t, expected.MachineID(), actual.MachineID()) + // no startParams, endParams and Error + assert.NotEqual(t, fmt.Sprint(expected.StartParams()), fmt.Sprint(actual.StartParams())) + assert.Nil(t, actual.Error()) + assert.Nil(t, actual.SerializedError()) + assert.Equal(t, expected.Status(), actual.Status()) + assert.Equal(t, expected.StartedTime().UnixNano(), actual.StartedTime().UnixNano()) + assert.Equal(t, expected.UpdatedTime().UnixNano(), actual.UpdatedTime().UnixNano()) +} diff --git a/pkg/saga/statemachine/engine/store/statemachine_store.go b/pkg/saga/statemachine/engine/store/statemachine_store.go index 8959df92..2dac06d3 100644 --- a/pkg/saga/statemachine/engine/store/statemachine_store.go +++ b/pkg/saga/statemachine/engine/store/statemachine_store.go @@ -8,7 +8,7 @@ import ( ) type StateLogRepository interface { - GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateInstance, error) + GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateMachineInstance, error) GetStateMachineInstanceByBusinessKey(businessKey string, tenantId string) (statelang.StateInstance, error) @@ -30,9 +30,9 @@ type StateLogStore interface { RecordStateFinished(ctx context.Context, stateInstance statelang.StateInstance, context process_ctrl.ProcessContext) error - GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateInstance, error) + GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateMachineInstance, error) - GetStateMachineInstanceByBusinessKey(businessKey string, tenantId string) (statelang.StateInstance, error) + GetStateMachineInstanceByBusinessKey(businessKey string, tenantId string) (statelang.StateMachineInstance, error) GetStateMachineInstanceByParentId(parentId string) ([]statelang.StateMachineInstance, error) diff --git a/pkg/saga/statemachine/engine/utils.go b/pkg/saga/statemachine/engine/utils.go index 654e319d..92b7aa10 100644 --- a/pkg/saga/statemachine/engine/utils.go +++ b/pkg/saga/statemachine/engine/utils.go @@ -12,7 +12,7 @@ type ProcessContextBuilder struct { } func NewProcessContextBuilder() *ProcessContextBuilder { - processContextImpl := &process_ctrl.ProcessContextImpl{} + processContextImpl := process_ctrl.NewProcessContextImpl() return &ProcessContextBuilder{processContextImpl} } diff --git a/pkg/saga/statemachine/statelang/state_instance.go b/pkg/saga/statemachine/statelang/state_instance.go index 826faca6..e77da058 100644 --- a/pkg/saga/statemachine/statelang/state_instance.go +++ b/pkg/saga/statemachine/statelang/state_instance.go @@ -77,6 +77,8 @@ type StateInstance interface { StateMachineInstance() StateMachineInstance + MachineInstanceID() string + SetStateMachineInstance(stateMachineInstance StateMachineInstance) IsIgnoreStatus() bool @@ -139,6 +141,14 @@ func (s *StateInstanceImpl) SetID(id string) { s.id = id } +func (s *StateInstanceImpl) MachineInstanceID() string { + return s.machineInstanceId +} + +func (s *StateInstanceImpl) SetMachineInstanceID(machineInstanceId string) { + s.machineInstanceId = machineInstanceId +} + func (s *StateInstanceImpl) Name() string { return s.name } diff --git a/pkg/saga/statemachine/statelang/statemachine_instance.go b/pkg/saga/statemachine/statelang/statemachine_instance.go index ff039024..399d222b 100644 --- a/pkg/saga/statemachine/statelang/statemachine_instance.go +++ b/pkg/saga/statemachine/statelang/statemachine_instance.go @@ -300,7 +300,7 @@ func (s *StateMachineInstanceImpl) SetSerializedStartParams(serializedStartParam } func (s *StateMachineInstanceImpl) SerializedEndParams() interface{} { - return s.endParams + return s.serializedEndParams } func (s *StateMachineInstanceImpl) SetSerializedEndParams(serializedEndParams interface{}) { diff --git a/testdata/sql/saga/sqlite_init.sql b/testdata/sql/saga/sqlite_init.sql new file mode 100644 index 00000000..f82ed793 --- /dev/null +++ b/testdata/sql/saga/sqlite_init.sql @@ -0,0 +1,81 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TABLE IF NOT EXISTS seata_state_machine_def +( + id VARCHAR(32) NOT NULL, + name VARCHAR(128) NOT NULL, + tenant_id VARCHAR(32) NOT NULL, + app_name VARCHAR(32) NOT NULL, + type VARCHAR(20), + comment_ VARCHAR(255), + ver VARCHAR(16) NOT NULL, + gmt_create TIMESTAMP(3) NOT NULL, + status VARCHAR(2) NOT NULL, + content CLOB, + recover_strategy VARCHAR(16), + PRIMARY KEY (id) +); + +CREATE TABLE IF NOT EXISTS seata_state_machine_inst +( + id VARCHAR(128) NOT NULL, + machine_id VARCHAR(32) NOT NULL, + tenant_id VARCHAR(32) NOT NULL, + parent_id VARCHAR(128), + gmt_started TIMESTAMP(3) NOT NULL, + business_key VARCHAR(48), + uni_business_key VARCHAR(128) GENERATED ALWAYS AS ( + CASE + WHEN "BUSINESS_KEY" IS NULL + THEN "ID" + ELSE "BUSINESS_KEY" + END), + start_params CLOB, + gmt_end TIMESTAMP(3), + excep BLOB, + end_params CLOB, + status VARCHAR(2), + compensation_status VARCHAR(2), + is_running SMALLINT, + gmt_updated TIMESTAMP(3) NOT NULL, + PRIMARY KEY (id) +); +CREATE UNIQUE INDEX IF NOT EXISTS state_machine_inst_unibuzkey ON seata_state_machine_inst (uni_business_key, tenant_id); + +CREATE TABLE IF NOT EXISTS seata_state_inst +( + id VARCHAR(48) NOT NULL, + machine_inst_id VARCHAR(46) NOT NULL, + name VARCHAR(128) NOT NULL, + type VARCHAR(20), + service_name VARCHAR(128), + service_method VARCHAR(128), + service_type VARCHAR(16), + business_key VARCHAR(48), + state_id_compensated_for VARCHAR(50), + state_id_retried_for VARCHAR(50), + gmt_started TIMESTAMP(3) NOT NULL, + is_for_update SMALLINT, + input_params CLOB, + output_params CLOB, + status VARCHAR(2) NOT NULL, + excep BLOB, + gmt_updated TIMESTAMP(3), + gmt_end TIMESTAMP(3), + PRIMARY KEY (id, machine_inst_id) +); --------------------------------------------------------------------- To unsubscribe, e-mail: notifications-unsubscr...@seata.apache.org For additional commands, e-mail: notifications-h...@seata.apache.org