This is an automated email from the ASF dual-hosted git repository.

zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c6652b0561 Add more test cases on PostgreSQLComDescribeExecutorTest 
(#37925)
2c6652b0561 is described below

commit 2c6652b056194d03dfa3d74e0bde601d6a96bdbe
Author: Liang Zhang <[email protected]>
AuthorDate: Sun Feb 1 19:16:58 2026 +0800

    Add more test cases on PostgreSQLComDescribeExecutorTest (#37925)
    
    * Add more test cases on PostgreSQLComDescribeExecutorTest
    
    * Add more test cases on PostgreSQLComDescribeExecutorTest
    
    * Add more test cases on PostgreSQLComDescribeExecutorTest
    
    * Add more test cases on PostgreSQLComDescribeExecutorTest
    
    * Add more test cases on PostgreSQLComDescribeExecutorTest
    
    * Add more test cases on PostgreSQLComDescribeExecutorTest
---
 .../PostgreSQLComDescribeExecutorTest.java         | 443 ++++++++++++++++-----
 1 file changed, 347 insertions(+), 96 deletions(-)

diff --git 
a/proxy/frontend/dialect/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
 
b/proxy/frontend/dialect/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
index 71d5b317f3c..8ebb9d9d2d1 100644
--- 
a/proxy/frontend/dialect/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
+++ 
b/proxy/frontend/dialect/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
@@ -61,6 +61,9 @@ import 
org.apache.shardingsphere.test.infra.framework.extension.mock.AutoMockExt
 import 
org.apache.shardingsphere.test.infra.framework.extension.mock.StaticMockSettings;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
 import org.mockito.InjectMocks;
 import org.mockito.Mock;
 import org.mockito.internal.configuration.plugins.Plugins;
@@ -69,6 +72,7 @@ import org.mockito.quality.Strictness;
 
 import java.sql.Connection;
 import java.sql.ParameterMetaData;
+import java.sql.PreparedStatement;
 import java.sql.ResultSetMetaData;
 import java.sql.SQLException;
 import java.sql.Types;
@@ -78,10 +82,12 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
 import java.util.Optional;
 import java.util.Properties;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
+import java.util.stream.Stream;
 
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -135,11 +141,35 @@ class PostgreSQLComDescribeExecutorTest {
     }
     
     @Test
-    void assertDescribePreparedStatementInsertWithoutColumns() throws 
SQLException {
+    void assertDescribePreparedStatementWithExistingRowDescription() throws 
SQLException {
+        when(packet.getType()).thenReturn('S');
+        String statementId = "S_exist";
+        when(packet.getName()).thenReturn(statementId);
+        String sql = "SELECT 1";
+        SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+        SQLStatementContext sqlStatementContext = 
mock(SelectStatementContext.class);
+        when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+        ServerPreparedStatementRegistry serverPreparedStatementRegistry = new 
ServerPreparedStatementRegistry();
+        
when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(serverPreparedStatementRegistry);
+        PostgreSQLServerPreparedStatement preparedStatement = new 
PostgreSQLServerPreparedStatement(
+                sql, sqlStatementContext, new HintValueContext(), new 
ArrayList<>(), Collections.emptyList());
+        
preparedStatement.setRowDescription(PostgreSQLNoDataPacket.getInstance());
+        serverPreparedStatementRegistry.addPreparedStatement(statementId, 
preparedStatement);
+        Collection<DatabasePacket> actual = executor.execute();
+        Iterator<DatabasePacket> actualIterator = actual.iterator();
+        PostgreSQLParameterDescriptionPacket parameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+        PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
+        parameterDescription.write(mockPayload);
+        verify(mockPayload).writeInt2(0);
+        assertThat(actualIterator.next(), 
is(PostgreSQLNoDataPacket.getInstance()));
+    }
+    
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("provideInsertMetaDataCases")
+    void assertDescribePreparedStatementInsertByMetaData(final String 
testName, final String statementId, final String sql,
+                                                         final int 
expectedParamCount, final int expectedInt4Count, final int expectedCharCount) 
throws SQLException {
         when(packet.getType()).thenReturn('S');
-        final String statementId = "S_1";
         when(packet.getName()).thenReturn(statementId);
-        String sql = "INSERT INTO t_order VALUES (?, 0, 'char', ?), (2, ?, ?, 
'')";
         SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
         List<PostgreSQLColumnType> parameterTypes = new 
ArrayList<>(sqlStatement.getParameterCount());
         for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
@@ -153,45 +183,76 @@ class PostgreSQLComDescribeExecutorTest {
         
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId,
 new PostgreSQLServerPreparedStatement(sql, sqlStatementContext, new 
HintValueContext(), parameterTypes,
                 parameterIndexes));
         Collection<DatabasePacket> actualPackets = executor.execute();
-        assertThat(actualPackets.size(), is(2));
         Iterator<DatabasePacket> actualPacketsIterator = 
actualPackets.iterator();
         PostgreSQLParameterDescriptionPacket actualParameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
         PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
         actualParameterDescription.write(mockPayload);
-        verify(mockPayload).writeInt2(4);
-        verify(mockPayload, times(2)).writeInt4(23);
-        verify(mockPayload, times(2)).writeInt4(18);
+        verify(mockPayload).writeInt2(expectedParamCount);
+        verify(mockPayload, 
times(expectedInt4Count)).writeInt4(PostgreSQLColumnType.INT4.getValue());
+        verify(mockPayload, 
times(expectedCharCount)).writeInt4(PostgreSQLColumnType.CHAR.getValue());
         assertThat(actualPacketsIterator.next(), 
is(PostgreSQLNoDataPacket.getInstance()));
     }
     
     @Test
-    void assertDescribePreparedStatementInsertWithColumns() throws 
SQLException {
+    void assertDescribePreparedStatementInsertWithoutParameters() throws 
SQLException {
         when(packet.getType()).thenReturn('S');
-        final String statementId = "S_2";
+        String statementId = "S_early_return";
         when(packet.getName()).thenReturn(statementId);
-        String sql = "INSERT INTO t_order (id, k, c, pad) VALUES (1, ?, ?, ?), 
(?, 2, ?, '')";
+        String sql = "INSERT INTO t_order VALUES (1)";
         SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
-        List<PostgreSQLColumnType> parameterTypes = new 
ArrayList<>(sqlStatement.getParameterCount());
-        for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
-            parameterTypes.add(PostgreSQLColumnType.UNSPECIFIED);
-        }
         SQLStatementContext sqlStatementContext = 
mock(InsertStatementContext.class);
         when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
-        ContextManager contextManager = mockContextManager();
+        ServerPreparedStatementRegistry serverPreparedStatementRegistry = new 
ServerPreparedStatementRegistry();
+        
when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(serverPreparedStatementRegistry);
+        serverPreparedStatementRegistry.addPreparedStatement(statementId,
+                new PostgreSQLServerPreparedStatement(sql, 
sqlStatementContext, new HintValueContext(), new ArrayList<>(), 
Collections.emptyList()));
+        Collection<DatabasePacket> actualPackets = executor.execute();
+        Iterator<DatabasePacket> actualIterator = actualPackets.iterator();
+        PostgreSQLParameterDescriptionPacket parameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+        PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
+        parameterDescription.write(mockPayload);
+        verify(mockPayload).writeInt2(0);
+        assertThat(actualIterator.hasNext(), is(false));
+    }
+    
+    @Test
+    void 
assertDescribePreparedStatementInsertWithSchemaAndMixedParameterTypes() throws 
SQLException {
+        when(packet.getType()).thenReturn('S');
+        String statementId = "S_schema";
+        when(packet.getName()).thenReturn(statementId);
+        String sql = "INSERT INTO public.t_small (col1, col2) VALUES (?, ?) 
RETURNING *, col1 + col2 expr_sum";
+        SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+        List<PostgreSQLColumnType> parameterTypes = new 
ArrayList<>(Arrays.asList(PostgreSQLColumnType.INT4, 
PostgreSQLColumnType.UNSPECIFIED));
+        SQLStatementContext sqlStatementContext = 
mock(InsertStatementContext.class);
+        when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+        ShardingSphereTable table = new ShardingSphereTable("t_small",
+                Arrays.asList(
+                        new ShardingSphereColumn("col1", Types.INTEGER, true, 
false, false, true, false, false),
+                        new ShardingSphereColumn("col2", Types.SMALLINT, true, 
false, false, true, false, false)),
+                Collections.emptyList(), Collections.emptyList());
+        ContextManager contextManager = mockContextManager(table);
         
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
         List<Integer> parameterIndexes = IntStream.range(0, 
sqlStatement.getParameterCount()).boxed().collect(Collectors.toList());
-        
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId,
 new PostgreSQLServerPreparedStatement(sql, sqlStatementContext, new 
HintValueContext(), parameterTypes,
-                parameterIndexes));
+        
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
+                statementId, new PostgreSQLServerPreparedStatement(sql, 
sqlStatementContext, new HintValueContext(), parameterTypes, parameterIndexes));
         Collection<DatabasePacket> actualPackets = executor.execute();
-        assertThat(actualPackets.size(), is(2));
-        Iterator<DatabasePacket> actualPacketsIterator = 
actualPackets.iterator();
-        PostgreSQLParameterDescriptionPacket actualParameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
+        Iterator<DatabasePacket> actualIterator = actualPackets.iterator();
+        PostgreSQLParameterDescriptionPacket parameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
         PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
-        actualParameterDescription.write(mockPayload);
-        verify(mockPayload).writeInt2(5);
-        verify(mockPayload, times(2)).writeInt4(23);
-        verify(mockPayload, times(3)).writeInt4(18);
-        assertThat(actualPacketsIterator.next(), 
is(PostgreSQLNoDataPacket.getInstance()));
+        parameterDescription.write(mockPayload);
+        verify(mockPayload).writeInt2(2);
+        verify(mockPayload).writeInt4(PostgreSQLColumnType.INT4.getValue());
+        verify(mockPayload).writeInt4(PostgreSQLColumnType.INT2.getValue());
+        PostgreSQLRowDescriptionPacket rowDescriptionPacket = 
(PostgreSQLRowDescriptionPacket) actualIterator.next();
+        List<PostgreSQLColumnDescription> columnDescriptions = 
getColumnDescriptions(rowDescriptionPacket);
+        assertThat(columnDescriptions.size(), is(3));
+        assertThat(columnDescriptions.get(0).getColumnName(), is("col1"));
+        assertThat(columnDescriptions.get(0).getColumnLength(), is(4));
+        assertThat(columnDescriptions.get(1).getColumnName(), is("col2"));
+        assertThat(columnDescriptions.get(1).getColumnLength(), is(2));
+        assertThat(columnDescriptions.get(2).getColumnName(), is("expr_sum"));
+        assertThat(columnDescriptions.get(2).getTypeOID(), 
is(PostgreSQLColumnType.VARCHAR.getValue()));
+        assertThat(columnDescriptions.get(2).getColumnLength(), is(-1));
     }
     
     @Test
@@ -246,12 +307,35 @@ class PostgreSQLComDescribeExecutorTest {
     }
     
     @Test
-    void assertDescribePreparedStatementInsertWithReturningClause() throws 
SQLException {
+    void 
assertDescribePreparedStatementInsertWithUnspecifiedTypesAndNoMarkers() throws 
SQLException {
+        when(packet.getType()).thenReturn('S');
+        String statementId = "S_mismatch";
+        when(packet.getName()).thenReturn(statementId);
+        String sql = "INSERT INTO t_order VALUES (1)";
+        SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+        SQLStatementContext sqlStatementContext = 
mock(InsertStatementContext.class);
+        when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+        ContextManager contextManager = mockContextManager();
+        
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+        List<PostgreSQLColumnType> parameterTypes = new 
ArrayList<>(Collections.singletonList(PostgreSQLColumnType.UNSPECIFIED));
+        
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
+                statementId, new PostgreSQLServerPreparedStatement(sql, 
sqlStatementContext, new HintValueContext(), parameterTypes, 
Collections.emptyList()));
+        Collection<DatabasePacket> actualPackets = executor.execute();
+        Iterator<DatabasePacket> actualIterator = actualPackets.iterator();
+        PostgreSQLParameterDescriptionPacket parameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+        PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
+        parameterDescription.write(mockPayload);
+        verify(mockPayload).writeInt2(1);
+        
verify(mockPayload).writeInt4(PostgreSQLColumnType.UNSPECIFIED.getValue());
+        assertThat(actualIterator.next(), 
is(PostgreSQLNoDataPacket.getInstance()));
+    }
+    
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("provideReturningCases")
+    void assertDescribePreparedStatementInsertWithReturning(final String 
testName, final String statementId, final String sql,
+                                                            final 
List<PostgreSQLColumnType> expectedParamTypes, final 
List<PostgreSQLColumnDescription> expectedColumns) throws SQLException {
         when(packet.getType()).thenReturn('S');
-        final String statementId = "S_2";
         when(packet.getName()).thenReturn(statementId);
-        String sql = "INSERT INTO t_order (k, c, pad) VALUES (?, ?, ?) "
-                + "RETURNING id, id alias_id, 'anonymous', 'OK' 
literal_string, 1 literal_int, 4294967296 literal_bigint, 1.1 literal_numeric, 
t_order.*, t_order, t_order alias_t_order";
         SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
         List<PostgreSQLColumnType> parameterTypes = new 
ArrayList<>(sqlStatement.getParameterCount());
         for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
@@ -262,70 +346,29 @@ class PostgreSQLComDescribeExecutorTest {
         ContextManager contextManager = mockContextManager();
         
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
         List<Integer> parameterIndexes = IntStream.range(0, 
sqlStatement.getParameterCount()).boxed().collect(Collectors.toList());
-        
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId,
 new PostgreSQLServerPreparedStatement(sql, sqlStatementContext, new 
HintValueContext(), parameterTypes,
-                parameterIndexes));
+        
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
+                statementId, new PostgreSQLServerPreparedStatement(sql, 
sqlStatementContext, new HintValueContext(), parameterTypes, parameterIndexes));
         Collection<DatabasePacket> actualPackets = executor.execute();
-        assertThat(actualPackets.size(), is(2));
-        Iterator<DatabasePacket> actualPacketsIterator = 
actualPackets.iterator();
-        PostgreSQLParameterDescriptionPacket actualParameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
+        Iterator<DatabasePacket> actualIterator = actualPackets.iterator();
+        PostgreSQLParameterDescriptionPacket parameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
         PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
-        actualParameterDescription.write(mockPayload);
-        verify(mockPayload).writeInt2(3);
-        verify(mockPayload).writeInt4(23);
-        verify(mockPayload, times(2)).writeInt4(18);
-        DatabasePacket actualRowDescriptionPacket = 
actualPacketsIterator.next();
-        assertThat(actualRowDescriptionPacket, 
is(isA(PostgreSQLRowDescriptionPacket.class)));
-        assertRowDescriptions((PostgreSQLRowDescriptionPacket) 
actualRowDescriptionPacket);
-    }
-    
-    private void assertRowDescriptions(final PostgreSQLRowDescriptionPacket 
actualRowDescriptionPacket) {
-        List<PostgreSQLColumnDescription> actualColumnDescriptions = new 
ArrayList<>(getColumnDescriptionsFromPacket(actualRowDescriptionPacket));
-        assertThat(actualColumnDescriptions.size(), is(13));
-        assertThat(actualColumnDescriptions.get(0).getColumnName(), is("id"));
-        assertThat(actualColumnDescriptions.get(0).getTypeOID(), 
is(PostgreSQLColumnType.INT4.getValue()));
-        assertThat(actualColumnDescriptions.get(0).getColumnLength(), is(4));
-        assertThat(actualColumnDescriptions.get(1).getColumnName(), 
is("alias_id"));
-        assertThat(actualColumnDescriptions.get(1).getTypeOID(), 
is(PostgreSQLColumnType.INT4.getValue()));
-        assertThat(actualColumnDescriptions.get(1).getColumnLength(), is(4));
-        assertThat(actualColumnDescriptions.get(2).getColumnName(), 
is("?column?"));
-        assertThat(actualColumnDescriptions.get(2).getTypeOID(), 
is(PostgreSQLColumnType.VARCHAR.getValue()));
-        assertThat(actualColumnDescriptions.get(2).getColumnLength(), is(-1));
-        assertThat(actualColumnDescriptions.get(3).getColumnName(), 
is("literal_string"));
-        assertThat(actualColumnDescriptions.get(3).getTypeOID(), 
is(PostgreSQLColumnType.VARCHAR.getValue()));
-        assertThat(actualColumnDescriptions.get(3).getColumnLength(), is(-1));
-        assertThat(actualColumnDescriptions.get(4).getColumnName(), 
is("literal_int"));
-        assertThat(actualColumnDescriptions.get(4).getTypeOID(), 
is(PostgreSQLColumnType.INT4.getValue()));
-        assertThat(actualColumnDescriptions.get(4).getColumnLength(), is(4));
-        assertThat(actualColumnDescriptions.get(5).getColumnName(), 
is("literal_bigint"));
-        assertThat(actualColumnDescriptions.get(5).getTypeOID(), 
is(PostgreSQLColumnType.INT8.getValue()));
-        assertThat(actualColumnDescriptions.get(5).getColumnLength(), is(8));
-        assertThat(actualColumnDescriptions.get(6).getColumnName(), 
is("literal_numeric"));
-        assertThat(actualColumnDescriptions.get(6).getTypeOID(), 
is(PostgreSQLColumnType.NUMERIC.getValue()));
-        assertThat(actualColumnDescriptions.get(6).getColumnLength(), is(-1));
-        assertThat(actualColumnDescriptions.get(7).getColumnName(), is("id"));
-        assertThat(actualColumnDescriptions.get(7).getTypeOID(), 
is(PostgreSQLColumnType.INT4.getValue()));
-        assertThat(actualColumnDescriptions.get(7).getColumnLength(), is(4));
-        assertThat(actualColumnDescriptions.get(8).getColumnName(), is("k"));
-        assertThat(actualColumnDescriptions.get(8).getTypeOID(), 
is(PostgreSQLColumnType.INT4.getValue()));
-        assertThat(actualColumnDescriptions.get(8).getColumnLength(), is(4));
-        assertThat(actualColumnDescriptions.get(9).getColumnName(), is("c"));
-        assertThat(actualColumnDescriptions.get(9).getTypeOID(), 
is(PostgreSQLColumnType.CHAR.getValue()));
-        assertThat(actualColumnDescriptions.get(9).getColumnLength(), is(-1));
-        assertThat(actualColumnDescriptions.get(10).getColumnName(), 
is("pad"));
-        assertThat(actualColumnDescriptions.get(10).getTypeOID(), 
is(PostgreSQLColumnType.CHAR.getValue()));
-        assertThat(actualColumnDescriptions.get(10).getColumnLength(), is(-1));
-        assertThat(actualColumnDescriptions.get(11).getColumnName(), 
is("t_order"));
-        assertThat(actualColumnDescriptions.get(11).getTypeOID(), 
is(PostgreSQLColumnType.VARCHAR.getValue()));
-        assertThat(actualColumnDescriptions.get(11).getColumnLength(), is(-1));
-        assertThat(actualColumnDescriptions.get(12).getColumnName(), 
is("alias_t_order"));
-        assertThat(actualColumnDescriptions.get(12).getTypeOID(), 
is(PostgreSQLColumnType.VARCHAR.getValue()));
-        assertThat(actualColumnDescriptions.get(12).getColumnLength(), is(-1));
-    }
-    
-    @SuppressWarnings("unchecked")
-    @SneakyThrows(ReflectiveOperationException.class)
-    private Collection<PostgreSQLColumnDescription> 
getColumnDescriptionsFromPacket(final PostgreSQLRowDescriptionPacket packet) {
-        return (Collection<PostgreSQLColumnDescription>) 
Plugins.getMemberAccessor().get(PostgreSQLRowDescriptionPacket.class.getDeclaredField("columnDescriptions"),
 packet);
+        parameterDescription.write(mockPayload);
+        verify(mockPayload).writeInt2(expectedParamTypes.size());
+        Map<Integer, Long> expectedTypeCounts = expectedParamTypes.stream()
+                .collect(Collectors.groupingBy(PostgreSQLColumnType::getValue, 
Collectors.counting()));
+        for (Map.Entry<Integer, Long> entry : expectedTypeCounts.entrySet()) {
+            verify(mockPayload, 
times(entry.getValue().intValue())).writeInt4(entry.getKey());
+        }
+        PostgreSQLRowDescriptionPacket rowDescriptionPacket = 
(PostgreSQLRowDescriptionPacket) actualIterator.next();
+        List<PostgreSQLColumnDescription> actualColumnDescriptions = 
getColumnDescriptions(rowDescriptionPacket);
+        assertThat(actualColumnDescriptions.size(), 
is(expectedColumns.size()));
+        for (int i = 0; i < expectedColumns.size(); i++) {
+            PostgreSQLColumnDescription expectedColumn = 
expectedColumns.get(i);
+            PostgreSQLColumnDescription actualColumn = 
actualColumnDescriptions.get(i);
+            assertThat(actualColumn.getColumnName(), 
is(expectedColumn.getColumnName()));
+            assertThat(actualColumn.getTypeOID(), 
is(expectedColumn.getTypeOID()));
+            assertThat(actualColumn.getColumnLength(), 
is(expectedColumn.getColumnLength()));
+        }
     }
     
     @Test
@@ -342,7 +385,8 @@ class PostgreSQLComDescribeExecutorTest {
         ContextManager contextManager = mockContextManager();
         
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
         List<Integer> parameterIndexes = IntStream.range(0, 
sqlStatement.getParameterCount()).boxed().collect(Collectors.toList());
-        ConnectionContext connectionContext = mockConnectionContext();
+        ConnectionContext connectionContext = mock(ConnectionContext.class);
+        
when(connectionContext.getCurrentDatabaseName()).thenReturn(Optional.of(DATABASE_NAME));
         
when(connectionSession.getConnectionContext()).thenReturn(connectionContext);
         
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
                 statementId, new PostgreSQLServerPreparedStatement(sql, 
sqlStatementContext, new HintValueContext(), parameterTypes, parameterIndexes));
@@ -372,10 +416,98 @@ class PostgreSQLComDescribeExecutorTest {
         }
     }
     
-    private ConnectionContext mockConnectionContext() {
-        ConnectionContext result = mock(ConnectionContext.class);
-        
when(result.getCurrentDatabaseName()).thenReturn(Optional.of(DATABASE_NAME));
-        return result;
+    @Test
+    void assertDescribeSelectPreparedStatementWithNullMetaData() throws 
SQLException {
+        when(packet.getType()).thenReturn('S');
+        String statementId = "S_null_metadata";
+        when(packet.getName()).thenReturn(statementId);
+        String sql = "SELECT id, k FROM t_order";
+        SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+        SQLStatementContext sqlStatementContext = 
mock(SelectStatementContext.class);
+        when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+        prepareJDBCBackendConnectionWithNullMetaData(sql);
+        ContextManager contextManager = mockContextManager();
+        
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+        List<PostgreSQLColumnType> parameterTypes = new ArrayList<>();
+        List<Integer> parameterIndexes = Collections.emptyList();
+        ConnectionContext connectionContext = mock(ConnectionContext.class);
+        
when(connectionContext.getCurrentDatabaseName()).thenReturn(Optional.of(DATABASE_NAME));
+        
when(connectionSession.getConnectionContext()).thenReturn(connectionContext);
+        
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
+                statementId, new PostgreSQLServerPreparedStatement(sql, 
sqlStatementContext, new HintValueContext(), parameterTypes, parameterIndexes));
+        Collection<DatabasePacket> actual = executor.execute();
+        Iterator<DatabasePacket> actualIterator = actual.iterator();
+        PostgreSQLParameterDescriptionPacket parameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+        PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
+        parameterDescription.write(mockPayload);
+        verify(mockPayload).writeInt2(0);
+        assertThat(actualIterator.next(), 
is(PostgreSQLNoDataPacket.getInstance()));
+    }
+    
+    @SuppressWarnings("unchecked")
+    @Test
+    void assertDescribeSelectPreparedStatementWithPresetRowDescription() 
throws SQLException {
+        when(packet.getType()).thenReturn('S');
+        String statementId = "S_pre_described";
+        when(packet.getName()).thenReturn(statementId);
+        String sql = "SELECT id FROM t_order WHERE id = ?";
+        SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+        SQLStatementContext sqlStatementContext = 
mock(SelectStatementContext.class);
+        when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+        List<PostgreSQLColumnType> parameterTypes = new 
ArrayList<>(Collections.singleton(PostgreSQLColumnType.INT4));
+        List<Integer> parameterIndexes = IntStream.range(0, 
sqlStatement.getParameterCount()).boxed().collect(Collectors.toList());
+        PostgreSQLServerPreparedStatement preparedStatement = 
mock(PostgreSQLServerPreparedStatement.class);
+        when(preparedStatement.describeRows()).thenReturn(Optional.empty(), 
Optional.of(PostgreSQLNoDataPacket.getInstance()));
+        when(preparedStatement.describeParameters()).thenReturn(new 
PostgreSQLParameterDescriptionPacket(parameterTypes));
+        when(preparedStatement.getSql()).thenReturn(sql);
+        
when(preparedStatement.getSqlStatementContext()).thenReturn(sqlStatementContext);
+        when(preparedStatement.getHintValueContext()).thenReturn(new 
HintValueContext());
+        when(preparedStatement.getParameterTypes()).thenReturn(parameterTypes);
+        
when(preparedStatement.getActualParameterMarkerIndexes()).thenReturn(parameterIndexes);
+        ContextManager contextManager = mockContextManager();
+        
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+        ConnectionContext connectionContext = mock(ConnectionContext.class);
+        
when(connectionContext.getCurrentDatabaseName()).thenReturn(Optional.of(DATABASE_NAME));
+        
when(connectionSession.getConnectionContext()).thenReturn(connectionContext);
+        prepareJDBCBackendConnectionWithPreparedStatement(sql);
+        
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId,
 preparedStatement);
+        Collection<DatabasePacket> actual = executor.execute();
+        Iterator<DatabasePacket> actualIterator = actual.iterator();
+        PostgreSQLParameterDescriptionPacket parameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+        PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
+        parameterDescription.write(mockPayload);
+        verify(mockPayload).writeInt2(1);
+        verify(mockPayload).writeInt4(PostgreSQLColumnType.INT4.getValue());
+        assertThat(actualIterator.next(), 
is(PostgreSQLNoDataPacket.getInstance()));
+    }
+    
+    @Test
+    void assertPopulateParameterTypesWithMixedSpecifiedAndUnspecified() throws 
SQLException {
+        when(packet.getType()).thenReturn('S');
+        String statementId = "S_mixed_params";
+        when(packet.getName()).thenReturn(statementId);
+        String sql = "SELECT id FROM t_order WHERE id = ? AND k = ?";
+        SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+        SQLStatementContext sqlStatementContext = 
mock(SelectStatementContext.class);
+        when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+        prepareJDBCBackendConnectionWithParamTypes(sql, new 
int[]{Types.INTEGER, Types.SMALLINT}, new String[]{"int4", "int2"});
+        List<PostgreSQLColumnType> parameterTypes = new 
ArrayList<>(Arrays.asList(PostgreSQLColumnType.INT4, 
PostgreSQLColumnType.UNSPECIFIED));
+        List<Integer> parameterIndexes = IntStream.range(0, 
sqlStatement.getParameterCount()).boxed().collect(Collectors.toList());
+        ContextManager contextManager = mockContextManager();
+        
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+        ConnectionContext connectionContext = mock(ConnectionContext.class);
+        
when(connectionContext.getCurrentDatabaseName()).thenReturn(Optional.of(DATABASE_NAME));
+        
when(connectionSession.getConnectionContext()).thenReturn(connectionContext);
+        
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(
+                statementId, new PostgreSQLServerPreparedStatement(sql, 
sqlStatementContext, new HintValueContext(), parameterTypes, parameterIndexes));
+        Collection<DatabasePacket> actual = executor.execute();
+        Iterator<DatabasePacket> actualIterator = actual.iterator();
+        PostgreSQLParameterDescriptionPacket parameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualIterator.next();
+        PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
+        parameterDescription.write(mockPayload);
+        verify(mockPayload).writeInt2(2);
+        verify(mockPayload).writeInt4(PostgreSQLColumnType.INT4.getValue());
+        verify(mockPayload).writeInt4(PostgreSQLColumnType.INT2.getValue());
     }
     
     private ContextManager mockContextManager() {
@@ -406,6 +538,29 @@ class PostgreSQLComDescribeExecutorTest {
         return result;
     }
     
+    private ContextManager mockContextManager(final ShardingSphereTable table) 
{
+        ContextManager result = mock(ContextManager.class, RETURNS_DEEP_STUBS);
+        
when(result.getMetaDataContexts().getMetaData().getProps()).thenReturn(new 
ConfigurationProperties(new Properties()));
+        
when(connectionSession.getUsedDatabaseName()).thenReturn(DATABASE_NAME);
+        
when(connectionSession.getCurrentDatabaseName()).thenReturn(DATABASE_NAME);
+        
when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(new 
ServerPreparedStatementRegistry());
+        RuleMetaData globalRuleMetaData = new 
RuleMetaData(Collections.singleton(new SQLTranslatorRule(new 
DefaultSQLTranslatorRuleConfigurationBuilder().build())));
+        
when(result.getMetaDataContexts().getMetaData().getGlobalRuleMetaData()).thenReturn(globalRuleMetaData);
+        
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getProtocolType()).thenReturn(DATABASE_TYPE);
+        StorageUnit storageUnit = mock(StorageUnit.class, RETURNS_DEEP_STUBS);
+        when(storageUnit.getStorageType()).thenReturn(DATABASE_TYPE);
+        
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getResourceMetaData().getStorageUnits()).thenReturn(Collections.singletonMap("ds_0",
 storageUnit));
+        
when(result.getMetaDataContexts().getMetaData().containsDatabase(DATABASE_NAME)).thenReturn(true);
+        ShardingSphereSchema schema = mock(ShardingSphereSchema.class);
+        
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).containsSchema("public")).thenReturn(true);
+        
when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getSchema("public")).thenReturn(schema);
+        when(schema.containsTable(table.getName())).thenReturn(true);
+        when(schema.getTable(table.getName())).thenReturn(table);
+        ShardingSphereDatabase database = 
result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME);
+        when(result.getDatabase(DATABASE_NAME)).thenReturn(database);
+        return result;
+    }
+    
     @SuppressWarnings("JDBCResourceOpenedButNotSafelyClosed")
     private void prepareJDBCBackendConnection(final String sql) throws 
SQLException {
         ProxyDatabaseConnectionManager databaseConnectionManager = 
mock(ProxyDatabaseConnectionManager.class);
@@ -419,6 +574,43 @@ class PostgreSQLComDescribeExecutorTest {
         
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
     }
     
+    private void prepareJDBCBackendConnectionWithNullMetaData(final String 
sql) throws SQLException {
+        ProxyDatabaseConnectionManager databaseConnectionManager = 
mock(ProxyDatabaseConnectionManager.class);
+        Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS);
+        PreparedStatement preparedStatement = mock(PreparedStatement.class, 
RETURNS_DEEP_STUBS);
+        when(preparedStatement.getMetaData()).thenReturn(null);
+        when(connection.prepareStatement(sql)).thenReturn(preparedStatement);
+        when(databaseConnectionManager.getConnections(any(), 
nullable(String.class), anyInt(), anyInt(), 
any(ConnectionMode.class))).thenReturn(Collections.singletonList(connection));
+        
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
+    }
+    
+    private void prepareJDBCBackendConnectionWithPreparedStatement(final 
String sql) throws SQLException {
+        ProxyDatabaseConnectionManager databaseConnectionManager = 
mock(ProxyDatabaseConnectionManager.class);
+        Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS);
+        PreparedStatement preparedStatement = mock(PreparedStatement.class);
+        when(connection.prepareStatement(sql)).thenReturn(preparedStatement);
+        when(databaseConnectionManager.getConnections(any(), 
nullable(String.class), anyInt(), anyInt(), 
any(ConnectionMode.class))).thenReturn(Collections.singletonList(connection));
+        
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
+    }
+    
+    private void prepareJDBCBackendConnectionWithParamTypes(final String sql, 
final int[] paramTypes, final String[] paramTypeNames) throws SQLException {
+        ParameterMetaData parameterMetaData = mock(ParameterMetaData.class);
+        for (int i = 0; i < paramTypes.length; i++) {
+            int index = i + 1;
+            
when(parameterMetaData.getParameterType(index)).thenReturn(paramTypes[i]);
+            
when(parameterMetaData.getParameterTypeName(index)).thenReturn(paramTypeNames[i]);
+        }
+        PreparedStatement preparedStatement = mock(PreparedStatement.class, 
RETURNS_DEEP_STUBS);
+        
when(preparedStatement.getParameterMetaData()).thenReturn(parameterMetaData);
+        ResultSetMetaData resultSetMetaData = 
prepareResultSetMetaDataForSingleColumn();
+        when(preparedStatement.getMetaData()).thenReturn(resultSetMetaData);
+        Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS);
+        when(connection.prepareStatement(sql)).thenReturn(preparedStatement);
+        ProxyDatabaseConnectionManager databaseConnectionManager = 
mock(ProxyDatabaseConnectionManager.class);
+        when(databaseConnectionManager.getConnections(any(), 
nullable(String.class), anyInt(), anyInt(), 
any(ConnectionMode.class))).thenReturn(Collections.singletonList(connection));
+        
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
+    }
+    
     private ResultSetMetaData prepareResultSetMetaData() throws SQLException {
         ResultSetMetaData result = mock(ResultSetMetaData.class);
         when(result.getColumnCount()).thenReturn(4);
@@ -429,6 +621,16 @@ class PostgreSQLComDescribeExecutorTest {
         return result;
     }
     
+    private ResultSetMetaData prepareResultSetMetaDataForSingleColumn() throws 
SQLException {
+        ResultSetMetaData result = mock(ResultSetMetaData.class);
+        when(result.getColumnCount()).thenReturn(1);
+        when(result.getColumnName(1)).thenReturn("id");
+        when(result.getColumnType(1)).thenReturn(Types.INTEGER);
+        when(result.getColumnDisplaySize(1)).thenReturn(11);
+        when(result.getColumnTypeName(1)).thenReturn("int4");
+        return result;
+    }
+    
     @SuppressWarnings("unchecked")
     @SneakyThrows(ReflectiveOperationException.class)
     private List<PostgreSQLColumnDescription> getColumnDescriptions(final 
PostgreSQLRowDescriptionPacket packet) {
@@ -439,4 +641,53 @@ class PostgreSQLComDescribeExecutorTest {
     void assertDescribeUnknownType() {
         assertThrows(UnsupportedSQLOperationException.class, () -> new 
PostgreSQLComDescribeExecutor(portalContext, packet, 
connectionSession).execute());
     }
+    
+    private static Stream<Arguments> provideInsertMetaDataCases() {
+        return Stream.of(
+                Arguments.of("insert without columns", "S_meta_1", "INSERT 
INTO t_order VALUES (?, 0, 'char', ?), (2, ?, ?, '')", 4, 2, 2),
+                Arguments.of("insert with columns", "S_meta_2", "INSERT INTO 
t_order (id, k, c, pad) VALUES (1, ?, ?, ?), (?, 2, ?, '')", 5, 2, 3),
+                Arguments.of("insert with case-insensitive columns", 
"S_meta_3", "INSERT INTO t_order (iD, k, c, PaD) VALUES (1, ?, ?, ?), (?, 2, ?, 
'')", 5, 2, 3));
+    }
+    
+    private static Stream<Arguments> provideReturningCases() {
+        return Stream.of(
+                Arguments.of("returning complex columns", 
"S_returning_complex",
+                        "INSERT INTO t_order (k, c, pad) VALUES (?, ?, ?) 
RETURNING"
+                                + " id, id alias_id, 'anonymous', 'OK' 
literal_string, 1 literal_int, 4294967296 literal_bigint, 1.1 literal_numeric, 
t_order.*, t_order, t_order alias_t_order",
+                        Arrays.asList(PostgreSQLColumnType.INT4, 
PostgreSQLColumnType.CHAR, PostgreSQLColumnType.CHAR),
+                        getExpectedReturningColumns()),
+                Arguments.of("returning numeric literal", 
"S_numeric_returning",
+                        "INSERT INTO t_order (k) VALUES (?) RETURNING 1.2 
numeric_value",
+                        Collections.singletonList(PostgreSQLColumnType.INT4),
+                        
Collections.singletonList(expectedColumn("numeric_value", Types.NUMERIC, -1, 
"numeric"))),
+                Arguments.of("returning boolean literal", 
"S_boolean_returning",
+                        "INSERT INTO t_order (k) VALUES (?) RETURNING true 
bool_value",
+                        Collections.singletonList(PostgreSQLColumnType.INT4),
+                        Collections.singletonList(expectedColumn("bool_value", 
Types.VARCHAR, -1, "varchar"))),
+                Arguments.of("returning without parameters", 
"S_returning_only",
+                        "INSERT INTO t_order VALUES (1) RETURNING id",
+                        Collections.emptyList(),
+                        Collections.singletonList(expectedColumn("id", 
Types.INTEGER, 4, "int4"))));
+    }
+    
+    private static PostgreSQLColumnDescription expectedColumn(final String 
columnName, final int jdbcType, final int columnLength, final String 
columnTypeName) {
+        return new PostgreSQLColumnDescription(columnName, 0, jdbcType, 
columnLength, columnTypeName);
+    }
+    
+    private static List<PostgreSQLColumnDescription> 
getExpectedReturningColumns() {
+        return Arrays.asList(
+                expectedColumn("id", Types.INTEGER, 4, "int4"),
+                expectedColumn("alias_id", Types.INTEGER, 4, "int4"),
+                expectedColumn("?column?", Types.VARCHAR, -1, "varchar"),
+                expectedColumn("literal_string", Types.VARCHAR, -1, "varchar"),
+                expectedColumn("literal_int", Types.INTEGER, 4, "int4"),
+                expectedColumn("literal_bigint", Types.BIGINT, 8, "int8"),
+                expectedColumn("literal_numeric", Types.NUMERIC, -1, 
"numeric"),
+                expectedColumn("id", Types.INTEGER, 4, "int4"),
+                expectedColumn("k", Types.INTEGER, 4, "int4"),
+                expectedColumn("c", Types.CHAR, -1, "char"),
+                expectedColumn("pad", Types.CHAR, -1, "char"),
+                expectedColumn("t_order", Types.VARCHAR, -1, "varchar"),
+                expectedColumn("alias_t_order", Types.VARCHAR, -1, "varchar"));
+    }
 }


Reply via email to