This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 1d820a8eb23 PostgreSQL/openGauss support describe insert returning
clause (#22379)
1d820a8eb23 is described below
commit 1d820a8eb23800f3cafc0eb7aca9f5fc9cdfd30c
Author: 吴伟杰 <[email protected]>
AuthorDate: Thu Nov 24 15:16:54 2022 +0800
PostgreSQL/openGauss support describe insert returning clause (#22379)
---
.../describe/PostgreSQLComDescribeExecutor.java | 132 +++++++++++++++------
.../PostgreSQLComDescribeExecutorTest.java | 72 +++++++++++
2 files changed, 167 insertions(+), 37 deletions(-)
diff --git
a/proxy/frontend/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
b/proxy/frontend/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
index 8ac7d8e3c11..f37f7aac60a 100644
---
a/proxy/frontend/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
+++
b/proxy/frontend/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
@@ -46,26 +46,33 @@ import
org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import
org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
import
org.apache.shardingsphere.proxy.frontend.postgresql.command.PortalContext;
import
org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.PostgreSQLServerPreparedStatement;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.ReturningSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ExpressionProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
+import
org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
import java.sql.Connection;
import java.sql.ParameterMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
+import java.sql.Types;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
-import java.util.Iterator;
+import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Optional;
-import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
@@ -75,6 +82,8 @@ import java.util.stream.Collectors;
@RequiredArgsConstructor
public final class PostgreSQLComDescribeExecutor implements CommandExecutor {
+ private static final String ANONYMOUS_COLUMN_NAME = "?column?";
+
private final PortalContext portalContext;
private final PostgreSQLComDescribePacket packet;
@@ -116,33 +125,21 @@ public final class PostgreSQLComDescribeExecutor
implements CommandExecutor {
}
private void describeInsertStatementByDatabaseMetaData(final
PostgreSQLServerPreparedStatement preparedStatement) {
- if (!preparedStatement.describeRows().isPresent()) {
- // TODO Consider the SQL `insert into table (col) values ($1)
returning id`
-
preparedStatement.setRowDescription(PostgreSQLNoDataPacket.getInstance());
- }
InsertStatement insertStatement = (InsertStatement)
preparedStatement.getSqlStatementContext().getSqlStatement();
- if (0 == insertStatement.getParameterCount()) {
+ Collection<Integer> unspecifiedTypeParameterIndexes =
getUnspecifiedTypeParameterIndexes(preparedStatement);
+ Optional<ReturningSegment> returningSegment =
InsertStatementHandler.getReturningSegment(insertStatement);
+ if (0 == insertStatement.getParameterCount() &&
unspecifiedTypeParameterIndexes.isEmpty() && !returningSegment.isPresent()) {
return;
}
- Set<Integer> unspecifiedTypeParameterIndexes =
getUnspecifiedTypeParameterIndexes(preparedStatement);
- if (unspecifiedTypeParameterIndexes.isEmpty()) {
- return;
- }
- String databaseName = connectionSession.getDatabaseName();
String logicTableName =
insertStatement.getTable().getTableName().getIdentifier().getValue();
- ShardingSphereDatabase database =
ProxyContext.getInstance().getDatabase(databaseName);
- String schemaName = insertStatement.getTable().getOwner().map(optional
-> optional.getIdentifier()
- .getValue()).orElseGet(() ->
DatabaseTypeEngine.getDefaultSchemaName(database.getProtocolType(),
databaseName));
- ShardingSphereTable table =
database.getSchema(schemaName).getTable(logicTableName);
- Map<String, ShardingSphereColumn> columns = table.getColumns();
- Map<String, ShardingSphereColumn> caseInsensitiveColumns = null;
- List<String> columnNames = insertStatement.getColumns().isEmpty()
- ? new ArrayList<>(table.getColumns().keySet())
- : insertStatement.getColumns().stream().map(each ->
each.getIdentifier().getValue()).collect(Collectors.toList());
- Iterator<InsertValuesSegment> iterator =
insertStatement.getValues().iterator();
+ ShardingSphereTable table =
getTableFromMetaData(connectionSession.getDatabaseName(), insertStatement,
logicTableName);
+ List<String> columnNamesOfInsert =
getColumnNamesOfInsertStatement(insertStatement, table);
+ Map<String, ShardingSphereColumn> columnsOfTable = table.getColumns();
+ Map<String, ShardingSphereColumn> caseInsensitiveColumnsOfTable =
convertToCaseInsensitiveColumnsOfTable(columnsOfTable);
+
preparedStatement.setRowDescription(returningSegment.<PostgreSQLPacket>map(returning
-> describeReturning(returning, columnsOfTable, caseInsensitiveColumnsOfTable))
+ .orElseGet(PostgreSQLNoDataPacket::getInstance));
int parameterMarkerIndex = 0;
- while (iterator.hasNext()) {
- InsertValuesSegment each = iterator.next();
+ for (InsertValuesSegment each : insertStatement.getValues()) {
ListIterator<ExpressionSegment> listIterator =
each.getValues().listIterator();
for (int columnIndex = listIterator.nextIndex();
listIterator.hasNext(); columnIndex = listIterator.nextIndex()) {
ExpressionSegment value = listIterator.next();
@@ -153,37 +150,98 @@ public final class PostgreSQLComDescribeExecutor
implements CommandExecutor {
parameterMarkerIndex++;
continue;
}
- String columnName = columnNames.get(columnIndex);
- ShardingSphereColumn column = columns.get(columnName);
- if (null == column) {
- if (null == caseInsensitiveColumns) {
- caseInsensitiveColumns =
convertToCaseInsensitiveColumnMetaDataMap(columns);
- }
- column = caseInsensitiveColumns.get(columnName);
- }
+ String columnName = columnNamesOfInsert.get(columnIndex);
+ ShardingSphereColumn column =
columnsOfTable.getOrDefault(columnName,
caseInsensitiveColumnsOfTable.get(columnName));
ShardingSpherePreconditions.checkState(null != column, () ->
new ColumnNotFoundException(logicTableName, columnName));
preparedStatement.getParameterTypes().set(parameterMarkerIndex++,
PostgreSQLColumnType.valueOfJDBCType(column.getDataType()));
}
}
}
- private Set<Integer> getUnspecifiedTypeParameterIndexes(final
PostgreSQLServerPreparedStatement preparedStatement) {
- Set<Integer> unspecifiedTypeParameterIndexes = new HashSet<>();
+ private Collection<Integer> getUnspecifiedTypeParameterIndexes(final
PostgreSQLServerPreparedStatement preparedStatement) {
+ Collection<Integer> result = new HashSet<>();
ListIterator<PostgreSQLColumnType> parameterTypesListIterator =
preparedStatement.getParameterTypes().listIterator();
for (int index = parameterTypesListIterator.nextIndex();
parameterTypesListIterator.hasNext(); index =
parameterTypesListIterator.nextIndex()) {
if (PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED ==
parameterTypesListIterator.next()) {
- unspecifiedTypeParameterIndexes.add(index);
+ result.add(index);
}
}
- return unspecifiedTypeParameterIndexes;
+ return result;
+ }
+
+ private ShardingSphereTable getTableFromMetaData(final String
databaseName, final InsertStatement insertStatement, final String
logicTableName) {
+ ShardingSphereDatabase database =
ProxyContext.getInstance().getDatabase(databaseName);
+ String schemaName = insertStatement.getTable().getOwner().map(optional
-> optional.getIdentifier()
+ .getValue()).orElseGet(() ->
DatabaseTypeEngine.getDefaultSchemaName(database.getProtocolType(),
databaseName));
+ return database.getSchema(schemaName).getTable(logicTableName);
+ }
+
+ private static List<String> getColumnNamesOfInsertStatement(final
InsertStatement insertStatement, final ShardingSphereTable table) {
+ return insertStatement.getColumns().isEmpty() ? new
ArrayList<>(table.getColumns().keySet())
+ : insertStatement.getColumns().stream().map(each ->
each.getIdentifier().getValue()).collect(Collectors.toList());
}
- private Map<String, ShardingSphereColumn>
convertToCaseInsensitiveColumnMetaDataMap(final Map<String,
ShardingSphereColumn> columns) {
+ private Map<String, ShardingSphereColumn>
convertToCaseInsensitiveColumnsOfTable(final Map<String, ShardingSphereColumn>
columns) {
Map<String, ShardingSphereColumn> result = new
TreeMap<>(String.CASE_INSENSITIVE_ORDER);
result.putAll(columns);
return result;
}
+ private PostgreSQLRowDescriptionPacket describeReturning(final
ReturningSegment returningSegment, final Map<String, ShardingSphereColumn>
columnsOfTable,
+ final Map<String,
ShardingSphereColumn> caseInsensitiveColumnsOfTable) {
+ Collection<PostgreSQLColumnDescription> result = new LinkedList<>();
+ for (ProjectionSegment each :
returningSegment.getProjections().getProjections()) {
+ if (each instanceof ShorthandProjectionSegment) {
+ columnsOfTable.values().stream().map(column -> new
PostgreSQLColumnDescription(column.getName(), 0, column.getDataType(),
estimateColumnLength(column.getDataType()), ""))
+ .forEach(result::add);
+ }
+ if (each instanceof ColumnProjectionSegment) {
+ String columnName = ((ColumnProjectionSegment)
each).getColumn().getIdentifier().getValue();
+ ShardingSphereColumn column =
columnsOfTable.getOrDefault(columnName,
caseInsensitiveColumnsOfTable.get(columnName));
+ String alias = ((ColumnProjectionSegment)
each).getAlias().orElseGet(column::getName);
+ result.add(new PostgreSQLColumnDescription(alias, 0,
column.getDataType(), estimateColumnLength(column.getDataType()), ""));
+ }
+ if (each instanceof ExpressionProjectionSegment) {
+
result.add(convertExpressionToDescription((ExpressionProjectionSegment) each));
+ }
+ }
+ return new PostgreSQLRowDescriptionPacket(result.size(), result);
+ }
+
+ private PostgreSQLColumnDescription convertExpressionToDescription(final
ExpressionProjectionSegment expressionProjectionSegment) {
+ ExpressionSegment expressionSegment =
expressionProjectionSegment.getExpr();
+ String columnName =
expressionProjectionSegment.getAlias().orElse(ANONYMOUS_COLUMN_NAME);
+ if (expressionSegment instanceof LiteralExpressionSegment) {
+ Object value = ((LiteralExpressionSegment)
expressionSegment).getLiterals();
+ if (value instanceof String) {
+ return new PostgreSQLColumnDescription(columnName, 0,
Types.VARCHAR, estimateColumnLength(Types.VARCHAR), "");
+ }
+ if (value instanceof Integer) {
+ return new PostgreSQLColumnDescription(columnName, 0,
Types.INTEGER, estimateColumnLength(Types.INTEGER), "");
+ }
+ if (value instanceof Long) {
+ return new PostgreSQLColumnDescription(columnName, 0,
Types.BIGINT, estimateColumnLength(Types.BIGINT), "");
+ }
+ if (value instanceof Number) {
+ return new PostgreSQLColumnDescription(columnName, 0,
Types.NUMERIC, estimateColumnLength(Types.NUMERIC), "");
+ }
+ }
+ return new PostgreSQLColumnDescription(columnName, 0, Types.VARCHAR,
estimateColumnLength(Types.VARCHAR), "");
+ }
+
+ private int estimateColumnLength(final int jdbcType) {
+ switch (jdbcType) {
+ case Types.SMALLINT:
+ return 2;
+ case Types.INTEGER:
+ return 4;
+ case Types.BIGINT:
+ return 8;
+ default:
+ return -1;
+ }
+ }
+
private void tryDescribePreparedStatementByJDBC(final
PostgreSQLServerPreparedStatement logicPreparedStatement) throws SQLException {
if (!(connectionSession.getBackendConnection() instanceof
JDBCBackendConnection)) {
return;
diff --git
a/proxy/frontend/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
b/proxy/frontend/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
index 568bb7e884b..2b0db7e5f9f 100644
---
a/proxy/frontend/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
+++
b/proxy/frontend/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
@@ -248,6 +248,78 @@ public final class PostgreSQLComDescribeExecutorTest
extends ProxyContextRestore
executor.execute();
}
+ @SuppressWarnings("rawtypes")
+ @Test
+ public void assertDescribePreparedStatementInsertWithReturningClause()
throws SQLException {
+ when(packet.getType()).thenReturn('S');
+ final String statementId = "S_2";
+ when(packet.getName()).thenReturn(statementId);
+ String sql = "insert into t_order (k, c, pad) values (?, ?, ?) "
+ + "returning id, id alias_id, 'anonymous', 'OK'
literal_string, 1 literal_int, 4294967296 literal_bigint, 1.1 literal_numeric,
t_order.*";
+ SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+ List<PostgreSQLColumnType> parameterTypes = new
ArrayList<>(sqlStatement.getParameterCount());
+ for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
+
parameterTypes.add(PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED);
+ }
+ SQLStatementContext sqlStatementContext =
mock(InsertStatementContext.class);
+ when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId,
new PostgreSQLServerPreparedStatement(sql, sqlStatementContext,
parameterTypes));
+ Collection<DatabasePacket<?>> actualPackets = executor.execute();
+ assertThat(actualPackets.size(), is(2));
+ Iterator<DatabasePacket<?>> actualPacketsIterator =
actualPackets.iterator();
+ PostgreSQLParameterDescriptionPacket actualParameterDescription =
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
+ PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
+ actualParameterDescription.write(mockPayload);
+ verify(mockPayload).writeInt2(3);
+ verify(mockPayload).writeInt4(23);
+ verify(mockPayload, times(2)).writeInt4(18);
+ DatabasePacket<?> actualRowDescriptionPacket =
actualPacketsIterator.next();
+ assertThat(actualRowDescriptionPacket,
is(instanceOf(PostgreSQLRowDescriptionPacket.class)));
+ List<PostgreSQLColumnDescription> actualColumnDescriptions = new
ArrayList<>(getColumnDescriptionsFromPacket((PostgreSQLRowDescriptionPacket)
actualRowDescriptionPacket));
+ assertThat(actualColumnDescriptions.size(), is(11));
+ assertThat(actualColumnDescriptions.get(0).getColumnName(), is("id"));
+ assertThat(actualColumnDescriptions.get(0).getTypeOID(),
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue()));
+ assertThat(actualColumnDescriptions.get(0).getColumnLength(), is(4));
+ assertThat(actualColumnDescriptions.get(1).getColumnName(),
is("alias_id"));
+ assertThat(actualColumnDescriptions.get(1).getTypeOID(),
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue()));
+ assertThat(actualColumnDescriptions.get(1).getColumnLength(), is(4));
+ assertThat(actualColumnDescriptions.get(2).getColumnName(),
is("?column?"));
+ assertThat(actualColumnDescriptions.get(2).getTypeOID(),
is(PostgreSQLColumnType.POSTGRESQL_TYPE_VARCHAR.getValue()));
+ assertThat(actualColumnDescriptions.get(2).getColumnLength(), is(-1));
+ assertThat(actualColumnDescriptions.get(3).getColumnName(),
is("literal_string"));
+ assertThat(actualColumnDescriptions.get(3).getTypeOID(),
is(PostgreSQLColumnType.POSTGRESQL_TYPE_VARCHAR.getValue()));
+ assertThat(actualColumnDescriptions.get(3).getColumnLength(), is(-1));
+ assertThat(actualColumnDescriptions.get(4).getColumnName(),
is("literal_int"));
+ assertThat(actualColumnDescriptions.get(4).getTypeOID(),
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue()));
+ assertThat(actualColumnDescriptions.get(4).getColumnLength(), is(4));
+ assertThat(actualColumnDescriptions.get(5).getColumnName(),
is("literal_bigint"));
+ assertThat(actualColumnDescriptions.get(5).getTypeOID(),
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT8.getValue()));
+ assertThat(actualColumnDescriptions.get(5).getColumnLength(), is(8));
+ assertThat(actualColumnDescriptions.get(6).getColumnName(),
is("literal_numeric"));
+ assertThat(actualColumnDescriptions.get(6).getTypeOID(),
is(PostgreSQLColumnType.POSTGRESQL_TYPE_NUMERIC.getValue()));
+ assertThat(actualColumnDescriptions.get(6).getColumnLength(), is(-1));
+ assertThat(actualColumnDescriptions.get(7).getColumnName(), is("id"));
+ assertThat(actualColumnDescriptions.get(7).getTypeOID(),
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue()));
+ assertThat(actualColumnDescriptions.get(7).getColumnLength(), is(4));
+ assertThat(actualColumnDescriptions.get(8).getColumnName(), is("k"));
+ assertThat(actualColumnDescriptions.get(8).getTypeOID(),
is(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue()));
+ assertThat(actualColumnDescriptions.get(8).getColumnLength(), is(4));
+ assertThat(actualColumnDescriptions.get(9).getColumnName(), is("c"));
+ assertThat(actualColumnDescriptions.get(9).getTypeOID(),
is(PostgreSQLColumnType.POSTGRESQL_TYPE_CHAR.getValue()));
+ assertThat(actualColumnDescriptions.get(9).getColumnLength(), is(-1));
+ assertThat(actualColumnDescriptions.get(10).getColumnName(),
is("pad"));
+ assertThat(actualColumnDescriptions.get(10).getTypeOID(),
is(PostgreSQLColumnType.POSTGRESQL_TYPE_CHAR.getValue()));
+ assertThat(actualColumnDescriptions.get(10).getColumnLength(), is(-1));
+ }
+
+ @SuppressWarnings("unchecked")
+ @SneakyThrows({NoSuchFieldException.class, IllegalAccessException.class})
+ private Collection<PostgreSQLColumnDescription>
getColumnDescriptionsFromPacket(final PostgreSQLRowDescriptionPacket packet) {
+ Field field =
PostgreSQLRowDescriptionPacket.class.getDeclaredField("columnDescriptions");
+ field.setAccessible(true);
+ return (Collection<PostgreSQLColumnDescription>) field.get(packet);
+ }
+
@SuppressWarnings("rawtypes")
@Test
public void assertDescribeSelectPreparedStatement() throws SQLException {