This is an automated email from the ASF dual-hosted git repository.

zfeng pushed a commit to branch develop-tmp
in repository https://gitbox.apache.org/repos/asf/incubator-seata-go.git


The following commit(s) were added to refs/heads/develop-tmp by this push:
     new 8e56d22d bugfix: error image when use null value as image query 
condition in insert on duplicate #704 (#725)
8e56d22d is described below

commit 8e56d22d7d9513f41ac65fc18ecd61e563a8bd2b
Author: Aster Zephyr <2046084...@qq.com>
AuthorDate: Sun Jun 8 21:13:34 2025 +0800

    bugfix: error image when use null value as image query condition in insert 
on duplicate #704 (#725)
    
    * bugfix #704
    
    * bugfix #704
    
    * bugfix704-2
    
    * bugfix-test-2
    
    * bugfix-test-2
    
    * pr725 bugfix
    
    ---------
    
    Co-authored-by: JayLiu <38887641+luky...@users.noreply.github.com>
    Co-authored-by: FengZhang <zfc...@qq.com>
---
 pkg/datasource/sql/types/image.go                  |  11 +-
 ...ql_insertonduplicate_update_undo_log_builder.go | 220 ++++++++++++++-------
 ...sertonduplicate_update_undo_log_builder_test.go |  63 ++++++
 3 files changed, 224 insertions(+), 70 deletions(-)

diff --git a/pkg/datasource/sql/types/image.go 
b/pkg/datasource/sql/types/image.go
index 3244cef9..62d836a7 100644
--- a/pkg/datasource/sql/types/image.go
+++ b/pkg/datasource/sql/types/image.go
@@ -18,6 +18,7 @@
 package types
 
 import (
+       "database/sql/driver"
        "encoding/base64"
        "encoding/json"
        "reflect"
@@ -117,14 +118,16 @@ type RecordImage struct {
        // Rows data row
        Rows []RowImage `json:"rows"`
        // TableMeta table information schema
-       TableMeta *TableMeta `json:"-"`
+       TableMeta     *TableMeta                `json:"-"`
+       PrimaryKeyMap map[string][]driver.Value `json:"primaryKeyMap,omitempty"`
 }
 
 func NewEmptyRecordImage(tableMeta *TableMeta, sqlType SQLType) *RecordImage {
        return &RecordImage{
-               TableName: tableMeta.TableName,
-               TableMeta: tableMeta,
-               SQLType:   sqlType,
+               TableName:     tableMeta.TableName,
+               TableMeta:     tableMeta,
+               SQLType:       sqlType,
+               PrimaryKeyMap: make(map[string][]driver.Value),
        }
 }
 
diff --git 
a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go
 
b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go
index 6b82d537..c40d4d9c 100644
--- 
a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go
+++ 
b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go
@@ -97,68 +97,108 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) 
buildBeforeImageSQL(insertStmt *a
        if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil {
                return "", nil, err
        }
-       var selectArgs []driver.Value
+       u.BeforeImageSqlPrimaryKeys = make(map[string]bool, 
len(metaData.Indexs))
        pkIndexMap := u.getPkIndex(insertStmt, metaData)
        var pkIndexArray []int
        for _, val := range pkIndexMap {
-               tmpVal := val
-               pkIndexArray = append(pkIndexArray, tmpVal)
+               pkIndexArray = append(pkIndexArray, val)
        }
        insertRows, err := getInsertRows(insertStmt, pkIndexArray)
        if err != nil {
                return "", nil, err
        }
-       insertNum := len(insertRows)
        paramMap, err := u.buildImageParameters(insertStmt, args, insertRows)
        if err != nil {
                return "", nil, err
        }
-
-       sql := strings.Builder{}
-       sql.WriteString("SELECT * FROM " + metaData.TableName + " ")
+       if len(paramMap) == 0 || len(metaData.Indexs) == 0 {
+               return "", nil, nil
+       }
+       hasPK := false
+       for _, index := range metaData.Indexs {
+               if strings.EqualFold("PRIMARY", index.Name) {
+                       allPKColumnsHaveValue := true
+                       for _, col := range index.Columns {
+                               if params, ok := paramMap[col.ColumnName]; !ok 
|| len(params) == 0 || params[0] == nil {
+                                       allPKColumnsHaveValue = false
+                                       break
+                               }
+                       }
+                       hasPK = allPKColumnsHaveValue
+                       break
+               }
+       }
+       if !hasPK {
+               hasValidUniqueIndex := false
+               for _, index := range metaData.Indexs {
+                       if !index.NonUnique && !strings.EqualFold("PRIMARY", 
index.Name) {
+                               if _, _, valid := validateIndexPrefix(index, 
paramMap, 0); valid {
+                                       hasValidUniqueIndex = true
+                                       break
+                               }
+                       }
+               }
+               if !hasValidUniqueIndex {
+                       return "", nil, nil
+               }
+       }
+       var sql strings.Builder
+       sql.WriteString("SELECT * FROM " + metaData.TableName + "  ")
+       var selectArgs []driver.Value
        isContainWhere := false
-       for i := 0; i < insertNum; i++ {
-               finalI := i
-               paramAppenderTempList := make([]driver.Value, 0)
+       hasConditions := false
+       for i := 0; i < len(insertRows); i++ {
+               var rowConditions []string
+               var rowArgs []driver.Value
+               usedParams := make(map[string]bool)
+
+               // First try unique indexes
                for _, index := range metaData.Indexs {
-                       //unique index
-                       if index.NonUnique || isIndexValueNotNull(index, 
paramMap, finalI) == false {
+                       if index.NonUnique || strings.EqualFold("PRIMARY", 
index.Name) {
                                continue
                        }
-                       columnIsNull := true
-                       uniqueList := make([]string, 0)
-                       for _, columnMeta := range index.Columns {
-                               columnName := 
strings.ToLower(columnMeta.ColumnName)
-                               imageParameters, ok := paramMap[columnName]
-                               if !ok && columnMeta.ColumnDef != nil {
-                                       if strings.EqualFold("PRIMARY", 
index.Name) {
-                                               
u.BeforeImageSqlPrimaryKeys[columnName] = true
-                                       }
-                                       uniqueList = append(uniqueList, 
columnName+" = DEFAULT("+columnName+") ")
-                                       columnIsNull = false
-                                       continue
-                               }
-                               if strings.EqualFold("PRIMARY", index.Name) {
-                                       u.BeforeImageSqlPrimaryKeys[columnName] 
= true
+                       if conditions, args, valid := 
validateIndexPrefix(index, paramMap, i); valid {
+                               rowConditions = append(rowConditions, 
"("+strings.Join(conditions, " and ")+")")
+                               rowArgs = append(rowArgs, args...)
+                               hasConditions = true
+                               for _, colMeta := range index.Columns {
+                                       usedParams[colMeta.ColumnName] = true
                                }
-                               columnIsNull = false
-                               uniqueList = append(uniqueList, columnName+" = 
? ")
-                               paramAppenderTempList = 
append(paramAppenderTempList, imageParameters[finalI])
                        }
+               }
 
-                       if !columnIsNull {
-                               if isContainWhere {
-                                       sql.WriteString(" OR (" + 
strings.Join(uniqueList, " and ") + ") ")
-                               } else {
-                                       sql.WriteString(" WHERE (" + 
strings.Join(uniqueList, " and ") + ") ")
-                                       isContainWhere = true
+               // Then try primary key
+               for _, index := range metaData.Indexs {
+                       if !strings.EqualFold("PRIMARY", index.Name) {
+                               continue
+                       }
+                       if conditions, args, valid := 
validateIndexPrefix(index, paramMap, i); valid {
+                               rowConditions = append(rowConditions, 
"("+strings.Join(conditions, " and ")+")")
+                               rowArgs = append(rowArgs, args...)
+                               hasConditions = true
+                               for _, colMeta := range index.Columns {
+                                       usedParams[colMeta.ColumnName] = true
                                }
                        }
                }
-               selectArgs = append(selectArgs, paramAppenderTempList...)
+
+               if len(rowConditions) > 0 {
+                       if !isContainWhere {
+                               sql.WriteString("WHERE ")
+                               isContainWhere = true
+                       } else {
+                               sql.WriteString(" OR ")
+                       }
+                       sql.WriteString(strings.Join(rowConditions, "  OR ") + 
" ")
+                       selectArgs = append(selectArgs, rowArgs...)
+               }
+       }
+       if !hasConditions {
+               return "", nil, nil
        }
-       log.Infof("build select sql by insert on duplicate sourceQuery, sql 
{}", sql.String())
-       return sql.String(), selectArgs, nil
+       sqlStr := sql.String()
+       log.Infof("build select sql by insert on duplicate sourceQuery, sql: 
%s", sqlStr)
+       return sqlStr, selectArgs, nil
 }
 
 func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, 
execCtx *types.ExecContext, beforeImages []*types.RecordImage) 
([]*types.RecordImage, error) {
@@ -168,14 +208,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) 
AfterImage(ctx context.Context, e
                log.Errorf("build prepare stmt: %+v", err)
                return nil, err
        }
-
+       defer stmt.Close()
+       tableName := 
execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O
+       metaData := execCtx.MetaDataMap[tableName]
        rows, err := stmt.Query(selectArgs)
        if err != nil {
-               log.Errorf("stmt query: %+v", err)
                return nil, err
        }
-       tableName := 
execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O
-       metaData := execCtx.MetaDataMap[tableName]
+       defer rows.Close()
        image, err := u.buildRecordImages(rows, &metaData)
        if err != nil {
                return nil, err
@@ -185,11 +225,13 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) 
AfterImage(ctx context.Context, e
 
 func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx 
context.Context, beforeImages []*types.RecordImage) (string, []driver.Value) {
        selectSQL, selectArgs := u.BeforeSelectSql, u.Args
-
        var beforeImage *types.RecordImage
        if len(beforeImages) > 0 {
                beforeImage = beforeImages[0]
        }
+       if beforeImage == nil || len(beforeImage.Rows) == 0 {
+               return selectSQL, selectArgs
+       }
        primaryValueMap := make(map[string][]interface{})
        for _, row := range beforeImage.Rows {
                for _, col := range row.Columns {
@@ -198,25 +240,46 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) 
buildAfterImageSQL(ctx context.Co
                        }
                }
        }
-
        var afterImageSql strings.Builder
-       var primaryValues []driver.Value
        afterImageSql.WriteString(selectSQL)
-       for i := 0; i < len(beforeImage.Rows); i++ {
-               wherePrimaryList := make([]string, 0)
-               for name, value := range primaryValueMap {
-                       if !u.BeforeImageSqlPrimaryKeys[name] {
-                               wherePrimaryList = append(wherePrimaryList, 
name+" = ? ")
-                               primaryValues = append(primaryValues, value[i])
+       if len(primaryValueMap) == 0 || len(selectArgs) == 
len(beforeImage.Rows)*len(primaryValueMap) {
+               return selectSQL, selectArgs
+       }
+       var primaryValues []driver.Value
+       usedPrimaryKeys := make(map[string]bool)
+       for name := range primaryValueMap {
+               if !u.BeforeImageSqlPrimaryKeys[name] {
+                       usedPrimaryKeys[name] = true
+                       for i := 0; i < len(beforeImage.Rows); i++ {
+                               if value := primaryValueMap[name][i]; value != 
nil {
+                                       if dv, ok := value.(driver.Value); ok {
+                                               primaryValues = 
append(primaryValues, dv)
+                                       } else {
+                                               primaryValues = 
append(primaryValues, value)
+                                       }
+                               }
                        }
                }
-               if len(wherePrimaryList) != 0 {
-                       afterImageSql.WriteString(" OR (" + 
strings.Join(wherePrimaryList, " and ") + ") ")
+       }
+       if len(primaryValues) > 0 {
+               afterImageSql.WriteString(" OR (" + 
strings.Join(u.buildPrimaryKeyConditions(primaryValueMap, usedPrimaryKeys), " 
and  ") + ") ")
+       }
+       finalArgs := make([]driver.Value, len(selectArgs)+len(primaryValues))
+       copy(finalArgs, selectArgs)
+       copy(finalArgs[len(selectArgs):], primaryValues)
+       sqlStr := afterImageSql.String()
+       log.Infof("build after select sql by insert on duplicate sourceQuery, 
sql %s", sqlStr)
+       return sqlStr, finalArgs
+}
+
+func (u *MySQLInsertOnDuplicateUndoLogBuilder) 
buildPrimaryKeyConditions(primaryValueMap map[string][]interface{}, 
usedPrimaryKeys map[string]bool) []string {
+       var conditions []string
+       for name := range primaryValueMap {
+               if !usedPrimaryKeys[name] {
+                       conditions = append(conditions, name+" = ? ")
                }
        }
-       selectArgs = append(selectArgs, primaryValues...)
-       log.Infof("build after select sql by insert on duplicate sourceQuery, 
sql {}", afterImageSql.String())
-       return afterImageSql.String(), selectArgs
+       return conditions
 }
 
 func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) 
error {
@@ -243,11 +306,10 @@ func checkDuplicateKeyUpdate(insert *ast.InsertStmt, 
metaData types.TableMeta) e
 
 // build sql params
 func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert 
*ast.InsertStmt, args []driver.Value, insertRows [][]interface{}) 
(map[string][]driver.Value, error) {
-       var (
-               parameterMap = make(map[string][]driver.Value)
-       )
+       parameterMap := make(map[string][]driver.Value)
        insertColumns := getInsertColumns(insert)
-       var placeHolderIndex = 0
+       placeHolderIndex := 0
+
        for _, row := range insertRows {
                if len(row) != len(insertColumns) {
                        log.Errorf("insert row's column size not equal to 
insert column size")
@@ -256,13 +318,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) 
buildImageParameters(insert *ast.
                for i, col := range insertColumns {
                        columnName := strings.ToLower(executor.DelEscape(col, 
types.DBTypeMySQL))
                        val := row[i]
-                       rStr, ok := val.(string)
-                       if ok && strings.EqualFold(rStr, SqlPlaceholder) {
-                               objects := args[placeHolderIndex]
-                               parameterMap[columnName] = 
append(parameterMap[col], objects)
+                       if str, ok := val.(string); ok && 
strings.EqualFold(str, SqlPlaceholder) {
+                               if placeHolderIndex >= len(args) {
+                                       return nil, fmt.Errorf("not enough 
parameters for placeholders")
+                               }
+                               parameterMap[columnName] = 
append(parameterMap[columnName], args[placeHolderIndex])
                                placeHolderIndex++
                        } else {
-                               parameterMap[columnName] = 
append(parameterMap[col], val)
+                               parameterMap[columnName] = 
append(parameterMap[columnName], val)
                        }
                }
        }
@@ -296,3 +359,28 @@ func isIndexValueNotNull(indexMeta types.IndexMeta, 
imageParameterMap map[string
        }
        return true
 }
+
+func validateIndexPrefix(index types.IndexMeta, paramMap 
map[string][]driver.Value, rowIndex int) ([]string, []driver.Value, bool) {
+       var indexConditions []string
+       var indexArgs []driver.Value
+       if len(index.Columns) > 1 {
+               for _, colMeta := range index.Columns {
+                       params, ok := paramMap[colMeta.ColumnName]
+                       if !ok || len(params) <= rowIndex || params[rowIndex] 
== nil {
+                               return nil, nil, false
+                       }
+               }
+       }
+       for _, colMeta := range index.Columns {
+               columnName := colMeta.ColumnName
+               params, ok := paramMap[columnName]
+               if ok && len(params) > rowIndex && params[rowIndex] != nil {
+                       indexConditions = append(indexConditions, columnName+" 
= ? ")
+                       indexArgs = append(indexArgs, params[rowIndex])
+               }
+       }
+       if len(indexConditions) != len(index.Columns) {
+               return nil, nil, false
+       }
+       return indexConditions, indexArgs, true
+}
diff --git 
a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go
 
b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go
index 59e673f7..f6e5b7cc 100644
--- 
a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go
+++ 
b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go
@@ -143,6 +143,69 @@ func TestInsertOnDuplicateBuildBeforeImageSQL(t 
*testing.T) {
                        expectQuery1:     "SELECT * FROM t_user  WHERE (name = 
?  and age = ? )  OR (name = ?  and age = ? ) ",
                        expectQueryArgs1: []driver.Value{"Jack1", 81, "Michal", 
int64(35)},
                },
+               // Test case for null unique index
+               {
+                       execCtx: &types.ExecContext{
+                               Query:       "insert into t_user(id, name, age) 
values(?, ?, ?) on duplicate key update age = ?",
+                               MetaDataMap: 
map[string]types.TableMeta{"t_user": tableMeta1},
+                       },
+                       sourceQueryArgs:  []driver.Value{1, nil, 2, 5},
+                       expectQuery1:     "SELECT * FROM t_user  WHERE (id = ? 
) ",
+                       expectQueryArgs1: []driver.Value{1},
+               },
+               // Test case for null primary key
+               {
+                       execCtx: &types.ExecContext{
+                               Query:       "insert into t_user(id, age) 
values(?, ?) on duplicate key update age = ?",
+                               MetaDataMap: 
map[string]types.TableMeta{"t_user": tableMeta1},
+                       },
+                       sourceQueryArgs:  []driver.Value{nil, 2, 5},
+                       expectQuery1:     "SELECT * FROM t_user WHERE (age = ? 
)",
+                       expectQueryArgs1: []driver.Value{2},
+               },
+               // Test case for null unique index with no primary key
+               {
+                       execCtx: &types.ExecContext{
+                               Query:       "insert into t_user(name, age) 
values(?, ?) on duplicate key update age = ?",
+                               MetaDataMap: 
map[string]types.TableMeta{"t_user": tableMeta2},
+                       },
+                       sourceQueryArgs:  []driver.Value{nil, 2, 5},
+                       expectQuery1:     "",
+                       expectQueryArgs1: nil,
+               },
+               // Test case for composite index with all columns
+               {
+                       name: "composite_index_full",
+                       execCtx: &types.ExecContext{
+                               Query:       "insert into t_user(id, name, age) 
values(?,?,?) on duplicate key update other = ?",
+                               MetaDataMap: 
map[string]types.TableMeta{"t_user": tableMeta1},
+                       },
+                       sourceQueryArgs:  []driver.Value{1, "Jack", 25, 
"other"},
+                       expectQuery1:     "SELECT * FROM t_user  WHERE (name = 
?  and age = ? )  OR (id = ? ) ",
+                       expectQueryArgs1: []driver.Value{"Jack", 25, 1},
+               },
+               // Test case for composite index with null value
+               {
+                       name: "composite_index_with_null",
+                       execCtx: &types.ExecContext{
+                               Query:       "insert into t_user(id, name, age) 
values(?,?,?) on duplicate key update other = ?",
+                               MetaDataMap: 
map[string]types.TableMeta{"t_user": tableMeta1},
+                       },
+                       sourceQueryArgs:  []driver.Value{1, "Jack", nil, 
"other"},
+                       expectQuery1:     "SELECT * FROM t_user  WHERE (id = ? 
) ",
+                       expectQueryArgs1: []driver.Value{1},
+               },
+               // Test case for composite index with leftmost prefix only
+               {
+                       name: "composite_index_leftmost_prefix",
+                       execCtx: &types.ExecContext{
+                               Query:       "insert into t_user(id, name) 
values(?,?) on duplicate key update other = ?",
+                               MetaDataMap: 
map[string]types.TableMeta{"t_user": tableMeta1},
+                       },
+                       sourceQueryArgs:  []driver.Value{1, "Jack", "other"},
+                       expectQuery1:     "SELECT * FROM t_user  WHERE (id = ? 
) ",
+                       expectQueryArgs1: []driver.Value{1},
+               },
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {


---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscr...@seata.apache.org
For additional commands, e-mail: notifications-h...@seata.apache.org

Reply via email to