This is an automated email from the ASF dual-hosted git repository.

zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 4e66b577ef5 Add more test cases on ProxySQLExecutorTest (#38049)
4e66b577ef5 is described below

commit 4e66b577ef55552675c9c4f604906c8e4d9714e0
Author: Liang Zhang <[email protected]>
AuthorDate: Sun Feb 15 17:19:16 2026 +0800

    Add more test cases on ProxySQLExecutorTest (#38049)
    
    * Add more test cases on ProxySQLExecutorTest
    
    * Add more test cases on ProxySQLExecutorTest
    
    * Add more test cases on ProxySQLExecutorTest
    
    * Add more test cases on ProxySQLExecutorTest
---
 .../backend/connector/ProxySQLExecutorTest.java    | 466 ++++++++++++++-------
 1 file changed, 311 insertions(+), 155 deletions(-)

diff --git 
a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/connector/ProxySQLExecutorTest.java
 
b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/connector/ProxySQLExecutorTest.java
index f50978cf926..64004998879 100644
--- 
a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/connector/ProxySQLExecutorTest.java
+++ 
b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/connector/ProxySQLExecutorTest.java
@@ -17,34 +17,50 @@
 
 package org.apache.shardingsphere.proxy.backend.connector;
 
+import com.google.common.collect.LinkedHashMultimap;
+import lombok.SneakyThrows;
+import 
org.apache.shardingsphere.database.connector.core.metadata.database.metadata.DialectDatabaseMetaData;
+import 
org.apache.shardingsphere.database.connector.core.metadata.database.metadata.option.transaction.DialectTransactionOption;
+import 
org.apache.shardingsphere.database.connector.core.spi.DatabaseTypedSPILoader;
 import org.apache.shardingsphere.database.connector.core.type.DatabaseType;
 import 
org.apache.shardingsphere.database.exception.core.exception.transaction.TableModifyInTransactionException;
+import 
org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext;
 import 
org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
-import 
org.apache.shardingsphere.infra.binder.context.statement.type.CommonSQLStatementContext;
-import 
org.apache.shardingsphere.infra.binder.context.statement.type.ddl.CursorStatementContext;
-import 
org.apache.shardingsphere.infra.binder.context.statement.type.dml.InsertStatementContext;
-import org.apache.shardingsphere.infra.config.mode.ModeConfiguration;
 import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
+import org.apache.shardingsphere.infra.config.rule.RuleConfiguration;
+import 
org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext;
+import 
org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupReportContext;
 import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
+import 
org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
+import 
org.apache.shardingsphere.infra.executor.sql.execute.engine.raw.RawExecutor;
+import 
org.apache.shardingsphere.infra.executor.sql.execute.engine.raw.RawSQLExecutionUnit;
+import 
org.apache.shardingsphere.infra.executor.sql.execute.engine.raw.callback.RawSQLExecutorCallback;
+import 
org.apache.shardingsphere.infra.executor.sql.execute.result.ExecuteResult;
+import 
org.apache.shardingsphere.infra.executor.sql.prepare.driver.DriverExecutionPrepareEngine;
 import 
org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.JDBCDriverType;
-import org.apache.shardingsphere.infra.hint.HintValueContext;
-import org.apache.shardingsphere.infra.instance.ComputeNodeInstanceContext;
+import 
org.apache.shardingsphere.infra.executor.sql.prepare.raw.RawExecutionPrepareEngine;
 import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
 import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
 import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
-import 
org.apache.shardingsphere.infra.metadata.statistics.ShardingSphereStatistics;
-import 
org.apache.shardingsphere.infra.metadata.statistics.builder.ShardingSphereStatisticsFactory;
-import org.apache.shardingsphere.infra.route.context.RouteContext;
-import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
-import org.apache.shardingsphere.infra.session.query.QueryContext;
+import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
+import org.apache.shardingsphere.infra.rule.attribute.RuleAttributes;
+import 
org.apache.shardingsphere.infra.rule.attribute.raw.RawExecutionRuleAttribute;
+import 
org.apache.shardingsphere.infra.session.connection.transaction.TransactionConnectionContext;
 import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
 import org.apache.shardingsphere.mode.manager.ContextManager;
-import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
-import org.apache.shardingsphere.mode.spi.repository.PersistRepository;
+import 
org.apache.shardingsphere.proxy.backend.connector.jdbc.executor.ProxyJDBCExecutor;
+import 
org.apache.shardingsphere.proxy.backend.connector.jdbc.statement.JDBCBackendStatement;
+import 
org.apache.shardingsphere.proxy.backend.connector.sane.DialectSaneQueryResultEngine;
 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
+import 
org.apache.shardingsphere.sql.parser.statement.core.enums.TransactionIsolationLevel;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment;
+import 
org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
+import 
org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.CloseStatement;
+import 
org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.CursorStatement;
+import 
org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.FetchStatement;
+import 
org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.MoveStatement;
 import 
org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.TruncateStatement;
 import 
org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.table.CreateTableStatement;
 import 
org.apache.shardingsphere.sql.parser.statement.core.statement.type.dml.InsertStatement;
@@ -54,228 +70,368 @@ import 
org.apache.shardingsphere.test.infra.framework.extension.mock.AutoMockExt
 import 
org.apache.shardingsphere.test.infra.framework.extension.mock.StaticMockSettings;
 import org.apache.shardingsphere.transaction.api.TransactionType;
 import org.apache.shardingsphere.transaction.rule.TransactionRule;
+import org.apache.shardingsphere.transaction.spi.TransactionHook;
 import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInstance;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
 import org.mockito.Answers;
 import org.mockito.Mock;
+import org.mockito.MockedConstruction;
+import org.mockito.MockedStatic;
+import org.mockito.internal.configuration.plugins.Plugins;
 import org.mockito.junit.jupiter.MockitoSettings;
 import org.mockito.quality.Strictness;
 
+import java.sql.Connection;
+import java.sql.SQLException;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
+import java.util.List;
 import java.util.Optional;
+import java.util.stream.Stream;
 
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.is;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.Mockito.CALLS_REAL_METHODS;
 import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyBoolean;
+import static org.mockito.Mockito.anyCollection;
+import static org.mockito.Mockito.anyString;
+import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockConstruction;
+import static org.mockito.Mockito.mockStatic;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 @ExtendWith(AutoMockExtension.class)
 @StaticMockSettings(ProxyContext.class)
 @MockitoSettings(strictness = Strictness.LENIENT)
+@TestInstance(TestInstance.Lifecycle.PER_CLASS)
 class ProxySQLExecutorTest {
     
-    private final DatabaseType databaseType = 
TypedSPILoader.getService(DatabaseType.class, "FIXTURE");
+    private final DatabaseType fixtureDatabaseType = 
TypedSPILoader.getService(DatabaseType.class, "FIXTURE");
     
     private final DatabaseType mysqlDatabaseType = 
TypedSPILoader.getService(DatabaseType.class, "MySQL");
     
     private final DatabaseType postgresqlDatabaseType = 
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL");
     
+    @Mock
+    private TransactionConnectionContext transactionConnectionContext;
+    
     @Mock(answer = Answers.RETURNS_DEEP_STUBS)
     private ConnectionSession connectionSession;
     
     @Mock(answer = Answers.RETURNS_DEEP_STUBS)
     private ProxyDatabaseConnectionManager databaseConnectionManager;
     
+    @Mock
+    private DatabaseProxyConnector databaseProxyConnector;
+    
+    @Mock
+    private ProxyJDBCExecutor regularExecutor;
+    
+    @Mock
+    private RawExecutor rawExecutor;
+    
+    @Mock
+    private TransactionHook transactionHook;
+    
     @Mock
     private TransactionRule transactionRule;
     
+    @Mock
+    private ShardingSphereRule shardingSphereRule;
+    
+    @Mock
+    private DialectSaneQueryResultEngine saneQueryResultEngine;
+    
+    @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+    private ShardingSphereMetaData metaData;
+    
+    @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+    private ShardingSphereDatabase database;
+    
     @BeforeEach
     void setUp() {
-        
when(connectionSession.getTransactionStatus().isInTransaction()).thenReturn(true);
-        
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
+        
when(connectionSession.getConnectionContext().getTransactionContext()).thenReturn(transactionConnectionContext);
         
when(databaseConnectionManager.getConnectionSession()).thenReturn(connectionSession);
-        
when(databaseConnectionManager.getConnectionSession().getUsedDatabaseName()).thenReturn("foo_db");
-        ShardingSphereMetaData metaData = mock(ShardingSphereMetaData.class, 
RETURNS_DEEP_STUBS);
-        
when(metaData.getDatabase("foo_db")).thenReturn(mock(ShardingSphereDatabase.class,
 RETURNS_DEEP_STUBS));
-        
when(metaData.getAllDatabases()).thenReturn(Collections.singleton(mock(ShardingSphereDatabase.class,
 RETURNS_DEEP_STUBS)));
-        
when(metaData.getAllDatabases().iterator().next().getProtocolType()).thenReturn(databaseType);
+        
when(databaseConnectionManager.getCachedConnections()).thenReturn(LinkedHashMultimap.create());
+        
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
+        
when(connectionSession.getIsolationLevel()).thenReturn(Optional.empty());
+        
when(connectionSession.getStatementManager()).thenReturn(mock(JDBCBackendStatement.class));
+        when(database.getName()).thenReturn("foo_db");
+        when(database.getProtocolType()).thenReturn(fixtureDatabaseType);
+        
when(database.getRuleMetaData().getRules()).thenReturn(Collections.emptyList());
+        when(metaData.getDatabase("foo_db")).thenReturn(database);
+        
when(metaData.getAllDatabases()).thenReturn(Collections.singleton(database));
         
when(metaData.getProps().<Integer>getValue(ConfigurationPropertyKey.KERNEL_EXECUTOR_SIZE)).thenReturn(0);
+        
when(metaData.getProps().<Integer>getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY)).thenReturn(1);
         
when(metaData.getProps().<Boolean>getValue(ConfigurationPropertyKey.PERSIST_SCHEMAS_TO_REPOSITORY_ENABLED)).thenReturn(true);
         when(transactionRule.getDefaultType()).thenReturn(TransactionType.XA);
         when(metaData.getGlobalRuleMetaData()).thenReturn(new 
RuleMetaData(Arrays.asList(mock(SQLFederationRule.class), transactionRule)));
-        ComputeNodeInstanceContext computeNodeInstanceContext = 
mock(ComputeNodeInstanceContext.class);
-        
when(computeNodeInstanceContext.getModeConfiguration()).thenReturn(mock(ModeConfiguration.class));
-        ContextManager contextManager = new ContextManager(new 
MetaDataContexts(metaData,
-                ShardingSphereStatisticsFactory.create(metaData, new 
ShardingSphereStatistics())), computeNodeInstanceContext, mock(), 
mock(PersistRepository.class, RETURNS_DEEP_STUBS));
+        ContextManager contextManager = mock(ContextManager.class, 
RETURNS_DEEP_STUBS);
+        
when(contextManager.getMetaDataContexts().getMetaData()).thenReturn(metaData);
+        when(contextManager.getDatabase("foo_db")).thenReturn(database);
+        when(contextManager.getDatabaseType()).thenReturn(fixtureDatabaseType);
         
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
     }
     
-    @Test
-    void assertCheckExecutePrerequisitesWhenExecuteDDLInXATransaction() {
-        ExecutionContext executionContext = new ExecutionContext(
-                new 
QueryContext(createCreateTableStatementContext(mysqlDatabaseType), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(),
-                        mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        assertThrows(TableModifyInTransactionException.class, () -> new 
ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext()));
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("constructorScenarios")
+    void assertConstructor(final String name, final String 
currentDatabaseName, final String schemaName, final boolean hasSchemaName) {
+        
when(connectionSession.getCurrentDatabaseName()).thenReturn(currentDatabaseName);
+        assertNotNull(createProxySQLExecutor(schemaName, 
hasSchemaName).getSqlFederationEngine());
     }
     
-    private ConnectionContext mockConnectionContext() {
-        ConnectionContext result = mock(ConnectionContext.class);
-        
when(result.getCurrentDatabaseName()).thenReturn(Optional.of("foo_db"));
-        return result;
+    private Stream<Arguments> constructorScenarios() {
+        return Stream.of(
+                
Arguments.of("constructor-use-used-database-when-current-empty", "", 
"foo_schema", true),
+                
Arguments.of("constructor-use-default-schema-when-schema-missing", "foo_db", 
"foo_schema", false));
+    }
+    
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("checkExecutePrerequisitesScenarios")
+    void assertCheckExecutePrerequisites(final String name, final SQLStatement 
sqlStatement,
+                                         final TransactionType 
transactionType, final boolean inTransaction, final boolean hasTable, final 
boolean expectedThrowException) {
+        when(transactionRule.getDefaultType()).thenReturn(transactionType);
+        
when(connectionSession.getTransactionStatus().isInTransaction()).thenReturn(inTransaction);
+        ProxySQLExecutor proxySQLExecutor = 
createProxySQLExecutor("foo_schema", true);
+        SQLStatementContext sqlStatementContext = 
createCheckStatementContext(sqlStatement, hasTable);
+        if (expectedThrowException) {
+            assertThrows(TableModifyInTransactionException.class, () -> 
proxySQLExecutor.checkExecutePrerequisites(sqlStatementContext));
+        } else {
+            assertDoesNotThrow(() -> 
proxySQLExecutor.checkExecutePrerequisites(sqlStatementContext));
+        }
     }
     
-    @Test
-    void 
assertCheckExecutePrerequisitesWhenExecuteTruncateInMySQLXATransaction() {
-        ExecutionContext executionContext = new ExecutionContext(
-                new 
QueryContext(createTruncateStatementContext(mysqlDatabaseType), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(), 
mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        assertThrows(TableModifyInTransactionException.class, () -> new 
ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext()));
+    private Stream<Arguments> checkExecutePrerequisitesScenarios() {
+        return Stream.of(
+                Arguments.of("ddl-create-mysql-xa-throws", 
createCreateTableStatement(mysqlDatabaseType), TransactionType.XA, true, true, 
true),
+                Arguments.of("ddl-truncate-mysql-xa-throws", 
createTruncateStatement(mysqlDatabaseType), TransactionType.XA, true, true, 
true),
+                Arguments.of("ddl-create-postgresql-local-throws", 
createCreateTableStatement(postgresqlDatabaseType), TransactionType.LOCAL, 
true, true, true),
+                Arguments.of("ddl-create-postgresql-xa-throws", 
createCreateTableStatement(postgresqlDatabaseType), TransactionType.XA, true, 
true, true),
+                Arguments.of("ddl-create-postgresql-local-empty-table-throws", 
createCreateTableStatement(postgresqlDatabaseType), TransactionType.LOCAL, 
true, false, true),
+                Arguments.of("ddl-create-mysql-local-pass", 
createCreateTableStatement(mysqlDatabaseType), TransactionType.LOCAL, true, 
true, false),
+                Arguments.of("ddl-truncate-mysql-local-pass", 
createTruncateStatement(mysqlDatabaseType), TransactionType.LOCAL, true, true, 
false),
+                Arguments.of("ddl-create-base-transaction-pass", 
createCreateTableStatement(mysqlDatabaseType), TransactionType.BASE, true, 
true, false),
+                Arguments.of("ddl-create-mysql-not-in-transaction-pass", 
createCreateTableStatement(mysqlDatabaseType), TransactionType.XA, false, true, 
false),
+                Arguments.of("ddl-create-postgresql-not-in-transaction-pass", 
createCreateTableStatement(postgresqlDatabaseType), TransactionType.LOCAL, 
false, true, false),
+                Arguments.of("ddl-truncate-postgresql-local-pass", 
createTruncateStatement(postgresqlDatabaseType), TransactionType.LOCAL, true, 
true, false),
+                Arguments.of("ddl-cursor-postgresql-local-pass", new 
CursorStatement(postgresqlDatabaseType, null, null), TransactionType.LOCAL, 
true, true, false),
+                Arguments.of("ddl-close-postgresql-local-pass", new 
CloseStatement(postgresqlDatabaseType, null, false), TransactionType.LOCAL, 
true, true, false),
+                Arguments.of("ddl-move-postgresql-local-pass", new 
MoveStatement(postgresqlDatabaseType, null, null), TransactionType.LOCAL, true, 
true, false),
+                Arguments.of("ddl-fetch-postgresql-local-pass", new 
FetchStatement(postgresqlDatabaseType, null, null), TransactionType.LOCAL, 
true, true, false),
+                Arguments.of("ddl-truncate-postgresql-xa-pass", 
createTruncateStatement(postgresqlDatabaseType), TransactionType.XA, true, 
true, false),
+                Arguments.of("dml-insert-mysql-xa-pass", 
createInsertStatement(mysqlDatabaseType), TransactionType.XA, true, true, 
false));
     }
     
-    @Test
-    void 
assertCheckExecutePrerequisitesWhenExecuteTruncateInMySQLLocalTransaction() {
-        
when(transactionRule.getDefaultType()).thenReturn(TransactionType.LOCAL);
-        ExecutionContext executionContext = new ExecutionContext(
-                new 
QueryContext(createTruncateStatementContext(mysqlDatabaseType), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(), 
mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        new ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext());
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("executeScenarios")
+    void assertExecute(final String name, final boolean hasRawExecutionRule, 
final SQLStatement sqlStatement, final boolean inTransaction,
+                       final boolean expectedHookInvoked, final boolean 
isReturnGeneratedKeys) throws SQLException {
+        when(connectionSession.getUsedDatabaseName()).thenReturn("foo_db");
+        
when(transactionConnectionContext.isInTransaction()).thenReturn(inTransaction);
+        
when(database.getRuleMetaData().getRules()).thenReturn(createRules(hasRawExecutionRule));
+        ProxySQLExecutor proxySQLExecutor = 
createProxySQLExecutor("foo_schema", true);
+        setExecutorField(proxySQLExecutor, "rawExecutor", rawExecutor);
+        setExecutorField(proxySQLExecutor, "regularExecutor", regularExecutor);
+        setExecutorField(proxySQLExecutor, "transactionHooks", 
Collections.singletonMap(shardingSphereRule, transactionHook));
+        ExecutionContext executionContext = 
createExecutionContext(sqlStatement);
+        ExecuteResult expectedExecuteResult = mock(ExecuteResult.class);
+        List<ExecuteResult> expected = 
Collections.singletonList(expectedExecuteResult);
+        if (hasRawExecutionRule) {
+            ExecutionGroupContext<RawSQLExecutionUnit> 
rawExecutionGroupContext = mock(ExecutionGroupContext.class);
+            try (
+                    MockedConstruction<RawExecutionPrepareEngine> ignored = 
mockConstruction(RawExecutionPrepareEngine.class,
+                            (mock, context) -> when(mock.prepare(anyString(), 
eq(executionContext), anyCollection(), any(ExecutionGroupReportContext.class)))
+                                    .thenReturn(rawExecutionGroupContext))) {
+                try (MockedConstruction<RawSQLExecutorCallback> 
ignoredCallback = mockConstruction(RawSQLExecutorCallback.class)) {
+                    when(rawExecutor.execute(eq(rawExecutionGroupContext), 
any(), any(RawSQLExecutorCallback.class))).thenReturn(expected);
+                    assertThat(proxySQLExecutor.execute(executionContext), 
is(expected));
+                }
+            }
+            return;
+        }
+        ExecutionGroupContext<JDBCExecutionUnit> jdbcExecutionGroupContext = 
mock(ExecutionGroupContext.class);
+        try (
+                MockedConstruction<DriverExecutionPrepareEngine> ignored = 
mockConstruction(DriverExecutionPrepareEngine.class,
+                        (mock, context) -> when(mock.prepare(anyString(), 
eq(executionContext), anyCollection(), any(ExecutionGroupReportContext.class)))
+                                .thenReturn(jdbcExecutionGroupContext))) {
+            when(regularExecutor.execute(any(), eq(jdbcExecutionGroupContext), 
eq(isReturnGeneratedKeys), anyBoolean())).thenReturn(expected);
+            assertThat(proxySQLExecutor.execute(executionContext), 
is(expected));
+        }
+        verify(regularExecutor).execute(any(), eq(jdbcExecutionGroupContext), 
eq(isReturnGeneratedKeys), anyBoolean());
+        if (expectedHookInvoked) {
+            verify(transactionHook).beforeExecuteSQL(eq(shardingSphereRule), 
eq(fixtureDatabaseType), anyCollection(),
+                    eq(transactionConnectionContext), 
eq(TransactionIsolationLevel.READ_COMMITTED));
+            return;
+        }
+        verify(transactionHook, never()).beforeExecuteSQL(any(), any(), 
anyCollection(), any(), any());
     }
     
-    @Test
-    void assertCheckExecutePrerequisitesWhenExecuteDMLInXATransaction() {
-        ExecutionContext executionContext = new ExecutionContext(
-                new 
QueryContext(mockInsertStatementContext(mysqlDatabaseType), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(), 
mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        new ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext());
+    private Stream<Arguments> executeScenarios() {
+        return Stream.of(
+                Arguments.of("execute-with-raw-rule", true, 
createCreateTableStatement(mysqlDatabaseType), true, false, false),
+                Arguments.of("execute-with-driver-and-generated-keys", false, 
createInsertStatement(mysqlDatabaseType), true, true, true),
+                Arguments.of("execute-with-driver-and-no-transaction", false, 
createInsertStatement(postgresqlDatabaseType), false, false, false));
     }
     
-    @Test
-    void assertCheckExecutePrerequisitesWhenExecuteDDLInBaseTransaction() {
-        
when(transactionRule.getDefaultType()).thenReturn(TransactionType.BASE);
-        ExecutionContext executionContext = new ExecutionContext(
-                new 
QueryContext(createCreateTableStatementContext(mysqlDatabaseType), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(),
-                        mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        new ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext());
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("executeFallbackScenarios")
+    void assertExecuteFallback(final String name, final boolean 
hasRawExecutionRule, final SQLStatement sqlStatement, final boolean 
hasSaneResult) throws SQLException {
+        when(connectionSession.getUsedDatabaseName()).thenReturn("foo_db");
+        
when(database.getRuleMetaData().getRules()).thenReturn(createRules(hasRawExecutionRule));
+        ProxySQLExecutor proxySQLExecutor = 
createProxySQLExecutor("foo_schema", true);
+        setExecutorField(proxySQLExecutor, "rawExecutor", rawExecutor);
+        setExecutorField(proxySQLExecutor, "regularExecutor", regularExecutor);
+        setExecutorField(proxySQLExecutor, "transactionHooks", 
Collections.singletonMap(shardingSphereRule, transactionHook));
+        ExecutionContext executionContext = 
createExecutionContext(sqlStatement);
+        SQLException expectedException = new SQLException("mock prepare 
failure");
+        try (MockedStatic<DatabaseTypedSPILoader> mockedDatabaseTypedSPILoader 
= mockStatic(DatabaseTypedSPILoader.class, CALLS_REAL_METHODS)) {
+            mockedDatabaseTypedSPILoader.when(() -> 
DatabaseTypedSPILoader.findService(DialectSaneQueryResultEngine.class, 
fixtureDatabaseType)).thenReturn(Optional.of(saneQueryResultEngine));
+            if (hasSaneResult) {
+                ExecuteResult saneExecuteResult = mock(ExecuteResult.class);
+                when(saneQueryResultEngine.getSaneQueryResult(sqlStatement, 
expectedException)).thenReturn(Optional.of(saneExecuteResult));
+                if (hasRawExecutionRule) {
+                    try (
+                            MockedConstruction<RawExecutionPrepareEngine> 
ignored = mockConstruction(RawExecutionPrepareEngine.class,
+                                    (mock, context) -> 
when(mock.prepare(anyString(), eq(executionContext), anyCollection(), 
any(ExecutionGroupReportContext.class)))
+                                            .thenThrow(expectedException))) {
+                        assertThat(proxySQLExecutor.execute(executionContext), 
is(Collections.singletonList(saneExecuteResult)));
+                    }
+                    return;
+                }
+                try (
+                        MockedConstruction<DriverExecutionPrepareEngine> 
ignored = mockConstruction(DriverExecutionPrepareEngine.class,
+                                (mock, context) -> 
when(mock.prepare(anyString(), eq(executionContext), anyCollection(), 
any(ExecutionGroupReportContext.class)))
+                                        .thenThrow(expectedException))) {
+                    assertThat(proxySQLExecutor.execute(executionContext), 
is(Collections.singletonList(saneExecuteResult)));
+                }
+                return;
+            }
+            when(saneQueryResultEngine.getSaneQueryResult(sqlStatement, 
expectedException)).thenReturn(Optional.empty());
+            if (hasRawExecutionRule) {
+                try (
+                        MockedConstruction<RawExecutionPrepareEngine> ignored 
= mockConstruction(RawExecutionPrepareEngine.class,
+                                (mock, context) -> 
when(mock.prepare(anyString(), eq(executionContext), anyCollection(), 
any(ExecutionGroupReportContext.class)))
+                                        .thenThrow(expectedException))) {
+                    SQLException actual = assertThrows(SQLException.class, () 
-> proxySQLExecutor.execute(executionContext));
+                    assertThat(actual, is(expectedException));
+                }
+                return;
+            }
+            try (
+                    MockedConstruction<DriverExecutionPrepareEngine> ignored = 
mockConstruction(DriverExecutionPrepareEngine.class,
+                            (mock, context) -> when(mock.prepare(anyString(), 
eq(executionContext), anyCollection(), any(ExecutionGroupReportContext.class)))
+                                    .thenThrow(expectedException))) {
+                assertThat(assertThrows(SQLException.class, () -> 
proxySQLExecutor.execute(executionContext)), is(expectedException));
+            }
+        }
     }
     
-    @Test
-    void assertCheckExecutePrerequisitesWhenExecuteDDLNotInXATransaction() {
-        
when(connectionSession.getTransactionStatus().isInTransaction()).thenReturn(false);
-        ExecutionContext executionContext = new ExecutionContext(
-                new 
QueryContext(createCreateTableStatementContext(mysqlDatabaseType), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(),
-                        mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        new ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext());
+    private Stream<Arguments> executeFallbackScenarios() {
+        return Stream.of(
+                Arguments.of("raw-prepare-failed-with-sane-result", true, 
createCreateTableStatement(mysqlDatabaseType), true),
+                Arguments.of("driver-prepare-failed-with-sane-result", false, 
createInsertStatement(postgresqlDatabaseType), true),
+                Arguments.of("driver-prepare-failed-throws-original", false, 
createInsertStatement(postgresqlDatabaseType), false));
     }
     
-    @Test
-    void 
assertCheckExecutePrerequisitesWhenExecuteCreateTableInPostgreSQLTransaction() {
-        
when(transactionRule.getDefaultType()).thenReturn(TransactionType.LOCAL);
-        ExecutionContext executionContext = new ExecutionContext(
-                new 
QueryContext(createCreateTableStatementContext(postgresqlDatabaseType), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(),
-                        mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        assertThrows(TableModifyInTransactionException.class, () -> new 
ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext()));
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("xaMetaDataRefreshScenarios")
+    void 
assertCheckExecutePrerequisitesWithMetaDataRefreshInXATransaction(final String 
name) {
+        DatabaseType databaseType = mock(DatabaseType.class);
+        DialectDatabaseMetaData dialectDatabaseMetaData = 
mock(DialectDatabaseMetaData.class);
+        when(dialectDatabaseMetaData.getTransactionOption()).thenReturn(new 
DialectTransactionOption(false, false, false, true, true,
+                Connection.TRANSACTION_READ_COMMITTED, false, false, 
Collections.emptyList()));
+        when(transactionRule.getDefaultType()).thenReturn(TransactionType.XA);
+        
when(connectionSession.getTransactionStatus().isInTransaction()).thenReturn(true);
+        try (MockedStatic<DatabaseTypedSPILoader> mockedDatabaseTypedSPILoader 
= mockStatic(DatabaseTypedSPILoader.class, CALLS_REAL_METHODS)) {
+            mockedDatabaseTypedSPILoader.when(() -> 
DatabaseTypedSPILoader.getService(DialectDatabaseMetaData.class, 
databaseType)).thenReturn(dialectDatabaseMetaData);
+            ProxySQLExecutor proxySQLExecutor = 
createProxySQLExecutor("foo_schema", true);
+            assertDoesNotThrow(() -> 
proxySQLExecutor.checkExecutePrerequisites(createCheckStatementContext(createCreateTableStatement(databaseType),
 true)));
+        }
     }
     
-    @Test
-    void 
assertCheckExecutePrerequisitesWhenExecuteTruncateInPostgreSQLTransaction() {
-        
when(transactionRule.getDefaultType()).thenReturn(TransactionType.LOCAL);
-        ExecutionContext executionContext = new ExecutionContext(
-                new 
QueryContext(createTruncateStatementContext(postgresqlDatabaseType), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(),
-                        mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        new ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext());
+    private Stream<Arguments> xaMetaDataRefreshScenarios() {
+        return 
Stream.of(Arguments.of("ddl-create-xa-with-metadata-refresh-supported"));
     }
     
-    @Test
-    void 
assertCheckExecutePrerequisitesWhenExecuteCursorInPostgreSQLTransaction() {
-        
when(transactionRule.getDefaultType()).thenReturn(TransactionType.LOCAL);
-        ExecutionContext executionContext = new ExecutionContext(
-                new QueryContext(mockCursorStatementContext(), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(), 
mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        new ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext());
+    private ProxySQLExecutor createProxySQLExecutor(final String schemaName, 
final boolean hasSchemaName) {
+        return new ProxySQLExecutor(JDBCDriverType.STATEMENT, 
databaseConnectionManager, databaseProxyConnector, 
createConstructorStatementContext(schemaName, hasSchemaName));
     }
     
-    @Test
-    void 
assertCheckExecutePrerequisitesWhenExecuteDMLInPostgreSQLTransaction() {
-        
when(transactionRule.getDefaultType()).thenReturn(TransactionType.LOCAL);
-        ExecutionContext executionContext = new ExecutionContext(
-                new 
QueryContext(mockInsertStatementContext(postgresqlDatabaseType), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(),
-                        mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        new ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext());
+    private SQLStatementContext createConstructorStatementContext(final String 
schemaName, final boolean hasSchemaName) {
+        SQLStatementContext result = mock(SQLStatementContext.class, 
RETURNS_DEEP_STUBS);
+        
when(result.getSqlStatement().getDatabaseType()).thenReturn(fixtureDatabaseType);
+        
when(result.getTablesContext().getSchemaName()).thenReturn(hasSchemaName ? 
Optional.of(schemaName) : Optional.empty());
+        return result;
     }
     
-    @Test
-    void assertCheckExecutePrerequisitesWhenExecuteDDLInMySQLTransaction() {
-        
when(transactionRule.getDefaultType()).thenReturn(TransactionType.LOCAL);
-        ExecutionContext executionContext = new ExecutionContext(
-                new 
QueryContext(createCreateTableStatementContext(mysqlDatabaseType), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(),
-                        mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        new ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext());
+    private SQLStatementContext createCheckStatementContext(final SQLStatement 
sqlStatement, final boolean hasTable) {
+        SQLStatementContext result = mock(SQLStatementContext.class);
+        when(result.getSqlStatement()).thenReturn(sqlStatement);
+        when(result.getTablesContext()).thenReturn(new TablesContext(hasTable 
? new SimpleTableSegment(new TableNameSegment(0, 0, new 
IdentifierValue("t_order"))) : null));
+        return result;
     }
     
-    private SQLStatementContext mockSQLStatementContext() {
-        SQLStatementContext result = mock(SQLStatementContext.class, 
RETURNS_DEEP_STUBS);
-        
when(result.getSqlStatement().getDatabaseType()).thenReturn(databaseType);
-        
when(result.getTablesContext().getSchemaName()).thenReturn(Optional.of("foo_db"));
+    private ExecutionContext createExecutionContext(final SQLStatement 
sqlStatement) {
+        SQLStatementContext sqlStatementContext = 
mock(SQLStatementContext.class);
+        when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+        ExecutionContext result = mock(ExecutionContext.class);
+        when(result.getSqlStatementContext()).thenReturn(sqlStatementContext);
         return result;
     }
     
-    @Test
-    void 
assertCheckExecutePrerequisitesWhenExecuteDDLNotInPostgreSQLTransaction() {
-        
when(transactionRule.getDefaultType()).thenReturn(TransactionType.LOCAL);
-        
when(connectionSession.getTransactionStatus().isInTransaction()).thenReturn(false);
-        ExecutionContext executionContext = new ExecutionContext(
-                new 
QueryContext(createCreateTableStatementContext(postgresqlDatabaseType), "", 
Collections.emptyList(), new HintValueContext(), mockConnectionContext(),
-                        mock(ShardingSphereMetaData.class)),
-                Collections.emptyList(), mock(RouteContext.class));
-        new ProxySQLExecutor(JDBCDriverType.STATEMENT,
-                databaseConnectionManager, mock(DatabaseProxyConnector.class), 
mockSQLStatementContext()).checkExecutePrerequisites(executionContext.getSqlStatementContext());
+    private Collection<ShardingSphereRule> createRules(final boolean 
hasRawExecutionRule) {
+        RuleAttributes attributes = hasRawExecutionRule ? new 
RuleAttributes(mock(RawExecutionRuleAttribute.class)) : new RuleAttributes();
+        ShardingSphereRule result = new ShardingSphereRule() {
+            
+            @Override
+            public RuleConfiguration getConfiguration() {
+                return null;
+            }
+            
+            @Override
+            public RuleAttributes getAttributes() {
+                return attributes;
+            }
+            
+            @Override
+            public int getOrder() {
+                return 0;
+            }
+        };
+        return Collections.singletonList(result);
     }
     
-    private CommonSQLStatementContext createCreateTableStatementContext(final 
DatabaseType databaseType) {
-        CreateTableStatement sqlStatement = new 
CreateTableStatement(databaseType);
-        sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 
0, new IdentifierValue("t_order"))));
-        sqlStatement.buildAttributes();
-        return new CommonSQLStatementContext(sqlStatement);
+    private CreateTableStatement createCreateTableStatement(final DatabaseType 
databaseType) {
+        CreateTableStatement result = new CreateTableStatement(databaseType);
+        result.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new 
IdentifierValue("t_order"))));
+        return result;
     }
     
-    private SQLStatementContext createTruncateStatementContext(final 
DatabaseType databaseType) {
-        TruncateStatement sqlStatement = new TruncateStatement(databaseType, 
Collections.singleton(new SimpleTableSegment(new TableNameSegment(0, 0, new 
IdentifierValue("t_order")))));
-        sqlStatement.buildAttributes();
-        return new CommonSQLStatementContext(sqlStatement);
+    private TruncateStatement createTruncateStatement(final DatabaseType 
databaseType) {
+        return new TruncateStatement(databaseType, Collections.singleton(new 
SimpleTableSegment(new TableNameSegment(0, 0, new 
IdentifierValue("t_order")))));
     }
     
-    private CursorStatementContext mockCursorStatementContext() {
-        CursorStatementContext result = mock(CursorStatementContext.class, 
RETURNS_DEEP_STUBS);
-        
when(result.getTablesContext().getDatabaseName()).thenReturn(Optional.empty());
-        
when(result.getSqlStatement().getDatabaseType()).thenReturn(databaseType);
+    private InsertStatement createInsertStatement(final DatabaseType 
databaseType) {
+        InsertStatement result = new InsertStatement(databaseType);
+        result.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new 
IdentifierValue("t_order"))));
         return result;
     }
     
-    private InsertStatementContext mockInsertStatementContext(final 
DatabaseType databaseType) {
-        InsertStatement sqlStatement = new InsertStatement(databaseType);
-        sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 
0, new IdentifierValue("t_order"))));
-        ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, 
RETURNS_DEEP_STUBS);
-        when(database.getName()).thenReturn("foo_db");
-        return new InsertStatementContext(sqlStatement, new 
ShardingSphereMetaData(Collections.singleton(database), mock(), mock(), 
mock()), "foo_db");
+    @SneakyThrows(ReflectiveOperationException.class)
+    private void setExecutorField(final ProxySQLExecutor target, final String 
fieldName, final Object value) {
+        
Plugins.getMemberAccessor().set(ProxySQLExecutor.class.getDeclaredField(fieldName),
 target, value);
     }
 }


Reply via email to