This is an automated email from the ASF dual-hosted git repository.
zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 2c6652b0561 Add more test cases on PostgreSQLComDescribeExecutorTest
(#37925)
2c6652b0561 is described below
commit 2c6652b056194d03dfa3d74e0bde601d6a96bdbe
Author: Liang Zhang <[email protected]>
AuthorDate: Sun Feb 1 19:16:58 2026 +0800
Add more test cases on PostgreSQLComDescribeExecutorTest (#37925)
* Add more test cases on PostgreSQLComDescribeExecutorTest
* Add more test cases on PostgreSQLComDescribeExecutorTest
* Add more test cases on PostgreSQLComDescribeExecutorTest
* Add more test cases on PostgreSQLComDescribeExecutorTest
* Add more test cases on PostgreSQLComDescribeExecutorTest
* Add more test cases on PostgreSQLComDescribeExecutorTest
---
.../PostgreSQLComDescribeExecutorTest.java | 443 ++++++++++++++++-----
1 file changed, 347 insertions(+), 96 deletions(-)
diff --git
a/proxy/frontend/dialect/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
b/proxy/frontend/dialect/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
index 71d5b317f3c..8ebb9d9d2d1 100644
---
a/proxy/frontend/dialect/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
+++
b/proxy/frontend/dialect/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
@@ -61,6 +61,9 @@ import
org.apache.shardingsphere.test.infra.framework.extension.mock.AutoMockExt
import
org.apache.shardingsphere.test.infra.framework.extension.mock.StaticMockSettings;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.internal.configuration.plugins.Plugins;
@@ -69,6 +72,7 @@ import org.mockito.quality.Strictness;
import java.sql.Connection;
import java.sql.ParameterMetaData;
+import java.sql.PreparedStatement;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Types;
@@ -78,10 +82,12 @@ import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
import java.util.Properties;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
+import java.util.stream.Stream;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
@@ -135,11 +141,35 @@ class PostgreSQLComDescribeExecutorTest {
}
@Test
- void assertDescribePreparedStatementInsertWithoutColumns() throws
SQLException {
+ void assertDescribePreparedStatementWithExistingRowDescription() throws
SQLException {
+ when(packet.getType()).thenReturn('S');
+ String statementId = "S_exist";
+ when(packet.getName()).thenReturn(statementId);
+ String sql = "SELECT 1";
+ SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+ SQLStatementContext sqlStatementContext =
mock(SelectStatementContext.class);
+ when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+ ServerPreparedStatementRegistry serverPreparedStatementRegistry = new
ServerPreparedStatementRegistry();
+
when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(serverPreparedStatementRegistry);
+ PostgreSQLServerPreparedStatement preparedStatement = new
PostgreSQLServerPreparedStatement(
+ sql, sqlStatementContext, new HintValueContext(), new
ArrayList<>(), Collections.emptyList());
+
preparedStatement.setRowDescription(PostgreSQLNoDataPacket.getInstance());
+ serverPreparedStatementRegistry.addPreparedStatement(statementId,
preparedStatement);
+ Collection<DatabasePacket> actual = executor.execute();
+ Iterator<DatabasePacket> actualIterator = actual.iterator();
+ PostgreSQLParameterDescriptionPacket parameterDescription =
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+ PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
+ parameterDescription.write(mockPayload);
+ verify(mockPayload).writeInt2(0);
+ assertThat(actualIterator.next(),
is(PostgreSQLNoDataPacket.getInstance()));
+ }
+
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("provideInsertMetaDataCases")
+ void assertDescribePreparedStatementInsertByMetaData(final String
testName, final String statementId, final String sql,
+ final int
expectedParamCount, final int expectedInt4Count, final int expectedCharCount)
throws SQLException {
when(packet.getType()).thenReturn('S');
- final String statementId = "S_1";
when(packet.getName()).thenReturn(statementId);
- String sql = "INSERT INTO t_order VALUES (?, 0, 'char', ?), (2, ?, ?,
'')";
SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
List<PostgreSQLColumnType> parameterTypes = new
ArrayList<>(sqlStatement.getParameterCount());
for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
@@ -153,45 +183,76 @@ class PostgreSQLComDescribeExecutorTest {
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId,
new PostgreSQLServerPreparedStatement(sql, sqlStatementContext, new
HintValueContext(), parameterTypes,
parameterIndexes));
Collection<DatabasePacket> actualPackets = executor.execute();
- assertThat(actualPackets.size(), is(2));
Iterator<DatabasePacket> actualPacketsIterator =
actualPackets.iterator();
PostgreSQLParameterDescriptionPacket actualParameterDescription =
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
actualParameterDescription.write(mockPayload);
- verify(mockPayload).writeInt2(4);
- verify(mockPayload, times(2)).writeInt4(23);
- verify(mockPayload, times(2)).writeInt4(18);
+ verify(mockPayload).writeInt2(expectedParamCount);
+ verify(mockPayload,
times(expectedInt4Count)).writeInt4(PostgreSQLColumnType.INT4.getValue());
+ verify(mockPayload,
times(expectedCharCount)).writeInt4(PostgreSQLColumnType.CHAR.getValue());
assertThat(actualPacketsIterator.next(),
is(PostgreSQLNoDataPacket.getInstance()));
}
@Test
- void assertDescribePreparedStatementInsertWithColumns() throws
SQLException {
+ void assertDescribePreparedStatementInsertWithoutParameters() throws
SQLException {
when(packet.getType()).thenReturn('S');
- final String statementId = "S_2";
+ String statementId = "S_early_return";
when(packet.getName()).thenReturn(statementId);
- String sql = "INSERT INTO t_order (id, k, c, pad) VALUES (1, ?, ?, ?),
(?, 2, ?, '')";
+ String sql = "INSERT INTO t_order VALUES (1)";
SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
- List<PostgreSQLColumnType> parameterTypes = new
ArrayList<>(sqlStatement.getParameterCount());
- for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
- parameterTypes.add(PostgreSQLColumnType.UNSPECIFIED);
- }
SQLStatementContext sqlStatementContext =
mock(InsertStatementContext.class);
when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
- ContextManager contextManager = mockContextManager();
+ ServerPreparedStatementRegistry serverPreparedStatementRegistry = new
ServerPreparedStatementRegistry();
+
when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(serverPreparedStatementRegistry);
+ serverPreparedStatementRegistry.addPreparedStatement(statementId,
+ new PostgreSQLServerPreparedStatement(sql,
sqlStatementContext, new HintValueContext(), new ArrayList<>(),
Collections.emptyList()));
+ Collection<DatabasePacket> actualPackets = executor.execute();
+ Iterator<DatabasePacket> actualIterator = actualPackets.iterator();
+ PostgreSQLParameterDescriptionPacket parameterDescription =
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+ PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
+ parameterDescription.write(mockPayload);
+ verify(mockPayload).writeInt2(0);
+ assertThat(actualIterator.hasNext(), is(false));
+ }
+
+ @Test
+ void
assertDescribePreparedStatementInsertWithSchemaAndMixedParameterTypes() throws
SQLException {
+ when(packet.getType()).thenReturn('S');
+ String statementId = "S_schema";
+ when(packet.getName()).thenReturn(statementId);
+ String sql = "INSERT INTO public.t_small (col1, col2) VALUES (?, ?)
RETURNING *, col1 + col2 expr_sum";
+ SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+ List<PostgreSQLColumnType> parameterTypes = new
ArrayList<>(Arrays.asList(PostgreSQLColumnType.INT4,
PostgreSQLColumnType.UNSPECIFIED));
+ SQLStatementContext sqlStatementContext =
mock(InsertStatementContext.class);
+ when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+ ShardingSphereTable table = new ShardingSphereTable("t_small",
+ Arrays.asList(
+ new ShardingSphereColumn("col1", Types.INTEGER, true,
false, false, true, false, false),
+ new ShardingSphereColumn("col2", Types.SMALLINT, true,
false, false, true, false, false)),
+ Collections.emptyList(), Collections.emptyList());
+ ContextManager contextManager = mockContextManager(table);
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
List<Integer> parameterIndexes = IntStream.range(0,
sqlStatement.getParameterCount()).boxed().collect(Collectors.toList());
-
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId,
new PostgreSQLServerPreparedStatement(sql, sqlStatementContext, new
HintValueContext(), parameterTypes,
- parameterIndexes));
+
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
+ statementId, new PostgreSQLServerPreparedStatement(sql,
sqlStatementContext, new HintValueContext(), parameterTypes, parameterIndexes));
Collection<DatabasePacket> actualPackets = executor.execute();
- assertThat(actualPackets.size(), is(2));
- Iterator<DatabasePacket> actualPacketsIterator =
actualPackets.iterator();
- PostgreSQLParameterDescriptionPacket actualParameterDescription =
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
+ Iterator<DatabasePacket> actualIterator = actualPackets.iterator();
+ PostgreSQLParameterDescriptionPacket parameterDescription =
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
- actualParameterDescription.write(mockPayload);
- verify(mockPayload).writeInt2(5);
- verify(mockPayload, times(2)).writeInt4(23);
- verify(mockPayload, times(3)).writeInt4(18);
- assertThat(actualPacketsIterator.next(),
is(PostgreSQLNoDataPacket.getInstance()));
+ parameterDescription.write(mockPayload);
+ verify(mockPayload).writeInt2(2);
+ verify(mockPayload).writeInt4(PostgreSQLColumnType.INT4.getValue());
+ verify(mockPayload).writeInt4(PostgreSQLColumnType.INT2.getValue());
+ PostgreSQLRowDescriptionPacket rowDescriptionPacket =
(PostgreSQLRowDescriptionPacket) actualIterator.next();
+ List<PostgreSQLColumnDescription> columnDescriptions =
getColumnDescriptions(rowDescriptionPacket);
+ assertThat(columnDescriptions.size(), is(3));
+ assertThat(columnDescriptions.get(0).getColumnName(), is("col1"));
+ assertThat(columnDescriptions.get(0).getColumnLength(), is(4));
+ assertThat(columnDescriptions.get(1).getColumnName(), is("col2"));
+ assertThat(columnDescriptions.get(1).getColumnLength(), is(2));
+ assertThat(columnDescriptions.get(2).getColumnName(), is("expr_sum"));
+ assertThat(columnDescriptions.get(2).getTypeOID(),
is(PostgreSQLColumnType.VARCHAR.getValue()));
+ assertThat(columnDescriptions.get(2).getColumnLength(), is(-1));
}
@Test
@@ -246,12 +307,35 @@ class PostgreSQLComDescribeExecutorTest {
}
@Test
- void assertDescribePreparedStatementInsertWithReturningClause() throws
SQLException {
+ void
assertDescribePreparedStatementInsertWithUnspecifiedTypesAndNoMarkers() throws
SQLException {
+ when(packet.getType()).thenReturn('S');
+ String statementId = "S_mismatch";
+ when(packet.getName()).thenReturn(statementId);
+ String sql = "INSERT INTO t_order VALUES (1)";
+ SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+ SQLStatementContext sqlStatementContext =
mock(InsertStatementContext.class);
+ when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+ ContextManager contextManager = mockContextManager();
+
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+ List<PostgreSQLColumnType> parameterTypes = new
ArrayList<>(Collections.singletonList(PostgreSQLColumnType.UNSPECIFIED));
+
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
+ statementId, new PostgreSQLServerPreparedStatement(sql,
sqlStatementContext, new HintValueContext(), parameterTypes,
Collections.emptyList()));
+ Collection<DatabasePacket> actualPackets = executor.execute();
+ Iterator<DatabasePacket> actualIterator = actualPackets.iterator();
+ PostgreSQLParameterDescriptionPacket parameterDescription =
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+ PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
+ parameterDescription.write(mockPayload);
+ verify(mockPayload).writeInt2(1);
+
verify(mockPayload).writeInt4(PostgreSQLColumnType.UNSPECIFIED.getValue());
+ assertThat(actualIterator.next(),
is(PostgreSQLNoDataPacket.getInstance()));
+ }
+
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("provideReturningCases")
+ void assertDescribePreparedStatementInsertWithReturning(final String
testName, final String statementId, final String sql,
+ final
List<PostgreSQLColumnType> expectedParamTypes, final
List<PostgreSQLColumnDescription> expectedColumns) throws SQLException {
when(packet.getType()).thenReturn('S');
- final String statementId = "S_2";
when(packet.getName()).thenReturn(statementId);
- String sql = "INSERT INTO t_order (k, c, pad) VALUES (?, ?, ?) "
- + "RETURNING id, id alias_id, 'anonymous', 'OK'
literal_string, 1 literal_int, 4294967296 literal_bigint, 1.1 literal_numeric,
t_order.*, t_order, t_order alias_t_order";
SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
List<PostgreSQLColumnType> parameterTypes = new
ArrayList<>(sqlStatement.getParameterCount());
for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
@@ -262,70 +346,29 @@ class PostgreSQLComDescribeExecutorTest {
ContextManager contextManager = mockContextManager();
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
List<Integer> parameterIndexes = IntStream.range(0,
sqlStatement.getParameterCount()).boxed().collect(Collectors.toList());
-
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId,
new PostgreSQLServerPreparedStatement(sql, sqlStatementContext, new
HintValueContext(), parameterTypes,
- parameterIndexes));
+
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
+ statementId, new PostgreSQLServerPreparedStatement(sql,
sqlStatementContext, new HintValueContext(), parameterTypes, parameterIndexes));
Collection<DatabasePacket> actualPackets = executor.execute();
- assertThat(actualPackets.size(), is(2));
- Iterator<DatabasePacket> actualPacketsIterator =
actualPackets.iterator();
- PostgreSQLParameterDescriptionPacket actualParameterDescription =
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
+ Iterator<DatabasePacket> actualIterator = actualPackets.iterator();
+ PostgreSQLParameterDescriptionPacket parameterDescription =
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
- actualParameterDescription.write(mockPayload);
- verify(mockPayload).writeInt2(3);
- verify(mockPayload).writeInt4(23);
- verify(mockPayload, times(2)).writeInt4(18);
- DatabasePacket actualRowDescriptionPacket =
actualPacketsIterator.next();
- assertThat(actualRowDescriptionPacket,
is(isA(PostgreSQLRowDescriptionPacket.class)));
- assertRowDescriptions((PostgreSQLRowDescriptionPacket)
actualRowDescriptionPacket);
- }
-
- private void assertRowDescriptions(final PostgreSQLRowDescriptionPacket
actualRowDescriptionPacket) {
- List<PostgreSQLColumnDescription> actualColumnDescriptions = new
ArrayList<>(getColumnDescriptionsFromPacket(actualRowDescriptionPacket));
- assertThat(actualColumnDescriptions.size(), is(13));
- assertThat(actualColumnDescriptions.get(0).getColumnName(), is("id"));
- assertThat(actualColumnDescriptions.get(0).getTypeOID(),
is(PostgreSQLColumnType.INT4.getValue()));
- assertThat(actualColumnDescriptions.get(0).getColumnLength(), is(4));
- assertThat(actualColumnDescriptions.get(1).getColumnName(),
is("alias_id"));
- assertThat(actualColumnDescriptions.get(1).getTypeOID(),
is(PostgreSQLColumnType.INT4.getValue()));
- assertThat(actualColumnDescriptions.get(1).getColumnLength(), is(4));
- assertThat(actualColumnDescriptions.get(2).getColumnName(),
is("?column?"));
- assertThat(actualColumnDescriptions.get(2).getTypeOID(),
is(PostgreSQLColumnType.VARCHAR.getValue()));
- assertThat(actualColumnDescriptions.get(2).getColumnLength(), is(-1));
- assertThat(actualColumnDescriptions.get(3).getColumnName(),
is("literal_string"));
- assertThat(actualColumnDescriptions.get(3).getTypeOID(),
is(PostgreSQLColumnType.VARCHAR.getValue()));
- assertThat(actualColumnDescriptions.get(3).getColumnLength(), is(-1));
- assertThat(actualColumnDescriptions.get(4).getColumnName(),
is("literal_int"));
- assertThat(actualColumnDescriptions.get(4).getTypeOID(),
is(PostgreSQLColumnType.INT4.getValue()));
- assertThat(actualColumnDescriptions.get(4).getColumnLength(), is(4));
- assertThat(actualColumnDescriptions.get(5).getColumnName(),
is("literal_bigint"));
- assertThat(actualColumnDescriptions.get(5).getTypeOID(),
is(PostgreSQLColumnType.INT8.getValue()));
- assertThat(actualColumnDescriptions.get(5).getColumnLength(), is(8));
- assertThat(actualColumnDescriptions.get(6).getColumnName(),
is("literal_numeric"));
- assertThat(actualColumnDescriptions.get(6).getTypeOID(),
is(PostgreSQLColumnType.NUMERIC.getValue()));
- assertThat(actualColumnDescriptions.get(6).getColumnLength(), is(-1));
- assertThat(actualColumnDescriptions.get(7).getColumnName(), is("id"));
- assertThat(actualColumnDescriptions.get(7).getTypeOID(),
is(PostgreSQLColumnType.INT4.getValue()));
- assertThat(actualColumnDescriptions.get(7).getColumnLength(), is(4));
- assertThat(actualColumnDescriptions.get(8).getColumnName(), is("k"));
- assertThat(actualColumnDescriptions.get(8).getTypeOID(),
is(PostgreSQLColumnType.INT4.getValue()));
- assertThat(actualColumnDescriptions.get(8).getColumnLength(), is(4));
- assertThat(actualColumnDescriptions.get(9).getColumnName(), is("c"));
- assertThat(actualColumnDescriptions.get(9).getTypeOID(),
is(PostgreSQLColumnType.CHAR.getValue()));
- assertThat(actualColumnDescriptions.get(9).getColumnLength(), is(-1));
- assertThat(actualColumnDescriptions.get(10).getColumnName(),
is("pad"));
- assertThat(actualColumnDescriptions.get(10).getTypeOID(),
is(PostgreSQLColumnType.CHAR.getValue()));
- assertThat(actualColumnDescriptions.get(10).getColumnLength(), is(-1));
- assertThat(actualColumnDescriptions.get(11).getColumnName(),
is("t_order"));
- assertThat(actualColumnDescriptions.get(11).getTypeOID(),
is(PostgreSQLColumnType.VARCHAR.getValue()));
- assertThat(actualColumnDescriptions.get(11).getColumnLength(), is(-1));
- assertThat(actualColumnDescriptions.get(12).getColumnName(),
is("alias_t_order"));
- assertThat(actualColumnDescriptions.get(12).getTypeOID(),
is(PostgreSQLColumnType.VARCHAR.getValue()));
- assertThat(actualColumnDescriptions.get(12).getColumnLength(), is(-1));
- }
-
- @SuppressWarnings("unchecked")
- @SneakyThrows(ReflectiveOperationException.class)
- private Collection<PostgreSQLColumnDescription>
getColumnDescriptionsFromPacket(final PostgreSQLRowDescriptionPacket packet) {
- return (Collection<PostgreSQLColumnDescription>)
Plugins.getMemberAccessor().get(PostgreSQLRowDescriptionPacket.class.getDeclaredField("columnDescriptions"),
packet);
+ parameterDescription.write(mockPayload);
+ verify(mockPayload).writeInt2(expectedParamTypes.size());
+ Map<Integer, Long> expectedTypeCounts = expectedParamTypes.stream()
+ .collect(Collectors.groupingBy(PostgreSQLColumnType::getValue,
Collectors.counting()));
+ for (Map.Entry<Integer, Long> entry : expectedTypeCounts.entrySet()) {
+ verify(mockPayload,
times(entry.getValue().intValue())).writeInt4(entry.getKey());
+ }
+ PostgreSQLRowDescriptionPacket rowDescriptionPacket =
(PostgreSQLRowDescriptionPacket) actualIterator.next();
+ List<PostgreSQLColumnDescription> actualColumnDescriptions =
getColumnDescriptions(rowDescriptionPacket);
+ assertThat(actualColumnDescriptions.size(),
is(expectedColumns.size()));
+ for (int i = 0; i < expectedColumns.size(); i++) {
+ PostgreSQLColumnDescription expectedColumn =
expectedColumns.get(i);
+ PostgreSQLColumnDescription actualColumn =
actualColumnDescriptions.get(i);
+ assertThat(actualColumn.getColumnName(),
is(expectedColumn.getColumnName()));
+ assertThat(actualColumn.getTypeOID(),
is(expectedColumn.getTypeOID()));
+ assertThat(actualColumn.getColumnLength(),
is(expectedColumn.getColumnLength()));
+ }
}
@Test
@@ -342,7 +385,8 @@ class PostgreSQLComDescribeExecutorTest {
ContextManager contextManager = mockContextManager();
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
List<Integer> parameterIndexes = IntStream.range(0,
sqlStatement.getParameterCount()).boxed().collect(Collectors.toList());
- ConnectionContext connectionContext = mockConnectionContext();
+ ConnectionContext connectionContext = mock(ConnectionContext.class);
+
when(connectionContext.getCurrentDatabaseName()).thenReturn(Optional.of(DATABASE_NAME));
when(connectionSession.getConnectionContext()).thenReturn(connectionContext);
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
statementId, new PostgreSQLServerPreparedStatement(sql,
sqlStatementContext, new HintValueContext(), parameterTypes, parameterIndexes));
@@ -372,10 +416,98 @@ class PostgreSQLComDescribeExecutorTest {
}
}
- private ConnectionContext mockConnectionContext() {
- ConnectionContext result = mock(ConnectionContext.class);
-
when(result.getCurrentDatabaseName()).thenReturn(Optional.of(DATABASE_NAME));
- return result;
+ @Test
+ void assertDescribeSelectPreparedStatementWithNullMetaData() throws
SQLException {
+ when(packet.getType()).thenReturn('S');
+ String statementId = "S_null_metadata";
+ when(packet.getName()).thenReturn(statementId);
+ String sql = "SELECT id, k FROM t_order";
+ SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+ SQLStatementContext sqlStatementContext =
mock(SelectStatementContext.class);
+ when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+ prepareJDBCBackendConnectionWithNullMetaData(sql);
+ ContextManager contextManager = mockContextManager();
+
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+ List<PostgreSQLColumnType> parameterTypes = new ArrayList<>();
+ List<Integer> parameterIndexes = Collections.emptyList();
+ ConnectionContext connectionContext = mock(ConnectionContext.class);
+
when(connectionContext.getCurrentDatabaseName()).thenReturn(Optional.of(DATABASE_NAME));
+
when(connectionSession.getConnectionContext()).thenReturn(connectionContext);
+
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
+ statementId, new PostgreSQLServerPreparedStatement(sql,
sqlStatementContext, new HintValueContext(), parameterTypes, parameterIndexes));
+ Collection<DatabasePacket> actual = executor.execute();
+ Iterator<DatabasePacket> actualIterator = actual.iterator();
+ PostgreSQLParameterDescriptionPacket parameterDescription =
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+ PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
+ parameterDescription.write(mockPayload);
+ verify(mockPayload).writeInt2(0);
+ assertThat(actualIterator.next(),
is(PostgreSQLNoDataPacket.getInstance()));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ void assertDescribeSelectPreparedStatementWithPresetRowDescription()
throws SQLException {
+ when(packet.getType()).thenReturn('S');
+ String statementId = "S_pre_described";
+ when(packet.getName()).thenReturn(statementId);
+ String sql = "SELECT id FROM t_order WHERE id = ?";
+ SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+ SQLStatementContext sqlStatementContext =
mock(SelectStatementContext.class);
+ when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+ List<PostgreSQLColumnType> parameterTypes = new
ArrayList<>(Collections.singleton(PostgreSQLColumnType.INT4));
+ List<Integer> parameterIndexes = IntStream.range(0,
sqlStatement.getParameterCount()).boxed().collect(Collectors.toList());
+ PostgreSQLServerPreparedStatement preparedStatement =
mock(PostgreSQLServerPreparedStatement.class);
+ when(preparedStatement.describeRows()).thenReturn(Optional.empty(),
Optional.of(PostgreSQLNoDataPacket.getInstance()));
+ when(preparedStatement.describeParameters()).thenReturn(new
PostgreSQLParameterDescriptionPacket(parameterTypes));
+ when(preparedStatement.getSql()).thenReturn(sql);
+
when(preparedStatement.getSqlStatementContext()).thenReturn(sqlStatementContext);
+ when(preparedStatement.getHintValueContext()).thenReturn(new
HintValueContext());
+ when(preparedStatement.getParameterTypes()).thenReturn(parameterTypes);
+
when(preparedStatement.getActualParameterMarkerIndexes()).thenReturn(parameterIndexes);
+ ContextManager contextManager = mockContextManager();
+
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+ ConnectionContext connectionContext = mock(ConnectionContext.class);
+
when(connectionContext.getCurrentDatabaseName()).thenReturn(Optional.of(DATABASE_NAME));
+
when(connectionSession.getConnectionContext()).thenReturn(connectionContext);
+ prepareJDBCBackendConnectionWithPreparedStatement(sql);
+
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId,
preparedStatement);
+ Collection<DatabasePacket> actual = executor.execute();
+ Iterator<DatabasePacket> actualIterator = actual.iterator();
+ PostgreSQLParameterDescriptionPacket parameterDescription =
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+ PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
+ parameterDescription.write(mockPayload);
+ verify(mockPayload).writeInt2(1);
+ verify(mockPayload).writeInt4(PostgreSQLColumnType.INT4.getValue());
+ assertThat(actualIterator.next(),
is(PostgreSQLNoDataPacket.getInstance()));
+ }
+
+ @Test
+ void assertPopulateParameterTypesWithMixedSpecifiedAndUnspecified() throws
SQLException {
+ when(packet.getType()).thenReturn('S');
+ String statementId = "S_mixed_params";
+ when(packet.getName()).thenReturn(statementId);
+ String sql = "SELECT id FROM t_order WHERE id = ? AND k = ?";
+ SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+ SQLStatementContext sqlStatementContext =
mock(SelectStatementContext.class);
+ when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+ prepareJDBCBackendConnectionWithParamTypes(sql, new
int[]{Types.INTEGER, Types.SMALLINT}, new String[]{"int4", "int2"});
+ List<PostgreSQLColumnType> parameterTypes = new
ArrayList<>(Arrays.asList(PostgreSQLColumnType.INT4,
PostgreSQLColumnType.UNSPECIFIED));
+ List<Integer> parameterIndexes = IntStream.range(0,
sqlStatement.getParameterCount()).boxed().collect(Collectors.toList());
+ ContextManager contextManager = mockContextManager();
+
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+ ConnectionContext connectionContext = mock(ConnectionContext.class);
+
when(connectionContext.getCurrentDatabaseName()).thenReturn(Optional.of(DATABASE_NAME));
+
when(connectionSession.getConnectionContext()).thenReturn(connectionContext);
+
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
+ statementId, new PostgreSQLServerPreparedStatement(sql,
sqlStatementContext, new HintValueContext(), parameterTypes, parameterIndexes));
+ Collection<DatabasePacket> actual = executor.execute();
+ Iterator<DatabasePacket> actualIterator = actual.iterator();
+ PostgreSQLParameterDescriptionPacket parameterDescription =
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+ PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
+ parameterDescription.write(mockPayload);
+ verify(mockPayload).writeInt2(2);
+ verify(mockPayload).writeInt4(PostgreSQLColumnType.INT4.getValue());
+ verify(mockPayload).writeInt4(PostgreSQLColumnType.INT2.getValue());
}
private ContextManager mockContextManager() {
@@ -406,6 +538,29 @@ class PostgreSQLComDescribeExecutorTest {
return result;
}
+ private ContextManager mockContextManager(final ShardingSphereTable table)
{
+ ContextManager result = mock(ContextManager.class, RETURNS_DEEP_STUBS);
+
when(result.getMetaDataContexts().getMetaData().getProps()).thenReturn(new
ConfigurationProperties(new Properties()));
+
when(connectionSession.getUsedDatabaseName()).thenReturn(DATABASE_NAME);
+
when(connectionSession.getCurrentDatabaseName()).thenReturn(DATABASE_NAME);
+
when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(new
ServerPreparedStatementRegistry());
+ RuleMetaData globalRuleMetaData = new
RuleMetaData(Collections.singleton(new SQLTranslatorRule(new
DefaultSQLTranslatorRuleConfigurationBuilder().build())));
+
when(result.getMetaDataContexts().getMetaData().getGlobalRuleMetaData()).thenReturn(globalRuleMetaData);
+
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getProtocolType()).thenReturn(DATABASE_TYPE);
+ StorageUnit storageUnit = mock(StorageUnit.class, RETURNS_DEEP_STUBS);
+ when(storageUnit.getStorageType()).thenReturn(DATABASE_TYPE);
+
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getResourceMetaData().getStorageUnits()).thenReturn(Collections.singletonMap("ds_0",
storageUnit));
+
when(result.getMetaDataContexts().getMetaData().containsDatabase(DATABASE_NAME)).thenReturn(true);
+ ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
+
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).containsSchema("public")).thenReturn(true);
+
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getSchema("public")).thenReturn(schema);
+ when(schema.containsTable(table.getName())).thenReturn(true);
+ when(schema.getTable(table.getName())).thenReturn(table);
+ ShardingSphereDatabase database =
result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME);
+ when(result.getDatabase(DATABASE_NAME)).thenReturn(database);
+ return result;
+ }
+
@SuppressWarnings("JDBCResourceOpenedButNotSafelyClosed")
private void prepareJDBCBackendConnection(final String sql) throws
SQLException {
ProxyDatabaseConnectionManager databaseConnectionManager =
mock(ProxyDatabaseConnectionManager.class);
@@ -419,6 +574,43 @@ class PostgreSQLComDescribeExecutorTest {
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
}
+ private void prepareJDBCBackendConnectionWithNullMetaData(final String
sql) throws SQLException {
+ ProxyDatabaseConnectionManager databaseConnectionManager =
mock(ProxyDatabaseConnectionManager.class);
+ Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS);
+ PreparedStatement preparedStatement = mock(PreparedStatement.class,
RETURNS_DEEP_STUBS);
+ when(preparedStatement.getMetaData()).thenReturn(null);
+ when(connection.prepareStatement(sql)).thenReturn(preparedStatement);
+ when(databaseConnectionManager.getConnections(any(),
nullable(String.class), anyInt(), anyInt(),
any(ConnectionMode.class))).thenReturn(Collections.singletonList(connection));
+
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
+ }
+
+ private void prepareJDBCBackendConnectionWithPreparedStatement(final
String sql) throws SQLException {
+ ProxyDatabaseConnectionManager databaseConnectionManager =
mock(ProxyDatabaseConnectionManager.class);
+ Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS);
+ PreparedStatement preparedStatement = mock(PreparedStatement.class);
+ when(connection.prepareStatement(sql)).thenReturn(preparedStatement);
+ when(databaseConnectionManager.getConnections(any(),
nullable(String.class), anyInt(), anyInt(),
any(ConnectionMode.class))).thenReturn(Collections.singletonList(connection));
+
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
+ }
+
+ private void prepareJDBCBackendConnectionWithParamTypes(final String sql,
final int[] paramTypes, final String[] paramTypeNames) throws SQLException {
+ ParameterMetaData parameterMetaData = mock(ParameterMetaData.class);
+ for (int i = 0; i < paramTypes.length; i++) {
+ int index = i + 1;
+
when(parameterMetaData.getParameterType(index)).thenReturn(paramTypes[i]);
+
when(parameterMetaData.getParameterTypeName(index)).thenReturn(paramTypeNames[i]);
+ }
+ PreparedStatement preparedStatement = mock(PreparedStatement.class,
RETURNS_DEEP_STUBS);
+
when(preparedStatement.getParameterMetaData()).thenReturn(parameterMetaData);
+ ResultSetMetaData resultSetMetaData =
prepareResultSetMetaDataForSingleColumn();
+ when(preparedStatement.getMetaData()).thenReturn(resultSetMetaData);
+ Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS);
+ when(connection.prepareStatement(sql)).thenReturn(preparedStatement);
+ ProxyDatabaseConnectionManager databaseConnectionManager =
mock(ProxyDatabaseConnectionManager.class);
+ when(databaseConnectionManager.getConnections(any(),
nullable(String.class), anyInt(), anyInt(),
any(ConnectionMode.class))).thenReturn(Collections.singletonList(connection));
+
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
+ }
+
private ResultSetMetaData prepareResultSetMetaData() throws SQLException {
ResultSetMetaData result = mock(ResultSetMetaData.class);
when(result.getColumnCount()).thenReturn(4);
@@ -429,6 +621,16 @@ class PostgreSQLComDescribeExecutorTest {
return result;
}
+ private ResultSetMetaData prepareResultSetMetaDataForSingleColumn() throws
SQLException {
+ ResultSetMetaData result = mock(ResultSetMetaData.class);
+ when(result.getColumnCount()).thenReturn(1);
+ when(result.getColumnName(1)).thenReturn("id");
+ when(result.getColumnType(1)).thenReturn(Types.INTEGER);
+ when(result.getColumnDisplaySize(1)).thenReturn(11);
+ when(result.getColumnTypeName(1)).thenReturn("int4");
+ return result;
+ }
+
@SuppressWarnings("unchecked")
@SneakyThrows(ReflectiveOperationException.class)
private List<PostgreSQLColumnDescription> getColumnDescriptions(final
PostgreSQLRowDescriptionPacket packet) {
@@ -439,4 +641,53 @@ class PostgreSQLComDescribeExecutorTest {
void assertDescribeUnknownType() {
assertThrows(UnsupportedSQLOperationException.class, () -> new
PostgreSQLComDescribeExecutor(portalContext, packet,
connectionSession).execute());
}
+
+ private static Stream<Arguments> provideInsertMetaDataCases() {
+ return Stream.of(
+ Arguments.of("insert without columns", "S_meta_1", "INSERT
INTO t_order VALUES (?, 0, 'char', ?), (2, ?, ?, '')", 4, 2, 2),
+ Arguments.of("insert with columns", "S_meta_2", "INSERT INTO
t_order (id, k, c, pad) VALUES (1, ?, ?, ?), (?, 2, ?, '')", 5, 2, 3),
+ Arguments.of("insert with case-insensitive columns",
"S_meta_3", "INSERT INTO t_order (iD, k, c, PaD) VALUES (1, ?, ?, ?), (?, 2, ?,
'')", 5, 2, 3));
+ }
+
+ private static Stream<Arguments> provideReturningCases() {
+ return Stream.of(
+ Arguments.of("returning complex columns",
"S_returning_complex",
+ "INSERT INTO t_order (k, c, pad) VALUES (?, ?, ?)
RETURNING"
+ + " id, id alias_id, 'anonymous', 'OK'
literal_string, 1 literal_int, 4294967296 literal_bigint, 1.1 literal_numeric,
t_order.*, t_order, t_order alias_t_order",
+ Arrays.asList(PostgreSQLColumnType.INT4,
PostgreSQLColumnType.CHAR, PostgreSQLColumnType.CHAR),
+ getExpectedReturningColumns()),
+ Arguments.of("returning numeric literal",
"S_numeric_returning",
+ "INSERT INTO t_order (k) VALUES (?) RETURNING 1.2
numeric_value",
+ Collections.singletonList(PostgreSQLColumnType.INT4),
+
Collections.singletonList(expectedColumn("numeric_value", Types.NUMERIC, -1,
"numeric"))),
+ Arguments.of("returning boolean literal",
"S_boolean_returning",
+ "INSERT INTO t_order (k) VALUES (?) RETURNING true
bool_value",
+ Collections.singletonList(PostgreSQLColumnType.INT4),
+ Collections.singletonList(expectedColumn("bool_value",
Types.VARCHAR, -1, "varchar"))),
+ Arguments.of("returning without parameters",
"S_returning_only",
+ "INSERT INTO t_order VALUES (1) RETURNING id",
+ Collections.emptyList(),
+ Collections.singletonList(expectedColumn("id",
Types.INTEGER, 4, "int4"))));
+ }
+
+ private static PostgreSQLColumnDescription expectedColumn(final String
columnName, final int jdbcType, final int columnLength, final String
columnTypeName) {
+ return new PostgreSQLColumnDescription(columnName, 0, jdbcType,
columnLength, columnTypeName);
+ }
+
+ private static List<PostgreSQLColumnDescription>
getExpectedReturningColumns() {
+ return Arrays.asList(
+ expectedColumn("id", Types.INTEGER, 4, "int4"),
+ expectedColumn("alias_id", Types.INTEGER, 4, "int4"),
+ expectedColumn("?column?", Types.VARCHAR, -1, "varchar"),
+ expectedColumn("literal_string", Types.VARCHAR, -1, "varchar"),
+ expectedColumn("literal_int", Types.INTEGER, 4, "int4"),
+ expectedColumn("literal_bigint", Types.BIGINT, 8, "int8"),
+ expectedColumn("literal_numeric", Types.NUMERIC, -1,
"numeric"),
+ expectedColumn("id", Types.INTEGER, 4, "int4"),
+ expectedColumn("k", Types.INTEGER, 4, "int4"),
+ expectedColumn("c", Types.CHAR, -1, "char"),
+ expectedColumn("pad", Types.CHAR, -1, "char"),
+ expectedColumn("t_order", Types.VARCHAR, -1, "varchar"),
+ expectedColumn("alias_t_order", Types.VARCHAR, -1, "varchar"));
+ }
}