This is an automated email from the ASF dual-hosted git repository.

panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 1d820a8eb23 PostgreSQL/openGauss support describe insert returning 
clause (#22379)
1d820a8eb23 is described below

commit 1d820a8eb23800f3cafc0eb7aca9f5fc9cdfd30c
Author: 吴伟杰 <[email protected]>
AuthorDate: Thu Nov 24 15:16:54 2022 +0800

    PostgreSQL/openGauss support describe insert returning clause (#22379)
---
 .../describe/PostgreSQLComDescribeExecutor.java    | 132 +++++++++++++++------
 .../PostgreSQLComDescribeExecutorTest.java         |  72 +++++++++++
 2 files changed, 167 insertions(+), 37 deletions(-)

diff --git 
a/proxy/frontend/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
 
b/proxy/frontend/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
index 8ac7d8e3c11..f37f7aac60a 100644
--- 
a/proxy/frontend/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
+++ 
b/proxy/frontend/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
@@ -46,26 +46,33 @@ import 
org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
 import 
org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
 import 
org.apache.shardingsphere.proxy.frontend.postgresql.command.PortalContext;
 import 
org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.PostgreSQLServerPreparedStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.ReturningSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ExpressionProjectionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
 
 import java.sql.Connection;
 import java.sql.ParameterMetaData;
 import java.sql.PreparedStatement;
 import java.sql.ResultSetMetaData;
 import java.sql.SQLException;
+import java.sql.Types;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
-import java.util.Iterator;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.ListIterator;
 import java.util.Map;
 import java.util.Optional;
-import java.util.Set;
 import java.util.TreeMap;
 import java.util.stream.Collectors;
 
@@ -75,6 +82,8 @@ import java.util.stream.Collectors;
 @RequiredArgsConstructor
 public final class PostgreSQLComDescribeExecutor implements CommandExecutor {
     
+    private static final String ANONYMOUS_COLUMN_NAME = "?column?";
+    
     private final PortalContext portalContext;
     
     private final PostgreSQLComDescribePacket packet;
@@ -116,33 +125,21 @@ public final class PostgreSQLComDescribeExecutor 
implements CommandExecutor {
     }
     
     private void describeInsertStatementByDatabaseMetaData(final 
PostgreSQLServerPreparedStatement preparedStatement) {
-        if (!preparedStatement.describeRows().isPresent()) {
-            // TODO Consider the SQL `insert into table (col) values ($1) 
returning id`
-            
preparedStatement.setRowDescription(PostgreSQLNoDataPacket.getInstance());
-        }
         InsertStatement insertStatement = (InsertStatement) 
preparedStatement.getSqlStatementContext().getSqlStatement();
-        if (0 == insertStatement.getParameterCount()) {
+        Collection<Integer> unspecifiedTypeParameterIndexes = 
getUnspecifiedTypeParameterIndexes(preparedStatement);
+        Optional<ReturningSegment> returningSegment = 
InsertStatementHandler.getReturningSegment(insertStatement);
+        if (0 == insertStatement.getParameterCount() && 
unspecifiedTypeParameterIndexes.isEmpty() && !returningSegment.isPresent()) {
             return;
         }
-        Set<Integer> unspecifiedTypeParameterIndexes = 
getUnspecifiedTypeParameterIndexes(preparedStatement);
-        if (unspecifiedTypeParameterIndexes.isEmpty()) {
-            return;
-        }
-        String databaseName = connectionSession.getDatabaseName();
         String logicTableName = 
insertStatement.getTable().getTableName().getIdentifier().getValue();
-        ShardingSphereDatabase database = 
ProxyContext.getInstance().getDatabase(databaseName);
-        String schemaName = insertStatement.getTable().getOwner().map(optional 
-> optional.getIdentifier()
-                .getValue()).orElseGet(() -> 
DatabaseTypeEngine.getDefaultSchemaName(database.getProtocolType(), 
databaseName));
-        ShardingSphereTable table = 
database.getSchema(schemaName).getTable(logicTableName);
-        Map<String, ShardingSphereColumn> columns = table.getColumns();
-        Map<String, ShardingSphereColumn> caseInsensitiveColumns = null;
-        List<String> columnNames = insertStatement.getColumns().isEmpty()
-                ? new ArrayList<>(table.getColumns().keySet())
-                : insertStatement.getColumns().stream().map(each -> 
each.getIdentifier().getValue()).collect(Collectors.toList());
-        Iterator<InsertValuesSegment> iterator = 
insertStatement.getValues().iterator();
+        ShardingSphereTable table = 
getTableFromMetaData(connectionSession.getDatabaseName(), insertStatement, 
logicTableName);
+        List<String> columnNamesOfInsert = 
getColumnNamesOfInsertStatement(insertStatement, table);
+        Map<String, ShardingSphereColumn> columnsOfTable = table.getColumns();
+        Map<String, ShardingSphereColumn> caseInsensitiveColumnsOfTable = 
convertToCaseInsensitiveColumnsOfTable(columnsOfTable);
+        
preparedStatement.setRowDescription(returningSegment.<PostgreSQLPacket>map(returning
 -> describeReturning(returning, columnsOfTable, caseInsensitiveColumnsOfTable))
+                .orElseGet(PostgreSQLNoDataPacket::getInstance));
         int parameterMarkerIndex = 0;
-        while (iterator.hasNext()) {
-            InsertValuesSegment each = iterator.next();
+        for (InsertValuesSegment each : insertStatement.getValues()) {
             ListIterator<ExpressionSegment> listIterator = 
each.getValues().listIterator();
             for (int columnIndex = listIterator.nextIndex(); 
listIterator.hasNext(); columnIndex = listIterator.nextIndex()) {
                 ExpressionSegment value = listIterator.next();
@@ -153,37 +150,98 @@ public final class PostgreSQLComDescribeExecutor 
implements CommandExecutor {
                     parameterMarkerIndex++;
                     continue;
                 }
-                String columnName = columnNames.get(columnIndex);
-                ShardingSphereColumn column = columns.get(columnName);
-                if (null == column) {
-                    if (null == caseInsensitiveColumns) {
-                        caseInsensitiveColumns = 
convertToCaseInsensitiveColumnMetaDataMap(columns);
-                    }
-                    column = caseInsensitiveColumns.get(columnName);
-                }
+                String columnName = columnNamesOfInsert.get(columnIndex);
+                ShardingSphereColumn column = 
columnsOfTable.getOrDefault(columnName, 
caseInsensitiveColumnsOfTable.get(columnName));
                 ShardingSpherePreconditions.checkState(null != column, () -> 
new ColumnNotFoundException(logicTableName, columnName));
                 
preparedStatement.getParameterTypes().set(parameterMarkerIndex++, 
PostgreSQLColumnType.valueOfJDBCType(column.getDataType()));
             }
         }
     }
     
-    private Set<Integer> getUnspecifiedTypeParameterIndexes(final 
PostgreSQLServerPreparedStatement preparedStatement) {
-        Set<Integer> unspecifiedTypeParameterIndexes = new HashSet<>();
+    private Collection<Integer> getUnspecifiedTypeParameterIndexes(final 
PostgreSQLServerPreparedStatement preparedStatement) {
+        Collection<Integer> result = new HashSet<>();
         ListIterator<PostgreSQLColumnType> parameterTypesListIterator = 
preparedStatement.getParameterTypes().listIterator();
         for (int index = parameterTypesListIterator.nextIndex(); 
parameterTypesListIterator.hasNext(); index = 
parameterTypesListIterator.nextIndex()) {
             if (PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED == 
parameterTypesListIterator.next()) {
-                unspecifiedTypeParameterIndexes.add(index);
+                result.add(index);
             }
         }
-        return unspecifiedTypeParameterIndexes;
+        return result;
+    }
+    
+    private ShardingSphereTable getTableFromMetaData(final String 
databaseName, final InsertStatement insertStatement, final String 
logicTableName) {
+        ShardingSphereDatabase database = 
ProxyContext.getInstance().getDatabase(databaseName);
+        String schemaName = insertStatement.getTable().getOwner().map(optional 
-> optional.getIdentifier()
+                .getValue()).orElseGet(() -> 
DatabaseTypeEngine.getDefaultSchemaName(database.getProtocolType(), 
databaseName));
+        return database.getSchema(schemaName).getTable(logicTableName);
+    }
+    
+    private static List<String> getColumnNamesOfInsertStatement(final 
InsertStatement insertStatement, final ShardingSphereTable table) {
+        return insertStatement.getColumns().isEmpty() ? new 
ArrayList<>(table.getColumns().keySet())
+                : insertStatement.getColumns().stream().map(each -> 
each.getIdentifier().getValue()).collect(Collectors.toList());
     }
     
-    private Map<String, ShardingSphereColumn> 
convertToCaseInsensitiveColumnMetaDataMap(final Map<String, 
ShardingSphereColumn> columns) {
+    private Map<String, ShardingSphereColumn> 
convertToCaseInsensitiveColumnsOfTable(final Map<String, ShardingSphereColumn> 
columns) {
         Map<String, ShardingSphereColumn> result = new 
TreeMap<>(String.CASE_INSENSITIVE_ORDER);
         result.putAll(columns);
         return result;
     }
     
+    private PostgreSQLRowDescriptionPacket describeReturning(final 
ReturningSegment returningSegment, final Map<String, ShardingSphereColumn> 
columnsOfTable,
+                                                             final Map<String, 
ShardingSphereColumn> caseInsensitiveColumnsOfTable) {
+        Collection<PostgreSQLColumnDescription> result = new LinkedList<>();
+        for (ProjectionSegment each : 
returningSegment.getProjections().getProjections()) {
+            if (each instanceof ShorthandProjectionSegment) {
+                columnsOfTable.values().stream().map(column -> new 
PostgreSQLColumnDescription(column.getName(), 0, column.getDataType(), 
estimateColumnLength(column.getDataType()), ""))
+                        .forEach(result::add);
+            }
+            if (each instanceof ColumnProjectionSegment) {
+                String columnName = ((ColumnProjectionSegment) 
each).getColumn().getIdentifier().getValue();
+                ShardingSphereColumn column = 
columnsOfTable.getOrDefault(columnName, 
caseInsensitiveColumnsOfTable.get(columnName));
+                String alias = ((ColumnProjectionSegment) 
each).getAlias().orElseGet(column::getName);
+                result.add(new PostgreSQLColumnDescription(alias, 0, 
column.getDataType(), estimateColumnLength(column.getDataType()), ""));
+            }
+            if (each instanceof ExpressionProjectionSegment) {
+                
result.add(convertExpressionToDescription((ExpressionProjectionSegment) each));
+            }
+        }
+        return new PostgreSQLRowDescriptionPacket(result.size(), result);
+    }
+    
+    private PostgreSQLColumnDescription convertExpressionToDescription(final 
ExpressionProjectionSegment expressionProjectionSegment) {
+        ExpressionSegment expressionSegment = 
expressionProjectionSegment.getExpr();
+        String columnName = 
expressionProjectionSegment.getAlias().orElse(ANONYMOUS_COLUMN_NAME);
+        if (expressionSegment instanceof LiteralExpressionSegment) {
+            Object value = ((LiteralExpressionSegment) 
expressionSegment).getLiterals();
+            if (value instanceof String) {
+                return new PostgreSQLColumnDescription(columnName, 0, 
Types.VARCHAR, estimateColumnLength(Types.VARCHAR), "");
+            }
+            if (value instanceof Integer) {
+                return new PostgreSQLColumnDescription(columnName, 0, 
Types.INTEGER, estimateColumnLength(Types.INTEGER), "");
+            }
+            if (value instanceof Long) {
+                return new PostgreSQLColumnDescription(columnName, 0, 
Types.BIGINT, estimateColumnLength(Types.BIGINT), "");
+            }
+            if (value instanceof Number) {
+                return new PostgreSQLColumnDescription(columnName, 0, 
Types.NUMERIC, estimateColumnLength(Types.NUMERIC), "");
+            }
+        }
+        return new PostgreSQLColumnDescription(columnName, 0, Types.VARCHAR, 
estimateColumnLength(Types.VARCHAR), "");
+    }
+    
+    private int estimateColumnLength(final int jdbcType) {
+        switch (jdbcType) {
+            case Types.SMALLINT:
+                return 2;
+            case Types.INTEGER:
+                return 4;
+            case Types.BIGINT:
+                return 8;
+            default:
+                return -1;
+        }
+    }
+    
     private void tryDescribePreparedStatementByJDBC(final 
PostgreSQLServerPreparedStatement logicPreparedStatement) throws SQLException {
         if (!(connectionSession.getBackendConnection() instanceof 
JDBCBackendConnection)) {
             return;
diff --git 
a/proxy/frontend/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
 
b/proxy/frontend/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
index 568bb7e884b..2b0db7e5f9f 100644
--- 
a/proxy/frontend/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
+++ 
b/proxy/frontend/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
@@ -248,6 +248,78 @@ public final class PostgreSQLComDescribeExecutorTest 
extends ProxyContextRestore
         executor.execute();
     }
     
+    @SuppressWarnings("rawtypes")
+    @Test
+    public void assertDescribePreparedStatementInsertWithReturningClause() 
throws SQLException {
+        when(packet.getType()).thenReturn('S');
+        final String statementId = "S_2";
+        when(packet.getName()).thenReturn(statementId);
+        String sql = "insert into t_order (k, c, pad) values (?, ?, ?) "
+                + "returning id, id alias_id, 'anonymous', 'OK' 
literal_string, 1 literal_int, 4294967296 literal_bigint, 1.1 literal_numeric, 
t_order.*";
+        SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+        List<PostgreSQLColumnType> parameterTypes = new 
ArrayList<>(sqlStatement.getParameterCount());
+        for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
+            
parameterTypes.add(PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED);
+        }
+        SQLStatementContext sqlStatementContext = 
mock(InsertStatementContext.class);
+        when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+        
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId,
 new PostgreSQLServerPreparedStatement(sql, sqlStatementContext, 
parameterTypes));
+        Collection<DatabasePacket<?>> actualPackets = executor.execute();
+        assertThat(actualPackets.size(), is(2));
+        Iterator<DatabasePacket<?>> actualPacketsIterator = 
actualPackets.iterator();
+        PostgreSQLParameterDescriptionPacket actualParameterDescription = 
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
+        PostgreSQLPacketPayload mockPayload = 
mock(PostgreSQLPacketPayload.class);
+        actualParameterDescription.write(mockPayload);
+        verify(mockPayload).writeInt2(3);
+        verify(mockPayload).writeInt4(23);
+        verify(mockPayload, times(2)).writeInt4(18);
+        DatabasePacket<?> actualRowDescriptionPacket = 
actualPacketsIterator.next();
+        assertThat(actualRowDescriptionPacket, 
is(instanceOf(PostgreSQLRowDescriptionPacket.class)));
+        List<PostgreSQLColumnDescription> actualColumnDescriptions = new 
ArrayList<>(getColumnDescriptionsFromPacket((PostgreSQLRowDescriptionPacket) 
actualRowDescriptionPacket));
+        assertThat(actualColumnDescriptions.size(), is(11));
+        assertThat(actualColumnDescriptions.get(0).getColumnName(), is("id"));
+        assertThat(actualColumnDescriptions.get(0).getTypeOID(), 
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue()));
+        assertThat(actualColumnDescriptions.get(0).getColumnLength(), is(4));
+        assertThat(actualColumnDescriptions.get(1).getColumnName(), 
is("alias_id"));
+        assertThat(actualColumnDescriptions.get(1).getTypeOID(), 
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue()));
+        assertThat(actualColumnDescriptions.get(1).getColumnLength(), is(4));
+        assertThat(actualColumnDescriptions.get(2).getColumnName(), 
is("?column?"));
+        assertThat(actualColumnDescriptions.get(2).getTypeOID(), 
is(PostgreSQLColumnType.POSTGRESQL_TYPE_VARCHAR.getValue()));
+        assertThat(actualColumnDescriptions.get(2).getColumnLength(), is(-1));
+        assertThat(actualColumnDescriptions.get(3).getColumnName(), 
is("literal_string"));
+        assertThat(actualColumnDescriptions.get(3).getTypeOID(), 
is(PostgreSQLColumnType.POSTGRESQL_TYPE_VARCHAR.getValue()));
+        assertThat(actualColumnDescriptions.get(3).getColumnLength(), is(-1));
+        assertThat(actualColumnDescriptions.get(4).getColumnName(), 
is("literal_int"));
+        assertThat(actualColumnDescriptions.get(4).getTypeOID(), 
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue()));
+        assertThat(actualColumnDescriptions.get(4).getColumnLength(), is(4));
+        assertThat(actualColumnDescriptions.get(5).getColumnName(), 
is("literal_bigint"));
+        assertThat(actualColumnDescriptions.get(5).getTypeOID(), 
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT8.getValue()));
+        assertThat(actualColumnDescriptions.get(5).getColumnLength(), is(8));
+        assertThat(actualColumnDescriptions.get(6).getColumnName(), 
is("literal_numeric"));
+        assertThat(actualColumnDescriptions.get(6).getTypeOID(), 
is(PostgreSQLColumnType.POSTGRESQL_TYPE_NUMERIC.getValue()));
+        assertThat(actualColumnDescriptions.get(6).getColumnLength(), is(-1));
+        assertThat(actualColumnDescriptions.get(7).getColumnName(), is("id"));
+        assertThat(actualColumnDescriptions.get(7).getTypeOID(), 
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue()));
+        assertThat(actualColumnDescriptions.get(7).getColumnLength(), is(4));
+        assertThat(actualColumnDescriptions.get(8).getColumnName(), is("k"));
+        assertThat(actualColumnDescriptions.get(8).getTypeOID(), 
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue()));
+        assertThat(actualColumnDescriptions.get(8).getColumnLength(), is(4));
+        assertThat(actualColumnDescriptions.get(9).getColumnName(), is("c"));
+        assertThat(actualColumnDescriptions.get(9).getTypeOID(), 
is(PostgreSQLColumnType.POSTGRESQL_TYPE_CHAR.getValue()));
+        assertThat(actualColumnDescriptions.get(9).getColumnLength(), is(-1));
+        assertThat(actualColumnDescriptions.get(10).getColumnName(), 
is("pad"));
+        assertThat(actualColumnDescriptions.get(10).getTypeOID(), 
is(PostgreSQLColumnType.POSTGRESQL_TYPE_CHAR.getValue()));
+        assertThat(actualColumnDescriptions.get(10).getColumnLength(), is(-1));
+    }
+    
+    @SuppressWarnings("unchecked")
+    @SneakyThrows({NoSuchFieldException.class, IllegalAccessException.class})
+    private Collection<PostgreSQLColumnDescription> 
getColumnDescriptionsFromPacket(final PostgreSQLRowDescriptionPacket packet) {
+        Field field = 
PostgreSQLRowDescriptionPacket.class.getDeclaredField("columnDescriptions");
+        field.setAccessible(true);
+        return (Collection<PostgreSQLColumnDescription>) field.get(packet);
+    }
+    
     @SuppressWarnings("rawtypes")
     @Test
     public void assertDescribeSelectPreparedStatement() throws SQLException {

Reply via email to