This is an automated email from the ASF dual-hosted git repository.

zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 1511687  Supports describe insert statement without column in 
PostgreSQL Proxy (#14808)
1511687 is described below

commit 15116870ad904fa6403fb71b5a9187686bf0e824
Author: 吴伟杰 <[email protected]>
AuthorDate: Sun Jan 16 23:13:28 2022 +0800

    Supports describe insert statement without column in PostgreSQL Proxy 
(#14808)
    
    * Supports describe insert statement without columns
    
    * Complete tests for PostgreSQLComDescribeExecutor
---
 .../describe/PostgreSQLComDescribeExecutor.java    |  16 +-
 .../PostgreSQLComDescribeExecutorTest.java         | 192 ++++++++++++++++++++-
 2 files changed, 196 insertions(+), 12 deletions(-)

diff --git 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
index f149e64..447cf6d 100644
--- 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
+++ 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
@@ -44,7 +44,6 @@ import 
org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
 import 
org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
 import 
org.apache.shardingsphere.proxy.frontend.postgresql.command.PostgreSQLConnectionContext;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
-import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
@@ -64,6 +63,7 @@ import java.util.ListIterator;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 /**
  * Command describe for PostgreSQL.
@@ -128,7 +128,12 @@ public final class PostgreSQLComDescribeExecutor 
implements CommandExecutor {
         String logicTableName = 
insertStatement.getTable().getTableName().getIdentifier().getValue();
         TableMetaData tableMetaData = 
ProxyContext.getInstance().getMetaData(schemaName).getSchema().get(logicTableName);
         Map<String, ColumnMetaData> columnMetaData = 
tableMetaData.getColumns();
-        List<ColumnSegment> columns = new 
ArrayList<>(insertStatement.getColumns());
+        List<String> columnNames;
+        if (insertStatement.getColumns().isEmpty()) {
+            columnNames = new ArrayList<>(tableMetaData.getColumns().keySet());
+        } else {
+            columnNames = insertStatement.getColumns().stream().map(each -> 
each.getIdentifier().getValue()).collect(Collectors.toList());
+        }
         Iterator<InsertValuesSegment> iterator = 
insertStatement.getValues().iterator();
         int parameterMarkerIndex = 0;
         while (iterator.hasNext()) {
@@ -136,11 +141,14 @@ public final class PostgreSQLComDescribeExecutor 
implements CommandExecutor {
             ListIterator<ExpressionSegment> listIterator = 
each.getValues().listIterator();
             for (int columnIndex = listIterator.nextIndex(); 
listIterator.hasNext(); columnIndex = listIterator.nextIndex()) {
                 ExpressionSegment value = listIterator.next();
-                if (!(value instanceof ParameterMarkerExpressionSegment) || 
!unspecifiedTypeParameterIndexes.contains(parameterMarkerIndex)) {
+                if (!(value instanceof ParameterMarkerExpressionSegment)) {
+                    continue;
+                }
+                if 
(!unspecifiedTypeParameterIndexes.contains(parameterMarkerIndex)) {
                     parameterMarkerIndex++;
                     continue;
                 }
-                String columnName = 
columns.get(columnIndex).getIdentifier().getValue();
+                String columnName = columnNames.get(columnIndex);
                 PostgreSQLColumnType parameterType = 
PostgreSQLColumnType.valueOfJDBCType(columnMetaData.get(columnName).getDataType());
                 
preparedStatement.getParameterTypes().set(parameterMarkerIndex++, 
parameterType);
             }
diff --git 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
index cadafc8..8f7cdbe 100644
--- 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
+++ 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
@@ -17,36 +17,79 @@
 
 package 
org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.describe;
 
+import lombok.SneakyThrows;
 import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
+import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLColumnDescription;
+import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLNoDataPacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLParameterDescriptionPacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLRowDescriptionPacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLPreparedStatementRegistry;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.describe.PostgreSQLComDescribePacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
+import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
+import 
org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
+import org.apache.shardingsphere.infra.metadata.schema.model.ColumnMetaData;
+import org.apache.shardingsphere.infra.metadata.schema.model.TableMetaData;
+import org.apache.shardingsphere.infra.parser.ShardingSphereSQLParserEngine;
+import org.apache.shardingsphere.mode.manager.ContextManager;
+import org.apache.shardingsphere.parser.rule.SQLParserRule;
+import 
org.apache.shardingsphere.parser.rule.builder.DefaultSQLParserRuleConfigurationBuilder;
+import 
org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.JDBCBackendConnection;
+import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
 import 
org.apache.shardingsphere.proxy.frontend.postgresql.command.PostgreSQLConnectionContext;
 import 
org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.PostgreSQLPortal;
-import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLSelectStatement;
+import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
+import org.junit.After;
+import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.Answers;
+import org.mockito.InjectMocks;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 
+import java.lang.reflect.Field;
+import java.sql.Connection;
+import java.sql.ParameterMetaData;
+import java.sql.ResultSetMetaData;
 import java.sql.SQLException;
+import java.sql.Types;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
 
 import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.CoreMatchers.is;
 import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.nullable;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 @RunWith(MockitoJUnitRunner.class)
 public final class PostgreSQLComDescribeExecutorTest {
     
+    private static final String SCHEMA_NAME = "postgres";
+    
+    private static final String TABLE_NAME = "t_order";
+    
+    private static final ShardingSphereSQLParserEngine SQL_PARSER_ENGINE = new 
ShardingSphereSQLParserEngine("PostgreSQL", new SQLParserRule(new 
DefaultSQLParserRuleConfigurationBuilder().build()));
+    
+    private ContextManager contextManagerBefore;
+    
+    @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+    private ContextManager mockContextManager;
+    
     @Mock
     private PostgreSQLConnectionContext connectionContext;
     
@@ -56,6 +99,28 @@ public final class PostgreSQLComDescribeExecutorTest {
     @Mock
     private ConnectionSession connectionSession;
     
+    @InjectMocks
+    private PostgreSQLComDescribeExecutor executor;
+    
+    @Before
+    public void setup() {
+        contextManagerBefore = ProxyContext.getInstance().getContextManager();
+        ProxyContext.getInstance().init(mockContextManager);
+        
when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getProps().getValue(ConfigurationPropertyKey.SQL_SHOW)).thenReturn(false);
+        when(connectionSession.getSchemaName()).thenReturn(SCHEMA_NAME);
+        
when(mockContextManager.getMetaDataContexts().getAllSchemaNames().contains(SCHEMA_NAME)).thenReturn(true);
+        prepareTableMetaData();
+    }
+    
+    private void prepareTableMetaData() {
+        Collection<ColumnMetaData> columnMetaData = Arrays.asList(
+                new ColumnMetaData("id", Types.INTEGER, true, false, false),
+                new ColumnMetaData("k", Types.INTEGER, true, false, false),
+                new ColumnMetaData("c", Types.CHAR, true, false, false),
+                new ColumnMetaData("pad", Types.CHAR, true, false, false));
+        
when(mockContextManager.getMetaDataContexts().getMetaData(SCHEMA_NAME).getSchema().get(TABLE_NAME)).thenReturn(new
 TableMetaData(TABLE_NAME, columnMetaData, Collections.emptyList()));
+    }
+    
     @Test
     public void assertDescribePortal() throws SQLException {
         when(packet.getType()).thenReturn('P');
@@ -64,30 +129,141 @@ public final class PostgreSQLComDescribeExecutorTest {
         PostgreSQLRowDescriptionPacket expected = 
mock(PostgreSQLRowDescriptionPacket.class);
         when(portal.describe()).thenReturn(expected);
         when(connectionContext.getPortal("P_1")).thenReturn(portal);
-        Collection<DatabasePacket<?>> actual = new 
PostgreSQLComDescribeExecutor(connectionContext, packet, 
connectionSession).execute();
+        Collection<DatabasePacket<?>> actual = executor.execute();
         assertThat(actual.size(), is(1));
         assertThat(actual.iterator().next(), is(expected));
     }
     
     @Test
-    public void assertDescribePreparedStatement() throws SQLException {
+    public void assertDescribePreparedStatementInsertWithoutColumns() throws 
SQLException {
+        when(packet.getType()).thenReturn('S');
+        final String statementId = "S_1";
+        when(packet.getName()).thenReturn(statementId);
+        final int connectionId = 1;
+        when(connectionSession.getConnectionId()).thenReturn(connectionId);
+        String sql = "insert into t_order values (?, 0, 'char', ?), (2, ?, ?, 
'')";
+        SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+        
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionId);
+        List<PostgreSQLColumnType> parameterTypes = new 
ArrayList<>(sqlStatement.getParameterCount());
+        for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
+            
parameterTypes.add(PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED);
+        }
+        
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionId, 
statementId, sql, sqlStatement, parameterTypes);
+        Collection<DatabasePacket<?>> actualPackets = executor.execute();
+        assertThat(actualPackets.size(), is(2));
+        Iterator<DatabasePacket<?>> actualPacketsIterator = 
actualPackets.iterator();
+        PostgreSQLParameterDescriptionPacket actualParameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
+        PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
+        actualParameterDescription.write(mockPayload);
+        verify(mockPayload).writeInt2(4);
+        verify(mockPayload, times(2)).writeInt4(23);
+        verify(mockPayload, times(2)).writeInt4(18);
+        assertTrue(actualPacketsIterator.next() instanceof 
PostgreSQLNoDataPacket);
+    }
+    
+    @Test
+    public void assertDescribePreparedStatementInsertWithColumns() throws 
SQLException {
         when(packet.getType()).thenReturn('S');
-        when(packet.getName()).thenReturn("S_1");
+        final String statementId = "S_2";
+        when(packet.getName()).thenReturn(statementId);
+        final int connectionId = 1;
+        when(connectionSession.getConnectionId()).thenReturn(connectionId);
+        String sql = "insert into t_order (id, k, c, pad) values (1, ?, ?, ?), 
(?, 2, ?, '')";
+        SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+        
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionId);
+        List<PostgreSQLColumnType> parameterTypes = new 
ArrayList<>(sqlStatement.getParameterCount());
+        for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
+            
parameterTypes.add(PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED);
+        }
+        
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionId, 
statementId, sql, sqlStatement, parameterTypes);
+        Collection<DatabasePacket<?>> actualPackets = executor.execute();
+        assertThat(actualPackets.size(), is(2));
+        Iterator<DatabasePacket<?>> actualPacketsIterator = 
actualPackets.iterator();
+        PostgreSQLParameterDescriptionPacket actualParameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
+        PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
+        actualParameterDescription.write(mockPayload);
+        verify(mockPayload).writeInt2(5);
+        verify(mockPayload, times(2)).writeInt4(23);
+        verify(mockPayload, times(3)).writeInt4(18);
+        assertTrue(actualPacketsIterator.next() instanceof 
PostgreSQLNoDataPacket);
+    }
+    
+    @Test
+    public void assertDescribeSelectPreparedStatement() throws SQLException {
+        when(packet.getType()).thenReturn('S');
+        String statementId = "S_3";
+        when(packet.getName()).thenReturn(statementId);
         when(connectionSession.getConnectionId()).thenReturn(1);
+        final String sql = "select id, k, c, pad from t_order where id = ?";
+        SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+        prepareJDBCBackendConnection(sql);
         PostgreSQLPreparedStatementRegistry.getInstance().register(1);
-        PostgreSQLPreparedStatementRegistry.getInstance().register(1, "S_1", 
"", new PostgreSQLSelectStatement(), 
Collections.singletonList(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4));
-        Collection<DatabasePacket<?>> actual = new 
PostgreSQLComDescribeExecutor(connectionContext, packet, 
connectionSession).execute();
-        assertThat(actual.size(), is(1));
-        PostgreSQLParameterDescriptionPacket actualParameterDescription = 
(PostgreSQLParameterDescriptionPacket) actual.iterator().next();
+        List<PostgreSQLColumnType> parameterTypes = new 
ArrayList<>(Collections.singletonList(PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED));
+        PostgreSQLPreparedStatementRegistry.getInstance().register(1, 
statementId, sql, sqlStatement, parameterTypes);
+        Collection<DatabasePacket<?>> actual = executor.execute();
+        assertThat(actual.size(), is(2));
+        Iterator<DatabasePacket<?>> actualPacketsIterator = actual.iterator();
+        PostgreSQLParameterDescriptionPacket actualParameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
         assertThat(actualParameterDescription, 
instanceOf(PostgreSQLParameterDescriptionPacket.class));
         PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
         actualParameterDescription.write(mockPayload);
         verify(mockPayload).writeInt2(1);
         
verify(mockPayload).writeInt4(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue());
+        PostgreSQLRowDescriptionPacket actualRowDescription = 
(PostgreSQLRowDescriptionPacket) actualPacketsIterator.next();
+        List<PostgreSQLColumnDescription> actualColumnDescriptions = 
getColumnDescriptions(actualRowDescription);
+        List<PostgreSQLColumnDescription> expectedColumnDescriptions = 
Arrays.asList(
+                new PostgreSQLColumnDescription("id", 1, Types.INTEGER, 11, 
"int4"),
+                new PostgreSQLColumnDescription("k", 2, Types.INTEGER, 11, 
"int4"),
+                new PostgreSQLColumnDescription("c", 3, Types.CHAR, 60, 
"int4"),
+                new PostgreSQLColumnDescription("pad", 4, Types.CHAR, 120, 
"int4")
+        );
+        for (int i = 0; i < expectedColumnDescriptions.size(); i++) {
+            PostgreSQLColumnDescription expectedColumnDescription = 
expectedColumnDescriptions.get(i);
+            PostgreSQLColumnDescription actualColumnDescription = 
actualColumnDescriptions.get(i);
+            assertThat(actualColumnDescription.getColumnName(), 
is(expectedColumnDescription.getColumnName()));
+            assertThat(actualColumnDescription.getColumnIndex(), 
is(expectedColumnDescription.getColumnIndex()));
+            assertThat(actualColumnDescription.getColumnLength(), 
is(expectedColumnDescription.getColumnLength()));
+            assertThat(actualColumnDescription.getTypeOID(), 
is(expectedColumnDescription.getTypeOID()));
+        }
+    }
+    
+    private void prepareJDBCBackendConnection(final String sql) throws 
SQLException {
+        JDBCBackendConnection backendConnection = 
mock(JDBCBackendConnection.class);
+        Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS);
+        ParameterMetaData parameterMetaData = mock(ParameterMetaData.class);
+        when(parameterMetaData.getParameterType(1)).thenReturn(Types.INTEGER);
+        
when(connection.prepareStatement(sql).getParameterMetaData()).thenReturn(parameterMetaData);
+        ResultSetMetaData resultSetMetaData = prepareResultSetMetaData();
+        
when(connection.prepareStatement(sql).getMetaData()).thenReturn(resultSetMetaData);
+        when(backendConnection.getConnections(nullable(String.class), 
anyInt(), 
any(ConnectionMode.class))).thenReturn(Collections.singletonList(connection));
+        
when(connectionSession.getBackendConnection()).thenReturn(backendConnection);
+    }
+    
+    private ResultSetMetaData prepareResultSetMetaData() throws SQLException {
+        ResultSetMetaData result = mock(ResultSetMetaData.class);
+        when(result.getColumnCount()).thenReturn(4);
+        when(result.getColumnName(anyInt())).thenReturn("id", "k", "c", "pad");
+        when(result.getColumnType(anyInt())).thenReturn(Types.INTEGER, 
Types.INTEGER, Types.CHAR, Types.CHAR);
+        when(result.getColumnDisplaySize(anyInt())).thenReturn(11, 11, 60, 
120);
+        when(result.getColumnTypeName(anyInt())).thenReturn("int4", "int4", 
"char", "char");
+        return result;
+    }
+    
+    @SuppressWarnings("unchecked")
+    @SneakyThrows
+    private List<PostgreSQLColumnDescription> getColumnDescriptions(final 
PostgreSQLRowDescriptionPacket packet) {
+        Field columnDescriptionsField = 
PostgreSQLRowDescriptionPacket.class.getDeclaredField("columnDescriptions");
+        columnDescriptionsField.setAccessible(true);
+        return (List<PostgreSQLColumnDescription>) 
columnDescriptionsField.get(packet);
     }
     
     @Test(expected = UnsupportedOperationException.class)
     public void assertDescribeUnknownType() throws SQLException {
         new PostgreSQLComDescribeExecutor(connectionContext, packet, 
connectionSession).execute();
     }
+    
+    @After
+    public void tearDown() {
+        ProxyContext.getInstance().init(contextManagerBefore);
+    }
 }

Reply via email to