This is an automated email from the ASF dual-hosted git repository.
tew pushed a commit to branch feature/support-postgresql-in-xa-mode
in repository https://gitbox.apache.org/repos/asf/incubator-seata-go.git
The following commit(s) were added to
refs/heads/feature/support-postgresql-in-xa-mode by this push:
new a8116ce5 fix: fix RegisterTableCache type errors and optimize code
patterns (#1029)
a8116ce5 is described below
commit a8116ce50a35eb5885efe6a600e13926a9fbe465
Author: flypiggy <[email protected]>
AuthorDate: Sat Dec 20 16:40:40 2025 +0800
fix: fix RegisterTableCache type errors and optimize code patterns (#1029)
fix issue #1028
---
pkg/datasource/sql/exec/at/base_executor.go | 49 +++++-----------------
pkg/datasource/sql/exec/at/escape.go | 2 +-
pkg/datasource/sql/exec/at/insert_executor.go | 6 +--
pkg/datasource/sql/exec/at/insert_executor_test.go | 17 ++++++--
.../sql/exec/at/multi_update_excutor_test.go | 9 +++-
pkg/datasource/sql/exec/at/update_executor_test.go | 5 ++-
6 files changed, 39 insertions(+), 49 deletions(-)
diff --git a/pkg/datasource/sql/exec/at/base_executor.go
b/pkg/datasource/sql/exec/at/base_executor.go
index 13ac4fe8..0259cd66 100644
--- a/pkg/datasource/sql/exec/at/base_executor.go
+++ b/pkg/datasource/sql/exec/at/base_executor.go
@@ -140,73 +140,50 @@ func (b *baseExecutor) traversalArgs(node ast.Node,
argsIndex *[]int32) {
if node == nil {
return
}
- switch node.(type) {
+ switch expr := node.(type) {
case *ast.BinaryOperationExpr:
- expr := node.(*ast.BinaryOperationExpr)
b.traversalArgs(expr.L, argsIndex)
b.traversalArgs(expr.R, argsIndex)
- break
case *ast.BetweenExpr:
- expr := node.(*ast.BetweenExpr)
b.traversalArgs(expr.Left, argsIndex)
b.traversalArgs(expr.Right, argsIndex)
- break
case *ast.PatternInExpr:
- exprs := node.(*ast.PatternInExpr).List
- for i := 0; i < len(exprs); i++ {
- b.traversalArgs(exprs[i], argsIndex)
+ for i := 0; i < len(expr.List); i++ {
+ b.traversalArgs(expr.List[i], argsIndex)
}
- break
case *ast.Join:
- exprs := node.(*ast.Join)
- b.traversalArgs(exprs.Left, argsIndex)
- if exprs.Right != nil {
- b.traversalArgs(exprs.Right, argsIndex)
+ b.traversalArgs(expr.Left, argsIndex)
+ if expr.Right != nil {
+ b.traversalArgs(expr.Right, argsIndex)
}
- if exprs.On != nil {
- b.traversalArgs(exprs.On.Expr, argsIndex)
+ if expr.On != nil {
+ b.traversalArgs(expr.On.Expr, argsIndex)
}
- break
case *ast.UnaryOperationExpr:
- expr := node.(*ast.UnaryOperationExpr)
b.traversalArgs(expr.V, argsIndex)
- break
case *ast.FuncCallExpr:
- expr := node.(*ast.FuncCallExpr)
for _, arg := range expr.Args {
b.traversalArgs(arg, argsIndex)
}
- break
case *ast.SubqueryExpr:
- expr := node.(*ast.SubqueryExpr)
if expr.Query != nil {
b.traversalArgs(expr.Query, argsIndex)
}
- break
case *ast.ExistsSubqueryExpr:
- expr := node.(*ast.ExistsSubqueryExpr)
if expr.Sel != nil {
b.traversalArgs(expr.Sel, argsIndex)
}
- break
case *ast.CompareSubqueryExpr:
- expr := node.(*ast.CompareSubqueryExpr)
b.traversalArgs(expr.L, argsIndex)
if expr.R != nil {
b.traversalArgs(expr.R, argsIndex)
}
- break
case *ast.PatternLikeExpr:
- expr := node.(*ast.PatternLikeExpr)
b.traversalArgs(expr.Expr, argsIndex)
b.traversalArgs(expr.Pattern, argsIndex)
- break
case *ast.IsNullExpr:
- expr := node.(*ast.IsNullExpr)
b.traversalArgs(expr.Expr, argsIndex)
- break
case *ast.CaseExpr:
- expr := node.(*ast.CaseExpr)
if expr.Value != nil {
b.traversalArgs(expr.Value, argsIndex)
}
@@ -217,10 +194,8 @@ func (b *baseExecutor) traversalArgs(node ast.Node,
argsIndex *[]int32) {
if expr.ElseClause != nil {
b.traversalArgs(expr.ElseClause, argsIndex)
}
- break
case *test_driver.ParamMarkerExpr:
- *argsIndex = append(*argsIndex,
int32(node.(*test_driver.ParamMarkerExpr).Order))
- break
+ *argsIndex = append(*argsIndex, int32(expr.Order))
}
}
@@ -269,10 +244,8 @@ func (b *baseExecutor) getNeedColumns(meta
*types.TableMeta, columns []string, d
needUpdateColumns = columns
if !b.containsPKByName(meta, columns) {
pkNames := meta.GetPrimaryKeyOnlyName()
- if pkNames != nil && len(pkNames) > 0 {
- for _, name := range pkNames {
- needUpdateColumns =
append(needUpdateColumns, name)
- }
+ if len(pkNames) > 0 {
+ needUpdateColumns = append(needUpdateColumns,
pkNames...)
}
}
// todo If it contains onUpdate columns, add onUpdate columns
diff --git a/pkg/datasource/sql/exec/at/escape.go
b/pkg/datasource/sql/exec/at/escape.go
index bd8d1c6a..05905770 100644
--- a/pkg/datasource/sql/exec/at/escape.go
+++ b/pkg/datasource/sql/exec/at/escape.go
@@ -200,7 +200,7 @@ func GetOrderedPkList(image *types.RecordImage, row
types.RowImage, dbType types
for _, pkName := range pkColumnNameListByOrder {
for _, col := range pkColumnNameListNoOrder {
- if strings.Index(col.ColumnName, pkName) > -1 {
+ if strings.Contains(col.ColumnName, pkName) {
pkFields = append(pkFields, col)
}
}
diff --git a/pkg/datasource/sql/exec/at/insert_executor.go
b/pkg/datasource/sql/exec/at/insert_executor.go
index ae7bac71..b5bb3304 100644
--- a/pkg/datasource/sql/exec/at/insert_executor.go
+++ b/pkg/datasource/sql/exec/at/insert_executor.go
@@ -328,7 +328,7 @@ func (i *insertExecutor)
parsePkValuesFromStatement(insertStmt *ast.InsertStmt,
return nil, nil
}
pkIndexMap := i.getPkIndex(insertStmt, meta)
- if pkIndexMap == nil || len(pkIndexMap) == 0 {
+ if len(pkIndexMap) == 0 {
return nil, fmt.Errorf("pkIndex is not found")
}
var pkIndexArray []int
@@ -343,13 +343,13 @@ func (i *insertExecutor)
parsePkValuesFromStatement(insertStmt *ast.InsertStmt,
pkValuesMap := make(map[string][]interface{})
- if nameValues != nil && len(nameValues) > 0 {
+ if len(nameValues) > 0 {
// use prepared statements
insertRows, err := getInsertRows(insertStmt, pkIndexArray)
if err != nil {
return nil, err
}
- if insertRows == nil || len(insertRows) == 0 {
+ if len(insertRows) == 0 {
return nil, err
}
totalPlaceholderNum := -1
diff --git a/pkg/datasource/sql/exec/at/insert_executor_test.go
b/pkg/datasource/sql/exec/at/insert_executor_test.go
index ecf37392..2ceedf05 100644
--- a/pkg/datasource/sql/exec/at/insert_executor_test.go
+++ b/pkg/datasource/sql/exec/at/insert_executor_test.go
@@ -19,6 +19,7 @@ package at
import (
"context"
+ "database/sql"
"database/sql/driver"
"reflect"
"testing"
@@ -114,7 +115,9 @@ func TestBuildSelectSQLByInsert(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- datasource.RegisterTableCache(types.DBTypeMySQL,
mysql.NewTableMetaInstance(nil, nil))
+ datasource.RegisterTableCache(types.DBTypeMySQL,
func(db *sql.DB, cfg interface{}) datasource.TableMetaCache {
+ return mysql.NewTableMetaInstance(db, nil)
+ })
stub :=
gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)),
"GetTableMeta",
func(_ *mysql.TableMetaCache, ctx
context.Context, dbName, tableName string) (*types.TableMeta, error) {
return &test.metaData, nil
@@ -629,7 +632,9 @@ func TestMySQLInsertUndoLogBuilder_getPkValuesByColumn(t
*testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- datasource.RegisterTableCache(types.DBTypeMySQL,
mysql.NewTableMetaInstance(nil, nil))
+ datasource.RegisterTableCache(types.DBTypeMySQL,
func(db *sql.DB, cfg interface{}) datasource.TableMetaCache {
+ return mysql.NewTableMetaInstance(db, nil)
+ })
stub :=
gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)),
"GetTableMeta",
func(_ *mysql.TableMetaCache, ctx
context.Context, dbName, tableName string) (*types.TableMeta, error) {
return &tt.args.meta, nil
@@ -731,7 +736,9 @@ func TestMySQLInsertUndoLogBuilder_getPkValuesByAuto(t
*testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- datasource.RegisterTableCache(types.DBTypeMySQL,
mysql.NewTableMetaInstance(nil, nil))
+ datasource.RegisterTableCache(types.DBTypeMySQL,
func(db *sql.DB, cfg interface{}) datasource.TableMetaCache {
+ return mysql.NewTableMetaInstance(db, nil)
+ })
stub :=
gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)),
"GetTableMeta",
func(_ *mysql.TableMetaCache, ctx
context.Context, dbName, tableName string) (*types.TableMeta, error) {
return &tt.args.meta, nil
@@ -824,7 +831,9 @@ func TestMySQLInsertUndoLogBuilder_autoGeneratePks(t
*testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- datasource.RegisterTableCache(types.DBTypeMySQL,
mysql.NewTableMetaInstance(nil, nil))
+ datasource.RegisterTableCache(types.DBTypeMySQL,
func(db *sql.DB, cfg interface{}) datasource.TableMetaCache {
+ return mysql.NewTableMetaInstance(db, nil)
+ })
stub :=
gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)),
"GetTableMeta",
func(_ *mysql.TableMetaCache, ctx
context.Context, dbName, tableName string) (*types.TableMeta, error) {
return &tt.args.meta, nil
diff --git a/pkg/datasource/sql/exec/at/multi_update_excutor_test.go
b/pkg/datasource/sql/exec/at/multi_update_excutor_test.go
index 21ef2e30..b0c9581e 100644
--- a/pkg/datasource/sql/exec/at/multi_update_excutor_test.go
+++ b/pkg/datasource/sql/exec/at/multi_update_excutor_test.go
@@ -18,6 +18,7 @@
package at
import (
+ "database/sql"
"database/sql/driver"
"testing"
@@ -34,7 +35,9 @@ import (
func TestBuildSelectSQLByMultiUpdate(t *testing.T) {
undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true})
- datasource.RegisterTableCache(types.DBTypeMySQL,
mysql.NewTableMetaInstance(nil, nil))
+ datasource.RegisterTableCache(types.DBTypeMySQL, func(db *sql.DB, cfg
interface{}) datasource.TableMetaCache {
+ return mysql.NewTableMetaInstance(db, nil)
+ })
tests := []struct {
name string
@@ -101,7 +104,9 @@ func TestBuildSelectSQLByMultiUpdate(t *testing.T) {
func TestBuildSelectSQLByMultiUpdateAllColumns(t *testing.T) {
undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: false})
- datasource.RegisterTableCache(types.DBTypeMySQL,
mysql.NewTableMetaInstance(nil, nil))
+ datasource.RegisterTableCache(types.DBTypeMySQL, func(db *sql.DB, cfg
interface{}) datasource.TableMetaCache {
+ return mysql.NewTableMetaInstance(db, nil)
+ })
tests := []struct {
name string
diff --git a/pkg/datasource/sql/exec/at/update_executor_test.go
b/pkg/datasource/sql/exec/at/update_executor_test.go
index a6ffc9be..01c17ac2 100644
--- a/pkg/datasource/sql/exec/at/update_executor_test.go
+++ b/pkg/datasource/sql/exec/at/update_executor_test.go
@@ -19,6 +19,7 @@ package at
import (
"context"
+ "database/sql"
"database/sql/driver"
"reflect"
"testing"
@@ -38,7 +39,9 @@ import (
func TestBuildSelectSQLByUpdate(t *testing.T) {
undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true})
- datasource.RegisterTableCache(types.DBTypeMySQL,
mysql.NewTableMetaInstance(nil, nil))
+ datasource.RegisterTableCache(types.DBTypeMySQL, func(db *sql.DB, cfg
interface{}) datasource.TableMetaCache {
+ return mysql.NewTableMetaInstance(db, nil)
+ })
stub :=
gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)),
"GetTableMeta",
func(_ *mysql.TableMetaCache, ctx context.Context, dbName,
tableName string) (*types.TableMeta, error) {
return &types.TableMeta{
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]