This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 8ff0322 Fix literals may be replaced by mistake in
PostgreSQL/openGauss protocol (#16370)
8ff0322 is described below
commit 8ff0322a2f4b31affc26cd2aec0e109a9a4398c6
Author: 吴伟杰 <[email protected]>
AuthorDate: Fri Mar 25 18:38:57 2022 +0800
Fix literals may be replaced by mistake in PostgreSQL/openGauss protocol
(#16370)
* Make EmptyStatement extend AbstractSQLStatement
* Refactor SQL rewriting in PostgreSQL protocol
* Avoid unused overhead of parsing empty statement
* Complete PostgreSQLComParseExecutorTest
---
.../extended/parse/PostgreSQLComParsePacket.java | 6 +-
.../extended/parse/PostgreSQLComParseExecutor.java | 38 ++++++---
.../parse/PostgreSQLComParseExecutorTest.java | 89 ++++++++++++++--------
.../sql/common/statement/dml/EmptyStatement.java | 4 +-
.../common/statement/dml/EmptyStatementTest.java} | 16 ++--
5 files changed, 98 insertions(+), 55 deletions(-)
diff --git
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/parse/PostgreSQLComParsePacket.java
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/parse/PostgreSQLComParsePacket.java
index 07f4989..b3ebd6c 100644
---
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/parse/PostgreSQLComParsePacket.java
+++
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/parse/PostgreSQLComParsePacket.java
@@ -45,11 +45,7 @@ public final class PostgreSQLComParsePacket extends
PostgreSQLCommandPacket {
this.payload = payload;
payload.readInt4();
statementId = payload.readStringNul();
- sql = alterSQLToJDBCStyle(payload.readStringNul());
- }
-
- private String alterSQLToJDBCStyle(final String sql) {
- return sql.replaceAll("\\$[0-9]+", "?");
+ sql = payload.readStringNul();
}
/**
diff --git
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
index fcc2b82..2386db8 100644
---
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
+++
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
@@ -30,12 +30,17 @@ import org.apache.shardingsphere.parser.rule.SQLParserRule;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import
org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
+import
org.apache.shardingsphere.sql.parser.sql.common.constant.ParameterMarkerType;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.EmptyStatement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
+import java.util.Comparator;
import java.util.List;
/**
@@ -50,21 +55,33 @@ public final class PostgreSQLComParseExecutor implements
CommandExecutor {
@Override
public Collection<DatabasePacket<?>> execute() {
- SQLStatement sqlStatement = parseSql(packet.getSql(),
connectionSession.getSchemaName());
+ ShardingSphereSQLParserEngine sqlParserEngine = null;
+ String sql = packet.getSql();
+ SQLStatement sqlStatement = sql.trim().isEmpty() ? new
EmptyStatement() : (sqlParserEngine =
createShardingSphereSQLParserEngine(connectionSession.getSchemaName())).parse(sql,
true);
+ if (sqlStatement.getParameterCount() > 0) {
+ sql = convertSQLToJDBCStyle(sqlStatement, sql);
+ sqlStatement = sqlParserEngine.parse(sql, true);
+ }
List<PostgreSQLColumnType> paddedColumnTypes =
paddingColumnTypes(sqlStatement.getParameterCount(),
packet.readParameterTypes());
-
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionSession.getConnectionId(),
packet.getStatementId(), packet.getSql(), sqlStatement, paddedColumnTypes);
+
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionSession.getConnectionId(),
packet.getStatementId(), sql, sqlStatement, paddedColumnTypes);
return
Collections.singletonList(PostgreSQLParseCompletePacket.getInstance());
}
- private SQLStatement parseSql(final String sql, final String schemaName) {
- if (sql.isEmpty()) {
- return new EmptyStatement();
- }
+ private ShardingSphereSQLParserEngine
createShardingSphereSQLParserEngine(final String schemaName) {
MetaDataContexts metaDataContexts =
ProxyContext.getInstance().getContextManager().getMetaDataContexts();
- ShardingSphereSQLParserEngine sqlStatementParserEngine = new
ShardingSphereSQLParserEngine(
-
DatabaseTypeRegistry.getTrunkDatabaseTypeName(metaDataContexts.getMetaData(schemaName).getResource().getDatabaseType()),
+ return new
ShardingSphereSQLParserEngine(DatabaseTypeRegistry.getTrunkDatabaseTypeName(metaDataContexts.getMetaData(schemaName).getResource().getDatabaseType()),
metaDataContexts.getGlobalRuleMetaData().findSingleRule(SQLParserRule.class).orElse(null));
- return sqlStatementParserEngine.parse(sql, true);
+ }
+
+ private String convertSQLToJDBCStyle(final SQLStatement sqlStatement,
final String sql) {
+ List<ParameterMarkerSegment> parameterMarkerSegments = new
ArrayList<>(((AbstractSQLStatement) sqlStatement).getParameterMarkerSegments());
+
parameterMarkerSegments.sort(Comparator.comparingInt(SQLSegment::getStopIndex));
+ StringBuilder result = new StringBuilder(sql);
+ for (int i = parameterMarkerSegments.size() - 1; i >= 0; i--) {
+ ParameterMarkerSegment each = parameterMarkerSegments.get(i);
+ result.replace(each.getStartIndex(), each.getStopIndex() + 1,
ParameterMarkerType.QUESTION.getMarker());
+ }
+ return result.toString();
}
private List<PostgreSQLColumnType> paddingColumnTypes(final int
parameterCount, final List<PostgreSQLColumnType> specifiedColumnTypes) {
@@ -73,7 +90,8 @@ public final class PostgreSQLComParseExecutor implements
CommandExecutor {
}
List<PostgreSQLColumnType> result = new ArrayList<>(parameterCount);
result.addAll(specifiedColumnTypes);
- for (int i = 0; i < parameterCount; i++) {
+ int unspecifiedCount = parameterCount - specifiedColumnTypes.size();
+ for (int i = 0; i < unspecifiedCount; i++) {
result.add(PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED);
}
return result;
diff --git
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
index fa22e6c..7a95baf 100644
---
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
+++
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
@@ -17,37 +17,40 @@
package
org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.parse;
+import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
+import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
+import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLPreparedStatement;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLPreparedStatementRegistry;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLComParsePacket;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLParseCompletePacket;
-import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.database.type.dialect.MySQLDatabaseType;
-import org.apache.shardingsphere.infra.executor.kernel.ExecutorEngine;
-import
org.apache.shardingsphere.infra.federation.optimizer.context.OptimizerContext;
+import
org.apache.shardingsphere.infra.database.type.dialect.PostgreSQLDatabaseType;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
-import
org.apache.shardingsphere.infra.metadata.rule.ShardingSphereRuleMetaData;
import org.apache.shardingsphere.mode.manager.ContextManager;
-import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
-import org.apache.shardingsphere.mode.metadata.persist.MetaDataPersistService;
import org.apache.shardingsphere.parser.rule.SQLParserRule;
import
org.apache.shardingsphere.parser.rule.builder.DefaultSQLParserRuleConfigurationBuilder;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.EmptyStatement;
+import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLInsertStatement;
+import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
+import org.mockito.Answers;
+import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-import java.lang.reflect.Field;
+import java.util.Arrays;
+import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
-import java.util.Properties;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@@ -55,6 +58,8 @@ import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public final class PostgreSQLComParseExecutorTest {
+ private static final int CONNECTION_ID = 1;
+
private final SQLParserRule sqlParserRule = new SQLParserRule(new
DefaultSQLParserRuleConfigurationBuilder().build());
@Mock
@@ -63,28 +68,55 @@ public final class PostgreSQLComParseExecutorTest {
@Mock
private ConnectionSession connectionSession;
+ @InjectMocks
+ private PostgreSQLComParseExecutor executor;
+
+ private ContextManager contextManagerBefore;
+
+ @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+ private ContextManager mockedContextManager;
+
@Before
public void setup() {
- PostgreSQLPreparedStatementRegistry.getInstance().register(1);
- PostgreSQLPreparedStatementRegistry.getInstance().register(1, "2", "",
new EmptyStatement(), Collections.emptyList());
- when(connectionSession.getConnectionId()).thenReturn(1);
+ contextManagerBefore = ProxyContext.getInstance().getContextManager();
+ ProxyContext.getInstance().init(mockedContextManager);
+
PostgreSQLPreparedStatementRegistry.getInstance().register(CONNECTION_ID);
+ when(connectionSession.getConnectionId()).thenReturn(CONNECTION_ID);
}
@Test
- public void assertNewInstance() throws NoSuchFieldException,
IllegalAccessException {
- when(parsePacket.getSql()).thenReturn("SELECT 1");
- when(parsePacket.getStatementId()).thenReturn("2");
+ public void assertExecuteWithEmptySQL() {
+ final String expectedSQL = "";
+ final String statementId = "S_1";
+ when(parsePacket.getSql()).thenReturn(expectedSQL);
+ when(parsePacket.getStatementId()).thenReturn(statementId);
+ Collection<DatabasePacket<?>> actualPackets = executor.execute();
+ assertThat(actualPackets.size(), is(1));
+ assertThat(actualPackets.iterator().next(),
is(PostgreSQLParseCompletePacket.getInstance()));
+ PostgreSQLPreparedStatement actualPreparedStatement =
PostgreSQLPreparedStatementRegistry.getInstance().get(CONNECTION_ID,
statementId);
+ assertTrue(actualPreparedStatement.getSqlStatement() instanceof
EmptyStatement);
+ assertThat(actualPreparedStatement.getSql(), is(expectedSQL));
+ assertThat(actualPreparedStatement.getParameterTypes(),
is(Collections.emptyList()));
+ }
+
+ @Test
+ public void assertExecuteWithParameterizedSQL() {
+ final String rawSQL = "/*$0*/insert into sbtest1 /* $1 */ -- $2 \n
(id, k, c, pad) \r values \r\n($1, $2, 'apsbd$31a', '$99')/*$0*/ \n--$0";
+ final String expectedSQL = "/*$0*/insert into sbtest1 /* $1 */ -- $2
\n (id, k, c, pad) \r values \r\n(?, ?, 'apsbd$31a', '$99')/*$0*/ \n--$0";
+ final String statementId = "S_2";
+ when(parsePacket.getSql()).thenReturn(rawSQL);
+ when(parsePacket.getStatementId()).thenReturn(statementId);
+
when(parsePacket.readParameterTypes()).thenReturn(Collections.singletonList(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4));
when(connectionSession.getSchemaName()).thenReturn("schema");
- Field contextManagerField =
ProxyContext.getInstance().getClass().getDeclaredField("contextManager");
- contextManagerField.setAccessible(true);
- ContextManager contextManager = mock(ContextManager.class,
RETURNS_DEEP_STUBS);
- MetaDataContexts metaDataContexts = new
MetaDataContexts(mock(MetaDataPersistService.class), getMetaDataMap(),
- mock(ShardingSphereRuleMetaData.class),
mock(ExecutorEngine.class), mock(OptimizerContext.class), new
ConfigurationProperties(new Properties()));
-
when(contextManager.getMetaDataContexts()).thenReturn(metaDataContexts);
- contextManagerField.set(ProxyContext.getInstance(), contextManager);
-
when(contextManager.getMetaDataContexts().getGlobalRuleMetaData().findSingleRule(SQLParserRule.class)).thenReturn(Optional.of(sqlParserRule));
- PostgreSQLComParseExecutor actual = new
PostgreSQLComParseExecutor(parsePacket, connectionSession);
- assertThat(actual.execute().iterator().next(),
is(PostgreSQLParseCompletePacket.getInstance()));
+
when(mockedContextManager.getMetaDataContexts().getMetaData("schema").getResource().getDatabaseType()).thenReturn(new
PostgreSQLDatabaseType());
+
when(mockedContextManager.getMetaDataContexts().getGlobalRuleMetaData().findSingleRule(SQLParserRule.class)).thenReturn(Optional.of(sqlParserRule));
+ Collection<DatabasePacket<?>> actualPackets = executor.execute();
+ assertThat(actualPackets.size(), is(1));
+ assertThat(actualPackets.iterator().next(),
is(PostgreSQLParseCompletePacket.getInstance()));
+ PostgreSQLPreparedStatement actualPreparedStatement =
PostgreSQLPreparedStatementRegistry.getInstance().get(CONNECTION_ID,
statementId);
+ assertTrue(actualPreparedStatement.getSqlStatement() instanceof
PostgreSQLInsertStatement);
+ assertThat(actualPreparedStatement.getSql(), is(expectedSQL));
+ assertThat(actualPreparedStatement.getParameterTypes(),
is(Arrays.asList(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4,
PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED)));
}
private Map<String, ShardingSphereMetaData> getMetaDataMap() {
@@ -93,11 +125,8 @@ public final class PostgreSQLComParseExecutorTest {
return Collections.singletonMap("schema", metaData);
}
- @Test
- public void assertGetSqlWithNull() {
- when(parsePacket.getStatementId()).thenReturn("");
- when(parsePacket.getSql()).thenReturn("");
- PostgreSQLComParseExecutor actual = new
PostgreSQLComParseExecutor(parsePacket, connectionSession);
- assertThat(actual.execute().iterator().next(),
is(PostgreSQLParseCompletePacket.getInstance()));
+ @After
+ public void tearDown() {
+ ProxyContext.getInstance().init(contextManagerBefore);
}
}
diff --git
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
index c6799d9..002eb1c 100644
---
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
+++
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
@@ -17,12 +17,12 @@
package org.apache.shardingsphere.sql.parser.sql.common.statement.dml;
-import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
+import
org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;
/**
* Empty statement.
*/
-public final class EmptyStatement implements SQLStatement {
+public final class EmptyStatement extends AbstractSQLStatement {
@Override
public int getParameterCount() {
diff --git
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatementTest.java
similarity index 76%
copy from
shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
copy to
shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatementTest.java
index c6799d9..692bf05 100644
---
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
+++
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatementTest.java
@@ -17,15 +17,15 @@
package org.apache.shardingsphere.sql.parser.sql.common.statement.dml;
-import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
+import org.junit.Test;
-/**
- * Empty statement.
- */
-public final class EmptyStatement implements SQLStatement {
+import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertThat;
+
+public final class EmptyStatementTest {
- @Override
- public int getParameterCount() {
- return 0;
+ @Test
+ public void assertGetParameterCount() {
+ assertThat(new EmptyStatement().getParameterCount(), is(0));
}
}