This is an automated email from the ASF dual-hosted git repository.
zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 1511687 Supports describe insert statement without column in
PostgreSQL Proxy (#14808)
1511687 is described below
commit 15116870ad904fa6403fb71b5a9187686bf0e824
Author: 吴伟杰 <[email protected]>
AuthorDate: Sun Jan 16 23:13:28 2022 +0800
Supports describe insert statement without column in PostgreSQL Proxy
(#14808)
* Supports describe insert statement without columns
* Complete tests for PostgreSQLComDescribeExecutor
---
.../describe/PostgreSQLComDescribeExecutor.java | 16 +-
.../PostgreSQLComDescribeExecutorTest.java | 192 ++++++++++++++++++++-
2 files changed, 196 insertions(+), 12 deletions(-)
diff --git
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
index f149e64..447cf6d 100644
---
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
+++
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java
@@ -44,7 +44,6 @@ import
org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import
org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
import
org.apache.shardingsphere.proxy.frontend.postgresql.command.PostgreSQLConnectionContext;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
-import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
@@ -64,6 +63,7 @@ import java.util.ListIterator;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
+import java.util.stream.Collectors;
/**
* Command describe for PostgreSQL.
@@ -128,7 +128,12 @@ public final class PostgreSQLComDescribeExecutor
implements CommandExecutor {
String logicTableName =
insertStatement.getTable().getTableName().getIdentifier().getValue();
TableMetaData tableMetaData =
ProxyContext.getInstance().getMetaData(schemaName).getSchema().get(logicTableName);
Map<String, ColumnMetaData> columnMetaData =
tableMetaData.getColumns();
- List<ColumnSegment> columns = new
ArrayList<>(insertStatement.getColumns());
+ List<String> columnNames;
+ if (insertStatement.getColumns().isEmpty()) {
+ columnNames = new ArrayList<>(tableMetaData.getColumns().keySet());
+ } else {
+ columnNames = insertStatement.getColumns().stream().map(each ->
each.getIdentifier().getValue()).collect(Collectors.toList());
+ }
Iterator<InsertValuesSegment> iterator =
insertStatement.getValues().iterator();
int parameterMarkerIndex = 0;
while (iterator.hasNext()) {
@@ -136,11 +141,14 @@ public final class PostgreSQLComDescribeExecutor
implements CommandExecutor {
ListIterator<ExpressionSegment> listIterator =
each.getValues().listIterator();
for (int columnIndex = listIterator.nextIndex();
listIterator.hasNext(); columnIndex = listIterator.nextIndex()) {
ExpressionSegment value = listIterator.next();
- if (!(value instanceof ParameterMarkerExpressionSegment) ||
!unspecifiedTypeParameterIndexes.contains(parameterMarkerIndex)) {
+ if (!(value instanceof ParameterMarkerExpressionSegment)) {
+ continue;
+ }
+ if
(!unspecifiedTypeParameterIndexes.contains(parameterMarkerIndex)) {
parameterMarkerIndex++;
continue;
}
- String columnName =
columns.get(columnIndex).getIdentifier().getValue();
+ String columnName = columnNames.get(columnIndex);
PostgreSQLColumnType parameterType =
PostgreSQLColumnType.valueOfJDBCType(columnMetaData.get(columnName).getDataType());
preparedStatement.getParameterTypes().set(parameterMarkerIndex++,
parameterType);
}
diff --git
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
index cadafc8..8f7cdbe 100644
---
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
+++
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java
@@ -17,36 +17,79 @@
package
org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.describe;
+import lombok.SneakyThrows;
import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
+import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLColumnDescription;
+import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLNoDataPacket;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLParameterDescriptionPacket;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLRowDescriptionPacket;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLPreparedStatementRegistry;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.describe.PostgreSQLComDescribePacket;
import
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
+import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
+import
org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
+import org.apache.shardingsphere.infra.metadata.schema.model.ColumnMetaData;
+import org.apache.shardingsphere.infra.metadata.schema.model.TableMetaData;
+import org.apache.shardingsphere.infra.parser.ShardingSphereSQLParserEngine;
+import org.apache.shardingsphere.mode.manager.ContextManager;
+import org.apache.shardingsphere.parser.rule.SQLParserRule;
+import
org.apache.shardingsphere.parser.rule.builder.DefaultSQLParserRuleConfigurationBuilder;
+import
org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.JDBCBackendConnection;
+import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import
org.apache.shardingsphere.proxy.frontend.postgresql.command.PostgreSQLConnectionContext;
import
org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.PostgreSQLPortal;
-import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLSelectStatement;
+import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
+import org.junit.After;
+import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
+import org.mockito.Answers;
+import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
+import java.lang.reflect.Field;
+import java.sql.Connection;
+import java.sql.ParameterMetaData;
+import java.sql.ResultSetMetaData;
import java.sql.SQLException;
+import java.sql.Types;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.nullable;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public final class PostgreSQLComDescribeExecutorTest {
+ private static final String SCHEMA_NAME = "postgres";
+
+ private static final String TABLE_NAME = "t_order";
+
+ private static final ShardingSphereSQLParserEngine SQL_PARSER_ENGINE = new
ShardingSphereSQLParserEngine("PostgreSQL", new SQLParserRule(new
DefaultSQLParserRuleConfigurationBuilder().build()));
+
+ private ContextManager contextManagerBefore;
+
+ @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+ private ContextManager mockContextManager;
+
@Mock
private PostgreSQLConnectionContext connectionContext;
@@ -56,6 +99,28 @@ public final class PostgreSQLComDescribeExecutorTest {
@Mock
private ConnectionSession connectionSession;
+ @InjectMocks
+ private PostgreSQLComDescribeExecutor executor;
+
+ @Before
+ public void setup() {
+ contextManagerBefore = ProxyContext.getInstance().getContextManager();
+ ProxyContext.getInstance().init(mockContextManager);
+
when(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getProps().getValue(ConfigurationPropertyKey.SQL_SHOW)).thenReturn(false);
+ when(connectionSession.getSchemaName()).thenReturn(SCHEMA_NAME);
+
when(mockContextManager.getMetaDataContexts().getAllSchemaNames().contains(SCHEMA_NAME)).thenReturn(true);
+ prepareTableMetaData();
+ }
+
+ private void prepareTableMetaData() {
+ Collection<ColumnMetaData> columnMetaData = Arrays.asList(
+ new ColumnMetaData("id", Types.INTEGER, true, false, false),
+ new ColumnMetaData("k", Types.INTEGER, true, false, false),
+ new ColumnMetaData("c", Types.CHAR, true, false, false),
+ new ColumnMetaData("pad", Types.CHAR, true, false, false));
+
when(mockContextManager.getMetaDataContexts().getMetaData(SCHEMA_NAME).getSchema().get(TABLE_NAME)).thenReturn(new
TableMetaData(TABLE_NAME, columnMetaData, Collections.emptyList()));
+ }
+
@Test
public void assertDescribePortal() throws SQLException {
when(packet.getType()).thenReturn('P');
@@ -64,30 +129,141 @@ public final class PostgreSQLComDescribeExecutorTest {
PostgreSQLRowDescriptionPacket expected =
mock(PostgreSQLRowDescriptionPacket.class);
when(portal.describe()).thenReturn(expected);
when(connectionContext.getPortal("P_1")).thenReturn(portal);
- Collection<DatabasePacket<?>> actual = new
PostgreSQLComDescribeExecutor(connectionContext, packet,
connectionSession).execute();
+ Collection<DatabasePacket<?>> actual = executor.execute();
assertThat(actual.size(), is(1));
assertThat(actual.iterator().next(), is(expected));
}
@Test
- public void assertDescribePreparedStatement() throws SQLException {
+ public void assertDescribePreparedStatementInsertWithoutColumns() throws
SQLException {
+ when(packet.getType()).thenReturn('S');
+ final String statementId = "S_1";
+ when(packet.getName()).thenReturn(statementId);
+ final int connectionId = 1;
+ when(connectionSession.getConnectionId()).thenReturn(connectionId);
+ String sql = "insert into t_order values (?, 0, 'char', ?), (2, ?, ?,
'')";
+ SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionId);
+ List<PostgreSQLColumnType> parameterTypes = new
ArrayList<>(sqlStatement.getParameterCount());
+ for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
+
parameterTypes.add(PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED);
+ }
+
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionId,
statementId, sql, sqlStatement, parameterTypes);
+ Collection<DatabasePacket<?>> actualPackets = executor.execute();
+ assertThat(actualPackets.size(), is(2));
+ Iterator<DatabasePacket<?>> actualPacketsIterator =
actualPackets.iterator();
+ PostgreSQLParameterDescriptionPacket actualParameterDescription =
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
+ PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
+ actualParameterDescription.write(mockPayload);
+ verify(mockPayload).writeInt2(4);
+ verify(mockPayload, times(2)).writeInt4(23);
+ verify(mockPayload, times(2)).writeInt4(18);
+ assertTrue(actualPacketsIterator.next() instanceof
PostgreSQLNoDataPacket);
+ }
+
+ @Test
+ public void assertDescribePreparedStatementInsertWithColumns() throws
SQLException {
when(packet.getType()).thenReturn('S');
- when(packet.getName()).thenReturn("S_1");
+ final String statementId = "S_2";
+ when(packet.getName()).thenReturn(statementId);
+ final int connectionId = 1;
+ when(connectionSession.getConnectionId()).thenReturn(connectionId);
+ String sql = "insert into t_order (id, k, c, pad) values (1, ?, ?, ?),
(?, 2, ?, '')";
+ SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionId);
+ List<PostgreSQLColumnType> parameterTypes = new
ArrayList<>(sqlStatement.getParameterCount());
+ for (int i = 0; i < sqlStatement.getParameterCount(); i++) {
+
parameterTypes.add(PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED);
+ }
+
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionId,
statementId, sql, sqlStatement, parameterTypes);
+ Collection<DatabasePacket<?>> actualPackets = executor.execute();
+ assertThat(actualPackets.size(), is(2));
+ Iterator<DatabasePacket<?>> actualPacketsIterator =
actualPackets.iterator();
+ PostgreSQLParameterDescriptionPacket actualParameterDescription =
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
+ PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
+ actualParameterDescription.write(mockPayload);
+ verify(mockPayload).writeInt2(5);
+ verify(mockPayload, times(2)).writeInt4(23);
+ verify(mockPayload, times(3)).writeInt4(18);
+ assertTrue(actualPacketsIterator.next() instanceof
PostgreSQLNoDataPacket);
+ }
+
+ @Test
+ public void assertDescribeSelectPreparedStatement() throws SQLException {
+ when(packet.getType()).thenReturn('S');
+ String statementId = "S_3";
+ when(packet.getName()).thenReturn(statementId);
when(connectionSession.getConnectionId()).thenReturn(1);
+ final String sql = "select id, k, c, pad from t_order where id = ?";
+ SQLStatement sqlStatement = SQL_PARSER_ENGINE.parse(sql, false);
+ prepareJDBCBackendConnection(sql);
PostgreSQLPreparedStatementRegistry.getInstance().register(1);
- PostgreSQLPreparedStatementRegistry.getInstance().register(1, "S_1",
"", new PostgreSQLSelectStatement(),
Collections.singletonList(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4));
- Collection<DatabasePacket<?>> actual = new
PostgreSQLComDescribeExecutor(connectionContext, packet,
connectionSession).execute();
- assertThat(actual.size(), is(1));
- PostgreSQLParameterDescriptionPacket actualParameterDescription =
(PostgreSQLParameterDescriptionPacket) actual.iterator().next();
+ List<PostgreSQLColumnType> parameterTypes = new
ArrayList<>(Collections.singletonList(PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED));
+ PostgreSQLPreparedStatementRegistry.getInstance().register(1,
statementId, sql, sqlStatement, parameterTypes);
+ Collection<DatabasePacket<?>> actual = executor.execute();
+ assertThat(actual.size(), is(2));
+ Iterator<DatabasePacket<?>> actualPacketsIterator = actual.iterator();
+ PostgreSQLParameterDescriptionPacket actualParameterDescription =
(PostgreSQLParameterDescriptionPacket) actualPacketsIterator.next();
assertThat(actualParameterDescription,
instanceOf(PostgreSQLParameterDescriptionPacket.class));
PostgreSQLPacketPayload mockPayload =
mock(PostgreSQLPacketPayload.class);
actualParameterDescription.write(mockPayload);
verify(mockPayload).writeInt2(1);
verify(mockPayload).writeInt4(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4.getValue());
+ PostgreSQLRowDescriptionPacket actualRowDescription =
(PostgreSQLRowDescriptionPacket) actualPacketsIterator.next();
+ List<PostgreSQLColumnDescription> actualColumnDescriptions =
getColumnDescriptions(actualRowDescription);
+ List<PostgreSQLColumnDescription> expectedColumnDescriptions =
Arrays.asList(
+ new PostgreSQLColumnDescription("id", 1, Types.INTEGER, 11,
"int4"),
+ new PostgreSQLColumnDescription("k", 2, Types.INTEGER, 11,
"int4"),
+ new PostgreSQLColumnDescription("c", 3, Types.CHAR, 60,
"int4"),
+ new PostgreSQLColumnDescription("pad", 4, Types.CHAR, 120,
"int4")
+ );
+ for (int i = 0; i < expectedColumnDescriptions.size(); i++) {
+ PostgreSQLColumnDescription expectedColumnDescription =
expectedColumnDescriptions.get(i);
+ PostgreSQLColumnDescription actualColumnDescription =
actualColumnDescriptions.get(i);
+ assertThat(actualColumnDescription.getColumnName(),
is(expectedColumnDescription.getColumnName()));
+ assertThat(actualColumnDescription.getColumnIndex(),
is(expectedColumnDescription.getColumnIndex()));
+ assertThat(actualColumnDescription.getColumnLength(),
is(expectedColumnDescription.getColumnLength()));
+ assertThat(actualColumnDescription.getTypeOID(),
is(expectedColumnDescription.getTypeOID()));
+ }
+ }
+
+ private void prepareJDBCBackendConnection(final String sql) throws
SQLException {
+ JDBCBackendConnection backendConnection =
mock(JDBCBackendConnection.class);
+ Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS);
+ ParameterMetaData parameterMetaData = mock(ParameterMetaData.class);
+ when(parameterMetaData.getParameterType(1)).thenReturn(Types.INTEGER);
+
when(connection.prepareStatement(sql).getParameterMetaData()).thenReturn(parameterMetaData);
+ ResultSetMetaData resultSetMetaData = prepareResultSetMetaData();
+
when(connection.prepareStatement(sql).getMetaData()).thenReturn(resultSetMetaData);
+ when(backendConnection.getConnections(nullable(String.class),
anyInt(),
any(ConnectionMode.class))).thenReturn(Collections.singletonList(connection));
+
when(connectionSession.getBackendConnection()).thenReturn(backendConnection);
+ }
+
+ private ResultSetMetaData prepareResultSetMetaData() throws SQLException {
+ ResultSetMetaData result = mock(ResultSetMetaData.class);
+ when(result.getColumnCount()).thenReturn(4);
+ when(result.getColumnName(anyInt())).thenReturn("id", "k", "c", "pad");
+ when(result.getColumnType(anyInt())).thenReturn(Types.INTEGER,
Types.INTEGER, Types.CHAR, Types.CHAR);
+ when(result.getColumnDisplaySize(anyInt())).thenReturn(11, 11, 60,
120);
+ when(result.getColumnTypeName(anyInt())).thenReturn("int4", "int4",
"char", "char");
+ return result;
+ }
+
+ @SuppressWarnings("unchecked")
+ @SneakyThrows
+ private List<PostgreSQLColumnDescription> getColumnDescriptions(final
PostgreSQLRowDescriptionPacket packet) {
+ Field columnDescriptionsField =
PostgreSQLRowDescriptionPacket.class.getDeclaredField("columnDescriptions");
+ columnDescriptionsField.setAccessible(true);
+ return (List<PostgreSQLColumnDescription>)
columnDescriptionsField.get(packet);
}
@Test(expected = UnsupportedOperationException.class)
public void assertDescribeUnknownType() throws SQLException {
new PostgreSQLComDescribeExecutor(connectionContext, packet,
connectionSession).execute();
}
+
+ @After
+ public void tearDown() {
+ ProxyContext.getInstance().init(contextManagerBefore);
+ }
}