This is an automated email from the ASF dual-hosted git repository.

tew pushed a commit to branch feature/support-postgresql-in-xa-mode
in repository https://gitbox.apache.org/repos/asf/incubator-seata-go.git


The following commit(s) were added to 
refs/heads/feature/support-postgresql-in-xa-mode by this push:
     new a8116ce5 fix: fix RegisterTableCache type errors and optimize code 
patterns (#1029)
a8116ce5 is described below

commit a8116ce50a35eb5885efe6a600e13926a9fbe465
Author: flypiggy <[email protected]>
AuthorDate: Sat Dec 20 16:40:40 2025 +0800

    fix: fix RegisterTableCache type errors and optimize code patterns (#1029)
    
    fix issue #1028
---
 pkg/datasource/sql/exec/at/base_executor.go        | 49 +++++-----------------
 pkg/datasource/sql/exec/at/escape.go               |  2 +-
 pkg/datasource/sql/exec/at/insert_executor.go      |  6 +--
 pkg/datasource/sql/exec/at/insert_executor_test.go | 17 ++++++--
 .../sql/exec/at/multi_update_excutor_test.go       |  9 +++-
 pkg/datasource/sql/exec/at/update_executor_test.go |  5 ++-
 6 files changed, 39 insertions(+), 49 deletions(-)

diff --git a/pkg/datasource/sql/exec/at/base_executor.go 
b/pkg/datasource/sql/exec/at/base_executor.go
index 13ac4fe8..0259cd66 100644
--- a/pkg/datasource/sql/exec/at/base_executor.go
+++ b/pkg/datasource/sql/exec/at/base_executor.go
@@ -140,73 +140,50 @@ func (b *baseExecutor) traversalArgs(node ast.Node, 
argsIndex *[]int32) {
        if node == nil {
                return
        }
-       switch node.(type) {
+       switch expr := node.(type) {
        case *ast.BinaryOperationExpr:
-               expr := node.(*ast.BinaryOperationExpr)
                b.traversalArgs(expr.L, argsIndex)
                b.traversalArgs(expr.R, argsIndex)
-               break
        case *ast.BetweenExpr:
-               expr := node.(*ast.BetweenExpr)
                b.traversalArgs(expr.Left, argsIndex)
                b.traversalArgs(expr.Right, argsIndex)
-               break
        case *ast.PatternInExpr:
-               exprs := node.(*ast.PatternInExpr).List
-               for i := 0; i < len(exprs); i++ {
-                       b.traversalArgs(exprs[i], argsIndex)
+               for i := 0; i < len(expr.List); i++ {
+                       b.traversalArgs(expr.List[i], argsIndex)
                }
-               break
        case *ast.Join:
-               exprs := node.(*ast.Join)
-               b.traversalArgs(exprs.Left, argsIndex)
-               if exprs.Right != nil {
-                       b.traversalArgs(exprs.Right, argsIndex)
+               b.traversalArgs(expr.Left, argsIndex)
+               if expr.Right != nil {
+                       b.traversalArgs(expr.Right, argsIndex)
                }
-               if exprs.On != nil {
-                       b.traversalArgs(exprs.On.Expr, argsIndex)
+               if expr.On != nil {
+                       b.traversalArgs(expr.On.Expr, argsIndex)
                }
-               break
        case *ast.UnaryOperationExpr:
-               expr := node.(*ast.UnaryOperationExpr)
                b.traversalArgs(expr.V, argsIndex)
-               break
        case *ast.FuncCallExpr:
-               expr := node.(*ast.FuncCallExpr)
                for _, arg := range expr.Args {
                        b.traversalArgs(arg, argsIndex)
                }
-               break
        case *ast.SubqueryExpr:
-               expr := node.(*ast.SubqueryExpr)
                if expr.Query != nil {
                        b.traversalArgs(expr.Query, argsIndex)
                }
-               break
        case *ast.ExistsSubqueryExpr:
-               expr := node.(*ast.ExistsSubqueryExpr)
                if expr.Sel != nil {
                        b.traversalArgs(expr.Sel, argsIndex)
                }
-               break
        case *ast.CompareSubqueryExpr:
-               expr := node.(*ast.CompareSubqueryExpr)
                b.traversalArgs(expr.L, argsIndex)
                if expr.R != nil {
                        b.traversalArgs(expr.R, argsIndex)
                }
-               break
        case *ast.PatternLikeExpr:
-               expr := node.(*ast.PatternLikeExpr)
                b.traversalArgs(expr.Expr, argsIndex)
                b.traversalArgs(expr.Pattern, argsIndex)
-               break
        case *ast.IsNullExpr:
-               expr := node.(*ast.IsNullExpr)
                b.traversalArgs(expr.Expr, argsIndex)
-               break
        case *ast.CaseExpr:
-               expr := node.(*ast.CaseExpr)
                if expr.Value != nil {
                        b.traversalArgs(expr.Value, argsIndex)
                }
@@ -217,10 +194,8 @@ func (b *baseExecutor) traversalArgs(node ast.Node, 
argsIndex *[]int32) {
                if expr.ElseClause != nil {
                        b.traversalArgs(expr.ElseClause, argsIndex)
                }
-               break
        case *test_driver.ParamMarkerExpr:
-               *argsIndex = append(*argsIndex, 
int32(node.(*test_driver.ParamMarkerExpr).Order))
-               break
+               *argsIndex = append(*argsIndex, int32(expr.Order))
        }
 }
 
@@ -269,10 +244,8 @@ func (b *baseExecutor) getNeedColumns(meta 
*types.TableMeta, columns []string, d
                needUpdateColumns = columns
                if !b.containsPKByName(meta, columns) {
                        pkNames := meta.GetPrimaryKeyOnlyName()
-                       if pkNames != nil && len(pkNames) > 0 {
-                               for _, name := range pkNames {
-                                       needUpdateColumns = 
append(needUpdateColumns, name)
-                               }
+                       if len(pkNames) > 0 {
+                               needUpdateColumns = append(needUpdateColumns, 
pkNames...)
                        }
                }
                // todo If it contains onUpdate columns, add onUpdate columns
diff --git a/pkg/datasource/sql/exec/at/escape.go 
b/pkg/datasource/sql/exec/at/escape.go
index bd8d1c6a..05905770 100644
--- a/pkg/datasource/sql/exec/at/escape.go
+++ b/pkg/datasource/sql/exec/at/escape.go
@@ -200,7 +200,7 @@ func GetOrderedPkList(image *types.RecordImage, row 
types.RowImage, dbType types
 
        for _, pkName := range pkColumnNameListByOrder {
                for _, col := range pkColumnNameListNoOrder {
-                       if strings.Index(col.ColumnName, pkName) > -1 {
+                       if strings.Contains(col.ColumnName, pkName) {
                                pkFields = append(pkFields, col)
                        }
                }
diff --git a/pkg/datasource/sql/exec/at/insert_executor.go 
b/pkg/datasource/sql/exec/at/insert_executor.go
index ae7bac71..b5bb3304 100644
--- a/pkg/datasource/sql/exec/at/insert_executor.go
+++ b/pkg/datasource/sql/exec/at/insert_executor.go
@@ -328,7 +328,7 @@ func (i *insertExecutor) 
parsePkValuesFromStatement(insertStmt *ast.InsertStmt,
                return nil, nil
        }
        pkIndexMap := i.getPkIndex(insertStmt, meta)
-       if pkIndexMap == nil || len(pkIndexMap) == 0 {
+       if len(pkIndexMap) == 0 {
                return nil, fmt.Errorf("pkIndex is not found")
        }
        var pkIndexArray []int
@@ -343,13 +343,13 @@ func (i *insertExecutor) 
parsePkValuesFromStatement(insertStmt *ast.InsertStmt,
 
        pkValuesMap := make(map[string][]interface{})
 
-       if nameValues != nil && len(nameValues) > 0 {
+       if len(nameValues) > 0 {
                // use prepared statements
                insertRows, err := getInsertRows(insertStmt, pkIndexArray)
                if err != nil {
                        return nil, err
                }
-               if insertRows == nil || len(insertRows) == 0 {
+               if len(insertRows) == 0 {
                        return nil, err
                }
                totalPlaceholderNum := -1
diff --git a/pkg/datasource/sql/exec/at/insert_executor_test.go 
b/pkg/datasource/sql/exec/at/insert_executor_test.go
index ecf37392..2ceedf05 100644
--- a/pkg/datasource/sql/exec/at/insert_executor_test.go
+++ b/pkg/datasource/sql/exec/at/insert_executor_test.go
@@ -19,6 +19,7 @@ package at
 
 import (
        "context"
+       "database/sql"
        "database/sql/driver"
        "reflect"
        "testing"
@@ -114,7 +115,9 @@ func TestBuildSelectSQLByInsert(t *testing.T) {
 
        for _, test := range tests {
                t.Run(test.name, func(t *testing.T) {
-                       datasource.RegisterTableCache(types.DBTypeMySQL, 
mysql.NewTableMetaInstance(nil, nil))
+                       datasource.RegisterTableCache(types.DBTypeMySQL, 
func(db *sql.DB, cfg interface{}) datasource.TableMetaCache {
+                               return mysql.NewTableMetaInstance(db, nil)
+                       })
                        stub := 
gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)),
 "GetTableMeta",
                                func(_ *mysql.TableMetaCache, ctx 
context.Context, dbName, tableName string) (*types.TableMeta, error) {
                                        return &test.metaData, nil
@@ -629,7 +632,9 @@ func TestMySQLInsertUndoLogBuilder_getPkValuesByColumn(t 
*testing.T) {
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       datasource.RegisterTableCache(types.DBTypeMySQL, 
mysql.NewTableMetaInstance(nil, nil))
+                       datasource.RegisterTableCache(types.DBTypeMySQL, 
func(db *sql.DB, cfg interface{}) datasource.TableMetaCache {
+                               return mysql.NewTableMetaInstance(db, nil)
+                       })
                        stub := 
gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)),
 "GetTableMeta",
                                func(_ *mysql.TableMetaCache, ctx 
context.Context, dbName, tableName string) (*types.TableMeta, error) {
                                        return &tt.args.meta, nil
@@ -731,7 +736,9 @@ func TestMySQLInsertUndoLogBuilder_getPkValuesByAuto(t 
*testing.T) {
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       datasource.RegisterTableCache(types.DBTypeMySQL, 
mysql.NewTableMetaInstance(nil, nil))
+                       datasource.RegisterTableCache(types.DBTypeMySQL, 
func(db *sql.DB, cfg interface{}) datasource.TableMetaCache {
+                               return mysql.NewTableMetaInstance(db, nil)
+                       })
                        stub := 
gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)),
 "GetTableMeta",
                                func(_ *mysql.TableMetaCache, ctx 
context.Context, dbName, tableName string) (*types.TableMeta, error) {
                                        return &tt.args.meta, nil
@@ -824,7 +831,9 @@ func TestMySQLInsertUndoLogBuilder_autoGeneratePks(t 
*testing.T) {
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       datasource.RegisterTableCache(types.DBTypeMySQL, 
mysql.NewTableMetaInstance(nil, nil))
+                       datasource.RegisterTableCache(types.DBTypeMySQL, 
func(db *sql.DB, cfg interface{}) datasource.TableMetaCache {
+                               return mysql.NewTableMetaInstance(db, nil)
+                       })
                        stub := 
gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)),
 "GetTableMeta",
                                func(_ *mysql.TableMetaCache, ctx 
context.Context, dbName, tableName string) (*types.TableMeta, error) {
                                        return &tt.args.meta, nil
diff --git a/pkg/datasource/sql/exec/at/multi_update_excutor_test.go 
b/pkg/datasource/sql/exec/at/multi_update_excutor_test.go
index 21ef2e30..b0c9581e 100644
--- a/pkg/datasource/sql/exec/at/multi_update_excutor_test.go
+++ b/pkg/datasource/sql/exec/at/multi_update_excutor_test.go
@@ -18,6 +18,7 @@
 package at
 
 import (
+       "database/sql"
        "database/sql/driver"
        "testing"
 
@@ -34,7 +35,9 @@ import (
 
 func TestBuildSelectSQLByMultiUpdate(t *testing.T) {
        undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true})
-       datasource.RegisterTableCache(types.DBTypeMySQL, 
mysql.NewTableMetaInstance(nil, nil))
+       datasource.RegisterTableCache(types.DBTypeMySQL, func(db *sql.DB, cfg 
interface{}) datasource.TableMetaCache {
+               return mysql.NewTableMetaInstance(db, nil)
+       })
 
        tests := []struct {
                name            string
@@ -101,7 +104,9 @@ func TestBuildSelectSQLByMultiUpdate(t *testing.T) {
 
 func TestBuildSelectSQLByMultiUpdateAllColumns(t *testing.T) {
        undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: false})
-       datasource.RegisterTableCache(types.DBTypeMySQL, 
mysql.NewTableMetaInstance(nil, nil))
+       datasource.RegisterTableCache(types.DBTypeMySQL, func(db *sql.DB, cfg 
interface{}) datasource.TableMetaCache {
+               return mysql.NewTableMetaInstance(db, nil)
+       })
 
        tests := []struct {
                name            string
diff --git a/pkg/datasource/sql/exec/at/update_executor_test.go 
b/pkg/datasource/sql/exec/at/update_executor_test.go
index a6ffc9be..01c17ac2 100644
--- a/pkg/datasource/sql/exec/at/update_executor_test.go
+++ b/pkg/datasource/sql/exec/at/update_executor_test.go
@@ -19,6 +19,7 @@ package at
 
 import (
        "context"
+       "database/sql"
        "database/sql/driver"
        "reflect"
        "testing"
@@ -38,7 +39,9 @@ import (
 
 func TestBuildSelectSQLByUpdate(t *testing.T) {
        undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true})
-       datasource.RegisterTableCache(types.DBTypeMySQL, 
mysql.NewTableMetaInstance(nil, nil))
+       datasource.RegisterTableCache(types.DBTypeMySQL, func(db *sql.DB, cfg 
interface{}) datasource.TableMetaCache {
+               return mysql.NewTableMetaInstance(db, nil)
+       })
        stub := 
gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)),
 "GetTableMeta",
                func(_ *mysql.TableMetaCache, ctx context.Context, dbName, 
tableName string) (*types.TableMeta, error) {
                        return &types.TableMeta{


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to