This is an automated email from the ASF dual-hosted git repository.
jimin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-seata-go.git
The following commit(s) were added to refs/heads/master by this push:
new afb86509 test: Improve test coverage for pkg/datasource/sql/exec (#999)
afb86509 is described below
commit afb86509eb0cca555f897a94953a61a52f5da608
Author: flypiggy <[email protected]>
AuthorDate: Tue Nov 25 20:07:25 2025 +0800
test: Improve test coverage for pkg/datasource/sql/exec (#999)
---
.github/workflows/golangci-lint.yml | 1 -
pkg/datasource/sql/exec/at/at_executor_test.go | 422 ++++++++++++++++++
pkg/datasource/sql/exec/executor_test.go | 471 +++++++++++++++++++++
pkg/datasource/sql/exec/hook_test.go | 460 ++++++++++++++++++++
.../grpc/grpc_transaction_interceptor_test.go | 7 +-
5 files changed, 1356 insertions(+), 5 deletions(-)
diff --git a/.github/workflows/golangci-lint.yml
b/.github/workflows/golangci-lint.yml
index ef6b6932..8a3ee2e7 100644
--- a/.github/workflows/golangci-lint.yml
+++ b/.github/workflows/golangci-lint.yml
@@ -57,6 +57,5 @@ jobs:
with:
version: v1.51.0
args: --timeout=10m
- skip-go-installation: true
skip-cache: true
skip-pkg-cache: true
\ No newline at end of file
diff --git a/pkg/datasource/sql/exec/at/at_executor_test.go
b/pkg/datasource/sql/exec/at/at_executor_test.go
new file mode 100644
index 00000000..72322da0
--- /dev/null
+++ b/pkg/datasource/sql/exec/at/at_executor_test.go
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package at
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "testing"
+
+ "github.com/agiledragon/gomonkey/v2"
+ "github.com/stretchr/testify/assert"
+
+ "seata.apache.org/seata-go/pkg/datasource/sql/exec"
+ "seata.apache.org/seata-go/pkg/datasource/sql/parser"
+ "seata.apache.org/seata-go/pkg/datasource/sql/types"
+ "seata.apache.org/seata-go/pkg/tm"
+)
+
+func TestATExecutor_Interceptors(t *testing.T) {
+ tests := []struct {
+ name string
+ interceptors []exec.SQLHook
+ }{
+ {
+ name: "set empty interceptors",
+ interceptors: []exec.SQLHook{},
+ },
+ {
+ name: "set single interceptor",
+ interceptors: []exec.SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ },
+ },
+ {
+ name: "set multiple interceptors",
+ interceptors: []exec.SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ &mockSQLHook{sqlType: types.SQLTypeUpdate},
+ &mockSQLHook{sqlType: types.SQLTypeDelete},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ executor := &ATExecutor{}
+ executor.Interceptors(tt.interceptors)
+ assert.Equal(t, len(tt.interceptors),
len(executor.hooks), "hooks count should match")
+ })
+ }
+}
+
+func TestATExecutor_ExecWithNamedValue_NonGlobalTx(t *testing.T) {
+ tests := []struct {
+ name string
+ query string
+ sqlType types.SQLType
+ wantErr bool
+ callCount int
+ }{
+ {
+ name: "non-global transaction - INSERT",
+ query: "INSERT INTO users (name, age) VALUES (?,
?)",
+ sqlType: types.SQLTypeInsert,
+ wantErr: false,
+ callCount: 1,
+ },
+ {
+ name: "non-global transaction - UPDATE",
+ query: "UPDATE users SET age = ? WHERE name = ?",
+ sqlType: types.SQLTypeUpdate,
+ wantErr: false,
+ callCount: 1,
+ },
+ {
+ name: "non-global transaction - DELETE",
+ query: "DELETE FROM users WHERE name = ?",
+ sqlType: types.SQLTypeDelete,
+ wantErr: false,
+ callCount: 1,
+ },
+ {
+ name: "non-global transaction - SELECT",
+ query: "SELECT * FROM users WHERE name = ?",
+ sqlType: types.SQLTypeSelect,
+ wantErr: false,
+ callCount: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ patches := gomonkey.ApplyFunc(tm.IsGlobalTx, func(ctx
context.Context) bool {
+ return false
+ })
+ defer patches.Reset()
+
+ patchesParser := gomonkey.ApplyFunc(parser.DoParser,
func(query string) (*types.ParseContext, error) {
+ return &types.ParseContext{
+ SQLType: tt.sqlType,
+ }, nil
+ })
+ defer patchesParser.Reset()
+
+ executor := &ATExecutor{}
+ execCtx := &types.ExecContext{
+ Query: tt.query,
+ NamedValues: []driver.NamedValue{{Value:
"test"}},
+ }
+
+ callCount := 0
+ callback := func(ctx context.Context, query string,
args []driver.NamedValue) (types.ExecResult, error) {
+ callCount++
+ return &mockExecResult{rowsAffected: 1}, nil
+ }
+
+ result, err :=
executor.ExecWithNamedValue(context.Background(), execCtx, callback)
+
+ if tt.wantErr {
+ assert.Error(t, err, "should return error")
+ } else {
+ assert.NoError(t, err, "should not return
error")
+ assert.NotNil(t, result, "result should not be
nil")
+ assert.Equal(t, tt.callCount, callCount,
"callback should be called once")
+ }
+ })
+ }
+}
+
+func TestATExecutor_ExecWithNamedValue_GlobalTx(t *testing.T) {
+ tests := []struct {
+ name string
+ query string
+ sqlType types.SQLType
+ wantErr bool
+ }{
+ {
+ name: "global transaction - INSERT",
+ query: "INSERT INTO users (name, age) VALUES (?, ?)",
+ sqlType: types.SQLTypeInsert,
+ wantErr: false,
+ },
+ {
+ name: "global transaction - UPDATE",
+ query: "UPDATE users SET age = ? WHERE name = ?",
+ sqlType: types.SQLTypeUpdate,
+ wantErr: false,
+ },
+ {
+ name: "global transaction - DELETE",
+ query: "DELETE FROM users WHERE name = ?",
+ sqlType: types.SQLTypeDelete,
+ wantErr: false,
+ },
+ {
+ name: "global transaction - SELECT FOR UPDATE",
+ query: "SELECT * FROM users WHERE name = ? FOR
UPDATE",
+ sqlType: types.SQLTypeSelectForUpdate,
+ wantErr: false,
+ },
+ {
+ name: "global transaction - INSERT ON DUPLICATE
UPDATE",
+ query: "INSERT INTO users (id, name) VALUES (?, ?) ON
DUPLICATE KEY UPDATE name = ?",
+ sqlType: types.SQLTypeInsertOnDuplicateUpdate,
+ wantErr: false,
+ },
+ {
+ name: "global transaction - MULTI",
+ query: "UPDATE users SET age = ? WHERE id IN (?, ?)",
+ sqlType: types.SQLTypeMulti,
+ wantErr: false,
+ },
+ {
+ name: "global transaction - plain SQL (SELECT
without FOR UPDATE)",
+ query: "SELECT * FROM users WHERE name = ?",
+ sqlType: types.SQLTypeSelect,
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ patches := gomonkey.ApplyFunc(tm.IsGlobalTx, func(ctx
context.Context) bool {
+ return true
+ })
+ defer patches.Reset()
+
+ patchesParser := gomonkey.ApplyFunc(parser.DoParser,
func(query string) (*types.ParseContext, error) {
+ return &types.ParseContext{
+ SQLType: tt.sqlType,
+ }, nil
+ })
+ defer patchesParser.Reset()
+
+ executor := &ATExecutor{
+ hooks: []exec.SQLHook{
+ &mockSQLHook{sqlType: tt.sqlType},
+ },
+ }
+ execCtx := &types.ExecContext{
+ Query: tt.query,
+ NamedValues: []driver.NamedValue{{Value:
"test"}},
+ }
+
+ callCount := 0
+ callback := func(ctx context.Context, query string,
args []driver.NamedValue) (types.ExecResult, error) {
+ callCount++
+ return &mockExecResult{rowsAffected: 1}, nil
+ }
+
+ result, err :=
executor.ExecWithNamedValue(context.Background(), execCtx, callback)
+
+ if tt.wantErr {
+ assert.Error(t, err, "should return error")
+ } else {
+ assert.NoError(t, err, "should not return
error")
+ assert.NotNil(t, result, "result should not be
nil")
+ }
+ })
+ }
+}
+
+func TestATExecutor_ExecWithNamedValue_ParserError(t *testing.T) {
+ patches := gomonkey.ApplyFunc(parser.DoParser, func(query string)
(*types.ParseContext, error) {
+ return nil, fmt.Errorf("parser error")
+ })
+ defer patches.Reset()
+
+ executor := &ATExecutor{}
+ execCtx := &types.ExecContext{
+ Query: "INVALID SQL",
+ NamedValues: []driver.NamedValue{},
+ }
+
+ callback := func(ctx context.Context, query string, args
[]driver.NamedValue) (types.ExecResult, error) {
+ return &mockExecResult{rowsAffected: 1}, nil
+ }
+
+ result, err := executor.ExecWithNamedValue(context.Background(),
execCtx, callback)
+
+ assert.Error(t, err, "should return parser error")
+ assert.Nil(t, result, "result should be nil on error")
+ assert.Contains(t, err.Error(), "parser error", "error message should
contain parser error")
+}
+
+func TestATExecutor_ExecWithValue(t *testing.T) {
+ tests := []struct {
+ name string
+ query string
+ values []driver.Value
+ sqlType types.SQLType
+ wantErr bool
+ callCount int
+ }{
+ {
+ name: "execute with values - converts to
NamedValues",
+ query: "INSERT INTO users (name, age) VALUES (?,
?)",
+ values: []driver.Value{"Alice", 30},
+ sqlType: types.SQLTypeInsert,
+ wantErr: false,
+ callCount: 1,
+ },
+ {
+ name: "execute with empty values",
+ query: "INSERT INTO users (name, age) VALUES
('Bob', 25)",
+ values: []driver.Value{},
+ sqlType: types.SQLTypeInsert,
+ wantErr: false,
+ callCount: 1,
+ },
+ {
+ name: "execute with multiple values",
+ query: "UPDATE users SET name = ?, age = ?, city =
? WHERE id = ?",
+ values: []driver.Value{"Charlie", 35, "NYC", 1},
+ sqlType: types.SQLTypeUpdate,
+ wantErr: false,
+ callCount: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ patches := gomonkey.ApplyFunc(tm.IsGlobalTx, func(ctx
context.Context) bool {
+ return false
+ })
+ defer patches.Reset()
+
+ patchesParser := gomonkey.ApplyFunc(parser.DoParser,
func(query string) (*types.ParseContext, error) {
+ return &types.ParseContext{
+ SQLType: tt.sqlType,
+ }, nil
+ })
+ defer patchesParser.Reset()
+
+ executor := &ATExecutor{}
+ execCtx := &types.ExecContext{
+ Query: tt.query,
+ Values: tt.values,
+ }
+
+ callCount := 0
+ callback := func(ctx context.Context, query string,
args []driver.NamedValue) (types.ExecResult, error) {
+ callCount++
+ return &mockExecResult{rowsAffected: 1}, nil
+ }
+
+ result, err :=
executor.ExecWithValue(context.Background(), execCtx, callback)
+
+ if tt.wantErr {
+ assert.Error(t, err, "should return error")
+ } else {
+ assert.NoError(t, err, "should not return
error")
+ assert.NotNil(t, result, "result should not be
nil")
+ assert.Equal(t, tt.callCount, callCount,
"callback should be called")
+ // Verify that NamedValues were populated in
execCtx
+ assert.NotNil(t, execCtx.NamedValues,
"NamedValues should be populated")
+ assert.Equal(t, len(tt.values),
len(execCtx.NamedValues), "NamedValues count should match values count")
+ }
+ })
+ }
+}
+
+func TestATExecutor_ExecWithValue_ParserError(t *testing.T) {
+ patches := gomonkey.ApplyFunc(parser.DoParser, func(query string)
(*types.ParseContext, error) {
+ return nil, fmt.Errorf("parser error")
+ })
+ defer patches.Reset()
+
+ executor := &ATExecutor{}
+ execCtx := &types.ExecContext{
+ Query: "INVALID SQL",
+ Values: []driver.Value{"test"},
+ }
+
+ callback := func(ctx context.Context, query string, args
[]driver.NamedValue) (types.ExecResult, error) {
+ return &mockExecResult{rowsAffected: 1}, nil
+ }
+
+ result, err := executor.ExecWithValue(context.Background(), execCtx,
callback)
+
+ assert.Error(t, err, "should return parser error")
+ assert.Nil(t, result, "result should be nil on error")
+}
+
+type mockSQLHook struct {
+ sqlType types.SQLType
+ beforeCallCount int
+ afterCallCount int
+ beforeError error
+ afterError error
+}
+
+func (m *mockSQLHook) Type() types.SQLType {
+ return m.sqlType
+}
+
+func (m *mockSQLHook) Before(ctx context.Context, execCtx *types.ExecContext)
error {
+ m.beforeCallCount++
+ return m.beforeError
+}
+
+func (m *mockSQLHook) After(ctx context.Context, execCtx *types.ExecContext)
error {
+ m.afterCallCount++
+ return m.afterError
+}
+
+type mockExecutor struct {
+ execContextFunc func(ctx context.Context, f
exec.CallbackWithNamedValue) (types.ExecResult, error)
+}
+
+func (m *mockExecutor) ExecContext(ctx context.Context, f
exec.CallbackWithNamedValue) (types.ExecResult, error) {
+ if m.execContextFunc != nil {
+ return m.execContextFunc(ctx, f)
+ }
+ return &mockExecResult{rowsAffected: 1}, nil
+}
+
+type mockExecResult struct {
+ lastInsertID int64
+ rowsAffected int64
+ rows driver.Rows
+}
+
+func (m *mockExecResult) GetRows() driver.Rows {
+ return m.rows
+}
+
+func (m *mockExecResult) GetResult() driver.Result {
+ return &mockResult{
+ lastInsertID: m.lastInsertID,
+ rowsAffected: m.rowsAffected,
+ }
+}
+
+type mockResult struct {
+ lastInsertID int64
+ rowsAffected int64
+}
+
+func (m *mockResult) LastInsertId() (int64, error) {
+ return m.lastInsertID, nil
+}
+
+func (m *mockResult) RowsAffected() (int64, error) {
+ return m.rowsAffected, nil
+}
diff --git a/pkg/datasource/sql/exec/executor_test.go
b/pkg/datasource/sql/exec/executor_test.go
new file mode 100644
index 00000000..3abcfd66
--- /dev/null
+++ b/pkg/datasource/sql/exec/executor_test.go
@@ -0,0 +1,471 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package exec
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "testing"
+
+ "github.com/DATA-DOG/go-sqlmock"
+ "github.com/stretchr/testify/assert"
+
+ "seata.apache.org/seata-go/pkg/datasource/sql/types"
+)
+
+// TestRegisterATExecutor tests the RegisterATExecutor function
+func TestRegisterATExecutor(t *testing.T) {
+ // Clean up before test
+ originalExecutors := atExecutors
+ defer func() { atExecutors = originalExecutors }()
+ atExecutors = make(map[types.DBType]func() SQLExecutor)
+
+ tests := []struct {
+ name string
+ dbType types.DBType
+ builder func() SQLExecutor
+ }{
+ {
+ name: "register MySQL executor",
+ dbType: types.DBTypeMySQL,
+ builder: func() SQLExecutor {
+ return &mockSQLExecutor{}
+ },
+ },
+ {
+ name: "register PostgreSQL executor",
+ dbType: types.DBTypePostgreSQL,
+ builder: func() SQLExecutor {
+ return &mockSQLExecutor{}
+ },
+ },
+ {
+ name: "register unknown type executor",
+ dbType: types.DBTypeUnknown,
+ builder: func() SQLExecutor {
+ return &mockSQLExecutor{}
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ RegisterATExecutor(tt.dbType, tt.builder)
+ assert.NotNil(t, atExecutors[tt.dbType], "executor
should be registered")
+ executor := atExecutors[tt.dbType]()
+ assert.NotNil(t, executor, "executor builder should
return an executor")
+ })
+ }
+}
+
+// TestBuildExecutor tests the BuildExecutor function
+func TestBuildExecutor(t *testing.T) {
+ // Setup: register a mock executor
+ originalExecutors := atExecutors
+ originalCommonHook := commonHook
+ originalHookSolts := hookSolts
+ defer func() {
+ atExecutors = originalExecutors
+ commonHook = originalCommonHook
+ hookSolts = originalHookSolts
+ }()
+
+ atExecutors = make(map[types.DBType]func() SQLExecutor)
+ commonHook = make([]SQLHook, 0, 4)
+ hookSolts = map[types.SQLType][]SQLHook{}
+
+ mockExecutor := &mockSQLExecutor{}
+ RegisterATExecutor(types.DBTypeMySQL, func() SQLExecutor {
+ return mockExecutor
+ })
+
+ // Register test hooks
+ RegisterCommonHook(&mockSQLHook{sqlType: types.SQLTypeSelect})
+ RegisterHook(&mockSQLHook{sqlType: types.SQLTypeInsert})
+
+ tests := []struct {
+ name string
+ dbType types.DBType
+ transactionMode types.TransactionMode
+ query string
+ wantErr bool
+ errMsg string
+ }{
+ {
+ name: "build executor for INSERT statement",
+ dbType: types.DBTypeMySQL,
+ transactionMode: types.ATMode,
+ query: "INSERT INTO users (name, age) VALUES
('Alice', 30)",
+ wantErr: false,
+ },
+ {
+ name: "build executor for UPDATE statement",
+ dbType: types.DBTypeMySQL,
+ transactionMode: types.ATMode,
+ query: "UPDATE users SET age = 31 WHERE name
= 'Alice'",
+ wantErr: false,
+ },
+ {
+ name: "build executor for DELETE statement",
+ dbType: types.DBTypeMySQL,
+ transactionMode: types.ATMode,
+ query: "DELETE FROM users WHERE name =
'Alice'",
+ wantErr: false,
+ },
+ {
+ name: "build executor for SELECT statement",
+ dbType: types.DBTypeMySQL,
+ transactionMode: types.ATMode,
+ query: "SELECT * FROM users WHERE name =
'Alice'",
+ wantErr: false,
+ },
+ {
+ name: "invalid SQL query",
+ dbType: types.DBTypeMySQL,
+ transactionMode: types.ATMode,
+ query: "INVALID SQL QUERY",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ executor, err := BuildExecutor(tt.dbType,
tt.transactionMode, tt.query)
+ if tt.wantErr {
+ assert.Error(t, err, "should return error for
invalid query")
+ assert.Nil(t, executor, "executor should be nil
on error")
+ } else {
+ assert.NoError(t, err, "should not return error
for valid query")
+ assert.NotNil(t, executor, "executor should not
be nil")
+ // Verify that the mock executor received the
interceptors
+ assert.True(t, mockExecutor.interceptorsCalled,
"Interceptors should be called")
+ assert.NotEmpty(t, mockExecutor.hooks, "hooks
should be set")
+ }
+ })
+ }
+}
+
+// TestBaseExecutor_Interceptors tests the Interceptors method
+func TestBaseExecutor_Interceptors(t *testing.T) {
+ tests := []struct {
+ name string
+ interceptors []SQLHook
+ }{
+ {
+ name: "set empty interceptors",
+ interceptors: []SQLHook{},
+ },
+ {
+ name: "set single interceptor",
+ interceptors: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ },
+ },
+ {
+ name: "set multiple interceptors",
+ interceptors: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ &mockSQLHook{sqlType: types.SQLTypeUpdate},
+ &mockSQLHook{sqlType: types.SQLTypeDelete},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ executor := &BaseExecutor{}
+ executor.Interceptors(tt.interceptors)
+ assert.Equal(t, len(tt.interceptors),
len(executor.hooks), "hooks count should match")
+ })
+ }
+}
+
+// TestBaseExecutor_ExecWithNamedValue tests the ExecWithNamedValue method
+func TestBaseExecutor_ExecWithNamedValue(t *testing.T) {
+ tests := []struct {
+ name string
+ setupHooks []SQLHook
+ innerExecutor SQLExecutor
+ callback CallbackWithNamedValue
+ execCtx *types.ExecContext
+ wantErr bool
+ wantBeforeCount int
+ wantAfterCount int
+ }{
+ {
+ name: "execute without hooks and without inner
executor",
+ callback: func(ctx context.Context, query string, args
[]driver.NamedValue) (types.ExecResult, error) {
+ return newMockExecResult(1, 1), nil
+ },
+ execCtx: &types.ExecContext{
+ Query: "INSERT INTO users VALUES (?, ?)",
+ NamedValues: []driver.NamedValue{{Value:
"Alice"}, {Value: 30}},
+ },
+ wantErr: false,
+ wantBeforeCount: 0,
+ wantAfterCount: 0,
+ },
+ {
+ name: "execute with hooks",
+ setupHooks: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ },
+ callback: func(ctx context.Context, query string, args
[]driver.NamedValue) (types.ExecResult, error) {
+ return newMockExecResult(1, 1), nil
+ },
+ execCtx: &types.ExecContext{
+ Query: "INSERT INTO users VALUES (?, ?)",
+ NamedValues: []driver.NamedValue{{Value:
"Alice"}, {Value: 30}},
+ },
+ wantErr: false,
+ wantBeforeCount: 2,
+ wantAfterCount: 2,
+ },
+ {
+ name: "callback returns error",
+ setupHooks: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ },
+ callback: func(ctx context.Context, query string, args
[]driver.NamedValue) (types.ExecResult, error) {
+ return nil, fmt.Errorf("execution error")
+ },
+ execCtx: &types.ExecContext{
+ Query: "INSERT INTO users VALUES (?, ?)",
+ NamedValues: []driver.NamedValue{{Value:
"Alice"}, {Value: 30}},
+ },
+ wantErr: true,
+ wantBeforeCount: 1,
+ wantAfterCount: 1, // After hooks are called even on
error (defer)
+ },
+ {
+ name: "execute with inner executor",
+ setupHooks: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ },
+ innerExecutor: &mockSQLExecutor{
+ execWithNamedValueFunc: func(ctx
context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue)
(types.ExecResult, error) {
+ return newMockExecResult(2, 2), nil
+ },
+ },
+ callback: func(ctx context.Context, query string, args
[]driver.NamedValue) (types.ExecResult, error) {
+ // This should not be called when inner
executor is set
+ panic("callback should not be called")
+ },
+ execCtx: &types.ExecContext{
+ Query: "INSERT INTO users VALUES (?, ?)",
+ NamedValues: []driver.NamedValue{{Value:
"Bob"}, {Value: 25}},
+ },
+ wantErr: false,
+ wantBeforeCount: 1,
+ wantAfterCount: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ executor := &BaseExecutor{
+ hooks: tt.setupHooks,
+ ex: tt.innerExecutor,
+ }
+
+ result, err :=
executor.ExecWithNamedValue(context.Background(), tt.execCtx, tt.callback)
+
+ if tt.wantErr {
+ assert.Error(t, err, "should return error")
+ } else {
+ assert.NoError(t, err, "should not return
error")
+ assert.NotNil(t, result, "result should not be
nil")
+ }
+
+ // Verify hooks were called
+ beforeCount := 0
+ afterCount := 0
+ for _, hook := range tt.setupHooks {
+ mockHook := hook.(*mockSQLHook)
+ beforeCount += mockHook.beforeCallCount
+ afterCount += mockHook.afterCallCount
+ }
+ assert.Equal(t, tt.wantBeforeCount, beforeCount,
"before hook call count should match")
+ assert.Equal(t, tt.wantAfterCount, afterCount, "after
hook call count should match")
+ })
+ }
+}
+
+// TestBaseExecutor_ExecWithValue tests the ExecWithValue method
+func TestBaseExecutor_ExecWithValue(t *testing.T) {
+ tests := []struct {
+ name string
+ setupHooks []SQLHook
+ callback CallbackWithNamedValue
+ execCtx *types.ExecContext
+ wantErr bool
+ wantBeforeCount int
+ wantAfterCount int
+ }{
+ {
+ name: "execute with values - converts to NamedValues",
+ callback: func(ctx context.Context, query string, args
[]driver.NamedValue) (types.ExecResult, error) {
+ // Verify that values were converted to
NamedValues
+ assert.NotEmpty(t, args, "args should not be
empty")
+ return newMockExecResult(1, 1), nil
+ },
+ execCtx: &types.ExecContext{
+ Query: "UPDATE users SET age = ? WHERE name =
?",
+ Values: []driver.Value{31, "Alice"},
+ },
+ wantErr: false,
+ wantBeforeCount: 0,
+ wantAfterCount: 0,
+ },
+ {
+ name: "execute with values and hooks",
+ setupHooks: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeUpdate},
+ },
+ callback: func(ctx context.Context, query string, args
[]driver.NamedValue) (types.ExecResult, error) {
+ return newMockExecResult(1, 1), nil
+ },
+ execCtx: &types.ExecContext{
+ Query: "UPDATE users SET age = ? WHERE name =
?",
+ Values: []driver.Value{31, "Alice"},
+ },
+ wantErr: false,
+ wantBeforeCount: 1,
+ wantAfterCount: 1,
+ },
+ {
+ name: "callback returns error with values",
+ setupHooks: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeUpdate},
+ },
+ callback: func(ctx context.Context, query string, args
[]driver.NamedValue) (types.ExecResult, error) {
+ return nil, fmt.Errorf("execution error")
+ },
+ execCtx: &types.ExecContext{
+ Query: "UPDATE users SET age = ? WHERE name =
?",
+ Values: []driver.Value{31, "Alice"},
+ },
+ wantErr: true,
+ wantBeforeCount: 1,
+ wantAfterCount: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ executor := &BaseExecutor{
+ hooks: tt.setupHooks,
+ }
+
+ result, err :=
executor.ExecWithValue(context.Background(), tt.execCtx, tt.callback)
+
+ if tt.wantErr {
+ assert.Error(t, err, "should return error")
+ } else {
+ assert.NoError(t, err, "should not return
error")
+ assert.NotNil(t, result, "result should not be
nil")
+ }
+
+ // Verify hooks were called
+ beforeCount := 0
+ afterCount := 0
+ for _, hook := range tt.setupHooks {
+ mockHook := hook.(*mockSQLHook)
+ beforeCount += mockHook.beforeCallCount
+ afterCount += mockHook.afterCallCount
+ }
+ assert.Equal(t, tt.wantBeforeCount, beforeCount,
"before hook call count should match")
+ assert.Equal(t, tt.wantAfterCount, afterCount, "after
hook call count should match")
+ })
+ }
+}
+
+// Mock implementations for testing
+
+// mockSQLExecutor is a mock implementation of SQLExecutor
+type mockSQLExecutor struct {
+ hooks []SQLHook
+ interceptorsCalled bool
+ execWithNamedValueFunc func(ctx context.Context, execCtx
*types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error)
+ execWithValueFunc func(ctx context.Context, execCtx
*types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error)
+}
+
+func (m *mockSQLExecutor) Interceptors(interceptors []SQLHook) {
+ m.hooks = interceptors
+ m.interceptorsCalled = true
+}
+
+func (m *mockSQLExecutor) ExecWithNamedValue(ctx context.Context, execCtx
*types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) {
+ if m.execWithNamedValueFunc != nil {
+ return m.execWithNamedValueFunc(ctx, execCtx, f)
+ }
+ return f(ctx, execCtx.Query, execCtx.NamedValues)
+}
+
+func (m *mockSQLExecutor) ExecWithValue(ctx context.Context, execCtx
*types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) {
+ if m.execWithValueFunc != nil {
+ return m.execWithValueFunc(ctx, execCtx, f)
+ }
+ return f(ctx, execCtx.Query, execCtx.NamedValues)
+}
+
+// mockSQLHook is a mock implementation of SQLHook
+type mockSQLHook struct {
+ sqlType types.SQLType
+ beforeCallCount int
+ afterCallCount int
+ beforeError error
+ afterError error
+}
+
+func (m *mockSQLHook) Type() types.SQLType {
+ return m.sqlType
+}
+
+func (m *mockSQLHook) Before(ctx context.Context, execCtx *types.ExecContext)
error {
+ m.beforeCallCount++
+ return m.beforeError
+}
+
+func (m *mockSQLHook) After(ctx context.Context, execCtx *types.ExecContext)
error {
+ m.afterCallCount++
+ return m.afterError
+}
+
+// mockExecResult is a mock implementation of types.ExecResult
+type mockExecResult struct {
+ lastInsertID int64
+ rowsAffected int64
+}
+
+func newMockExecResult(lastInsertID, rowsAffected int64) types.ExecResult {
+ return &mockExecResult{
+ lastInsertID: lastInsertID,
+ rowsAffected: rowsAffected,
+ }
+}
+
+func (m *mockExecResult) GetRows() driver.Rows {
+ panic("not implemented for write result")
+}
+
+func (m *mockExecResult) GetResult() driver.Result {
+ return sqlmock.NewResult(m.lastInsertID, m.rowsAffected)
+}
diff --git a/pkg/datasource/sql/exec/hook_test.go
b/pkg/datasource/sql/exec/hook_test.go
new file mode 100644
index 00000000..8ca84653
--- /dev/null
+++ b/pkg/datasource/sql/exec/hook_test.go
@@ -0,0 +1,460 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package exec
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+
+ "seata.apache.org/seata-go/pkg/datasource/sql/types"
+)
+
+// TestRegisterCommonHook tests the RegisterCommonHook function
+func TestRegisterCommonHook(t *testing.T) {
+ // Save original state and restore after test
+ originalCommonHook := commonHook
+ defer func() { commonHook = originalCommonHook }()
+
+ tests := []struct {
+ name string
+ initialHooks []SQLHook
+ hooksToAdd []SQLHook
+ expectedCount int
+ }{
+ {
+ name: "register single common hook",
+ initialHooks: []SQLHook{},
+ hooksToAdd: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ },
+ expectedCount: 1,
+ },
+ {
+ name: "register multiple common hooks",
+ initialHooks: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ },
+ hooksToAdd: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeUpdate},
+ &mockSQLHook{sqlType: types.SQLTypeDelete},
+ },
+ expectedCount: 3,
+ },
+ {
+ name: "register hook with unknown type",
+ initialHooks: []SQLHook{},
+ hooksToAdd: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeUnknown},
+ },
+ expectedCount: 1, // Common hooks accept any type
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Setup initial state
+ commonHook = make([]SQLHook, len(tt.initialHooks))
+ copy(commonHook, tt.initialHooks)
+
+ // Register hooks
+ for _, hook := range tt.hooksToAdd {
+ RegisterCommonHook(hook)
+ }
+
+ // Verify
+ assert.Equal(t, tt.expectedCount, len(commonHook),
"common hook count should match")
+ assert.Equal(t, cap(commonHook) >= len(commonHook),
true, "capacity should be sufficient")
+ })
+ }
+}
+
+// TestCleanCommonHook tests the CleanCommonHook function
+func TestCleanCommonHook(t *testing.T) {
+ // Save original state and restore after test
+ originalCommonHook := commonHook
+ defer func() { commonHook = originalCommonHook }()
+
+ tests := []struct {
+ name string
+ initialHooks []SQLHook
+ }{
+ {
+ name: "clean empty common hooks",
+ initialHooks: []SQLHook{},
+ },
+ {
+ name: "clean single common hook",
+ initialHooks: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ },
+ },
+ {
+ name: "clean multiple common hooks",
+ initialHooks: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ &mockSQLHook{sqlType: types.SQLTypeUpdate},
+ &mockSQLHook{sqlType: types.SQLTypeDelete},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Setup
+ commonHook = make([]SQLHook, len(tt.initialHooks))
+ copy(commonHook, tt.initialHooks)
+
+ // Execute
+ CleanCommonHook()
+
+ // Verify
+ assert.Equal(t, 0, len(commonHook), "common hooks
should be empty")
+ assert.Equal(t, 4, cap(commonHook), "capacity should be
reset to 4")
+ })
+ }
+}
+
+// TestRegisterHook tests the RegisterHook function
+func TestRegisterHook(t *testing.T) {
+ // Save original state and restore after test
+ originalHookSolts := hookSolts
+ defer func() { hookSolts = originalHookSolts }()
+
+ tests := []struct {
+ name string
+ hooksToRegister []SQLHook
+ expectedSQLTypes []types.SQLType
+ expectedCounts map[types.SQLType]int
+ skipUnknownType bool
+ }{
+ {
+ name: "register single hook",
+ hooksToRegister: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ },
+ expectedSQLTypes: []types.SQLType{types.SQLTypeInsert},
+ expectedCounts: map[types.SQLType]int{
+ types.SQLTypeInsert: 1,
+ },
+ },
+ {
+ name: "register multiple hooks of same type",
+ hooksToRegister: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeUpdate},
+ &mockSQLHook{sqlType: types.SQLTypeUpdate},
+ &mockSQLHook{sqlType: types.SQLTypeUpdate},
+ },
+ expectedSQLTypes: []types.SQLType{types.SQLTypeUpdate},
+ expectedCounts: map[types.SQLType]int{
+ types.SQLTypeUpdate: 3,
+ },
+ },
+ {
+ name: "register hooks of different types",
+ hooksToRegister: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ &mockSQLHook{sqlType: types.SQLTypeUpdate},
+ &mockSQLHook{sqlType: types.SQLTypeDelete},
+ &mockSQLHook{sqlType: types.SQLTypeSelect},
+ },
+ expectedSQLTypes: []types.SQLType{
+ types.SQLTypeInsert,
+ types.SQLTypeUpdate,
+ types.SQLTypeDelete,
+ types.SQLTypeSelect,
+ },
+ expectedCounts: map[types.SQLType]int{
+ types.SQLTypeInsert: 1,
+ types.SQLTypeUpdate: 1,
+ types.SQLTypeDelete: 1,
+ types.SQLTypeSelect: 1,
+ },
+ },
+ {
+ name: "register hook with unknown type should be
skipped",
+ hooksToRegister: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeUnknown},
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ },
+ expectedSQLTypes: []types.SQLType{types.SQLTypeInsert},
+ expectedCounts: map[types.SQLType]int{
+ types.SQLTypeInsert: 1,
+ },
+ skipUnknownType: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Setup
+ hookSolts = map[types.SQLType][]SQLHook{}
+
+ // Register hooks
+ for _, hook := range tt.hooksToRegister {
+ RegisterHook(hook)
+ }
+
+ // Verify expected SQL types are registered
+ for _, sqlType := range tt.expectedSQLTypes {
+ hooks, exists := hookSolts[sqlType]
+ assert.True(t, exists, "hook slot should exist
for SQL type %v", sqlType)
+ expectedCount := tt.expectedCounts[sqlType]
+ assert.Equal(t, expectedCount, len(hooks),
"hook count should match for SQL type %v", sqlType)
+ }
+
+ // Verify unknown type is skipped
+ if tt.skipUnknownType {
+ _, exists := hookSolts[types.SQLTypeUnknown]
+ assert.False(t, exists, "unknown type should
not be registered")
+ }
+ })
+ }
+}
+
+// TestHookExecution tests the execution order and behavior of hooks
+func TestHookExecution(t *testing.T) {
+ tests := []struct {
+ name string
+ hooks []SQLHook
+ wantBeforeCount int
+ wantAfterCount int
+ }{
+ {
+ name: "hooks execute in order",
+ hooks: []SQLHook{
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ &mockSQLHook{sqlType: types.SQLTypeInsert},
+ },
+ wantBeforeCount: 3,
+ wantAfterCount: 3,
+ },
+ {
+ name: "before hook error does not prevent execution",
+ hooks: []SQLHook{
+ &mockSQLHook{
+ sqlType: types.SQLTypeInsert,
+ beforeError: fmt.Errorf("before hook
error"),
+ },
+ },
+ wantBeforeCount: 1,
+ wantAfterCount: 1, // After hooks run even if before
fails (error is ignored)
+ },
+ {
+ name: "after hook error is logged but doesn't fail",
+ hooks: []SQLHook{
+ &mockSQLHook{
+ sqlType: types.SQLTypeInsert,
+ afterError: fmt.Errorf("after hook
error"),
+ },
+ },
+ wantBeforeCount: 1,
+ wantAfterCount: 1, // After hooks always run
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ executor := &BaseExecutor{
+ hooks: tt.hooks,
+ }
+
+ execCtx := &types.ExecContext{
+ Query: "INSERT INTO users VALUES (?, ?)",
+ NamedValues: []driver.NamedValue{{Value:
"Alice"}, {Value: 30}},
+ }
+
+ callback := func(ctx context.Context, query string,
args []driver.NamedValue) (types.ExecResult, error) {
+ return newMockExecResult(1, 1), nil
+ }
+
+ _, err :=
executor.ExecWithNamedValue(context.Background(), execCtx, callback)
+ assert.NoError(t, err, "should not return error")
+
+ // Verify hook call counts
+ beforeCount := 0
+ afterCount := 0
+ for _, hook := range tt.hooks {
+ mockHook := hook.(*mockSQLHook)
+ beforeCount += mockHook.beforeCallCount
+ afterCount += mockHook.afterCallCount
+ }
+ assert.Equal(t, tt.wantBeforeCount, beforeCount,
"before hook call count should match")
+ assert.Equal(t, tt.wantAfterCount, afterCount, "after
hook call count should match")
+ })
+ }
+}
+
+// TestHookIntegration tests the integration between hook registration and
execution
+func TestHookIntegration(t *testing.T) {
+ // Save original state and restore after test
+ originalCommonHook := commonHook
+ originalHookSolts := hookSolts
+ defer func() {
+ commonHook = originalCommonHook
+ hookSolts = originalHookSolts
+ }()
+
+ // Clean state
+ commonHook = make([]SQLHook, 0, 4)
+ hookSolts = map[types.SQLType][]SQLHook{}
+
+ // Register common hooks
+ commonHook1 := &mockSQLHook{sqlType: types.SQLTypeSelect}
+ commonHook2 := &mockSQLHook{sqlType: types.SQLTypeUpdate}
+ RegisterCommonHook(commonHook1)
+ RegisterCommonHook(commonHook2)
+
+ // Register type-specific hooks
+ insertHook1 := &mockSQLHook{sqlType: types.SQLTypeInsert}
+ insertHook2 := &mockSQLHook{sqlType: types.SQLTypeInsert}
+ RegisterHook(insertHook1)
+ RegisterHook(insertHook2)
+
+ updateHook := &mockSQLHook{sqlType: types.SQLTypeUpdate}
+ RegisterHook(updateHook)
+
+ // Verify registration
+ assert.Equal(t, 2, len(commonHook), "should have 2 common hooks")
+ assert.Equal(t, 2, len(hookSolts[types.SQLTypeInsert]), "should have 2
insert hooks")
+ assert.Equal(t, 1, len(hookSolts[types.SQLTypeUpdate]), "should have 1
update hook")
+
+ // Test hook execution with BaseExecutor
+ executor := &BaseExecutor{
+ hooks: []SQLHook{commonHook1, insertHook1, insertHook2},
+ }
+
+ execCtx := &types.ExecContext{
+ Query: "INSERT INTO users VALUES (?, ?)",
+ NamedValues: []driver.NamedValue{{Value: "Alice"}, {Value: 30}},
+ }
+
+ callback := func(ctx context.Context, query string, args
[]driver.NamedValue) (types.ExecResult, error) {
+ return newMockExecResult(1, 1), nil
+ }
+
+ result, err := executor.ExecWithNamedValue(context.Background(),
execCtx, callback)
+
+ assert.NoError(t, err, "execution should succeed")
+ assert.NotNil(t, result, "result should not be nil")
+
+ // Verify all hooks were called
+ assert.Equal(t, 1, commonHook1.beforeCallCount, "common hook 1 before
should be called")
+ assert.Equal(t, 1, commonHook1.afterCallCount, "common hook 1 after
should be called")
+ assert.Equal(t, 1, insertHook1.beforeCallCount, "insert hook 1 before
should be called")
+ assert.Equal(t, 1, insertHook1.afterCallCount, "insert hook 1 after
should be called")
+ assert.Equal(t, 1, insertHook2.beforeCallCount, "insert hook 2 before
should be called")
+ assert.Equal(t, 1, insertHook2.afterCallCount, "insert hook 2 after
should be called")
+
+ // Clean hooks and verify
+ CleanCommonHook()
+ assert.Equal(t, 0, len(commonHook), "common hooks should be cleaned")
+}
+
+// TestHookChainExecution tests that hooks execute in a chain and respect the
execution order
+func TestHookChainExecution(t *testing.T) {
+ executionOrder := []string{}
+
+ hook1 := &trackingHook{
+ sqlType: types.SQLTypeInsert,
+ onBefore: func() {
+ executionOrder = append(executionOrder, "hook1-before")
+ },
+ onAfter: func() {
+ executionOrder = append(executionOrder, "hook1-after")
+ },
+ }
+
+ hook2 := &trackingHook{
+ sqlType: types.SQLTypeInsert,
+ onBefore: func() {
+ executionOrder = append(executionOrder, "hook2-before")
+ },
+ onAfter: func() {
+ executionOrder = append(executionOrder, "hook2-after")
+ },
+ }
+
+ hook3 := &trackingHook{
+ sqlType: types.SQLTypeInsert,
+ onBefore: func() {
+ executionOrder = append(executionOrder, "hook3-before")
+ },
+ onAfter: func() {
+ executionOrder = append(executionOrder, "hook3-after")
+ },
+ }
+
+ executor := &BaseExecutor{
+ hooks: []SQLHook{hook1, hook2, hook3},
+ }
+
+ execCtx := &types.ExecContext{
+ Query: "INSERT INTO users VALUES (?, ?)",
+ NamedValues: []driver.NamedValue{{Value: "Alice"}, {Value: 30}},
+ }
+
+ callback := func(ctx context.Context, query string, args
[]driver.NamedValue) (types.ExecResult, error) {
+ executionOrder = append(executionOrder, "callback")
+ return newMockExecResult(1, 1), nil
+ }
+
+ _, err := executor.ExecWithNamedValue(context.Background(), execCtx,
callback)
+ assert.NoError(t, err, "execution should succeed")
+
+ // Verify execution order: all before hooks, then callback, then all
after hooks (in same order)
+ expectedOrder := []string{
+ "hook1-before",
+ "hook2-before",
+ "hook3-before",
+ "callback",
+ "hook1-after",
+ "hook2-after",
+ "hook3-after",
+ }
+ assert.Equal(t, expectedOrder, executionOrder, "hooks should execute in
correct order")
+}
+
+// trackingHook is a helper hook that tracks execution order
+type trackingHook struct {
+ sqlType types.SQLType
+ onBefore func()
+ onAfter func()
+}
+
+func (h *trackingHook) Type() types.SQLType {
+ return h.sqlType
+}
+
+func (h *trackingHook) Before(ctx context.Context, execCtx *types.ExecContext)
error {
+ if h.onBefore != nil {
+ h.onBefore()
+ }
+ return nil
+}
+
+func (h *trackingHook) After(ctx context.Context, execCtx *types.ExecContext)
error {
+ if h.onAfter != nil {
+ h.onAfter()
+ }
+ return nil
+}
diff --git a/pkg/integration/grpc/grpc_transaction_interceptor_test.go
b/pkg/integration/grpc/grpc_transaction_interceptor_test.go
index a4579d53..04cae4a9 100644
--- a/pkg/integration/grpc/grpc_transaction_interceptor_test.go
+++ b/pkg/integration/grpc/grpc_transaction_interceptor_test.go
@@ -170,10 +170,9 @@ func
TestServerTransactionInterceptor_WithXidKeyLowercase(t *testing.T) {
}
func TestServerTransactionInterceptor_XidKeyPrecedence(t *testing.T) {
- md := metadata.New(map[string]string{
- constant.XidKey: "primary-xid",
- constant.XidKeyLowercase: "secondary-xid",
- })
+ md := metadata.MD{
+ "tx_xid": []string{"primary-xid", "secondary-xid"},
+ }
ctx := metadata.NewIncomingContext(context.Background(), md)
var handlerCtx context.Context
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]