This is an automated email from the ASF dual-hosted git repository.

panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 8ff0322  Fix literals may be replaced by mistake in 
PostgreSQL/openGauss protocol (#16370)
8ff0322 is described below

commit 8ff0322a2f4b31affc26cd2aec0e109a9a4398c6
Author: 吴伟杰 <[email protected]>
AuthorDate: Fri Mar 25 18:38:57 2022 +0800

    Fix literals may be replaced by mistake in PostgreSQL/openGauss protocol 
(#16370)
    
    * Make EmptyStatement extend AbstractSQLStatement
    
    * Refactor SQL rewriting in PostgreSQL protocol
    
    * Avoid unused overhead of parsing empty statement
    
    * Complete PostgreSQLComParseExecutorTest
---
 .../extended/parse/PostgreSQLComParsePacket.java   |  6 +-
 .../extended/parse/PostgreSQLComParseExecutor.java | 38 ++++++---
 .../parse/PostgreSQLComParseExecutorTest.java      | 89 ++++++++++++++--------
 .../sql/common/statement/dml/EmptyStatement.java   |  4 +-
 .../common/statement/dml/EmptyStatementTest.java}  | 16 ++--
 5 files changed, 98 insertions(+), 55 deletions(-)

diff --git 
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/parse/PostgreSQLComParsePacket.java
 
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/parse/PostgreSQLComParsePacket.java
index 07f4989..b3ebd6c 100644
--- 
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/parse/PostgreSQLComParsePacket.java
+++ 
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/parse/PostgreSQLComParsePacket.java
@@ -45,11 +45,7 @@ public final class PostgreSQLComParsePacket extends 
PostgreSQLCommandPacket {
         this.payload = payload;
         payload.readInt4();
         statementId = payload.readStringNul();
-        sql = alterSQLToJDBCStyle(payload.readStringNul());
-    }
-    
-    private String alterSQLToJDBCStyle(final String sql) {
-        return sql.replaceAll("\\$[0-9]+", "?");
+        sql = payload.readStringNul();
     }
     
     /**
diff --git 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
index fcc2b82..2386db8 100644
--- 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
+++ 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
@@ -30,12 +30,17 @@ import org.apache.shardingsphere.parser.rule.SQLParserRule;
 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
 import 
org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
+import 
org.apache.shardingsphere.sql.parser.sql.common.constant.ParameterMarkerType;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;
 import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.EmptyStatement;
 
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.List;
 
 /**
@@ -50,21 +55,33 @@ public final class PostgreSQLComParseExecutor implements 
CommandExecutor {
     
     @Override
     public Collection<DatabasePacket<?>> execute() {
-        SQLStatement sqlStatement = parseSql(packet.getSql(), 
connectionSession.getSchemaName());
+        ShardingSphereSQLParserEngine sqlParserEngine = null;
+        String sql = packet.getSql();
+        SQLStatement sqlStatement = sql.trim().isEmpty() ? new 
EmptyStatement() : (sqlParserEngine = 
createShardingSphereSQLParserEngine(connectionSession.getSchemaName())).parse(sql,
 true);
+        if (sqlStatement.getParameterCount() > 0) {
+            sql = convertSQLToJDBCStyle(sqlStatement, sql);
+            sqlStatement = sqlParserEngine.parse(sql, true);
+        }
         List<PostgreSQLColumnType> paddedColumnTypes = 
paddingColumnTypes(sqlStatement.getParameterCount(), 
packet.readParameterTypes());
-        
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionSession.getConnectionId(),
 packet.getStatementId(), packet.getSql(), sqlStatement, paddedColumnTypes);
+        
PostgreSQLPreparedStatementRegistry.getInstance().register(connectionSession.getConnectionId(),
 packet.getStatementId(), sql, sqlStatement, paddedColumnTypes);
         return 
Collections.singletonList(PostgreSQLParseCompletePacket.getInstance());
     }
     
-    private SQLStatement parseSql(final String sql, final String schemaName) {
-        if (sql.isEmpty()) {
-            return new EmptyStatement();
-        }
+    private ShardingSphereSQLParserEngine 
createShardingSphereSQLParserEngine(final String schemaName) {
         MetaDataContexts metaDataContexts = 
ProxyContext.getInstance().getContextManager().getMetaDataContexts();
-        ShardingSphereSQLParserEngine sqlStatementParserEngine = new 
ShardingSphereSQLParserEngine(
-                
DatabaseTypeRegistry.getTrunkDatabaseTypeName(metaDataContexts.getMetaData(schemaName).getResource().getDatabaseType()),
+        return new 
ShardingSphereSQLParserEngine(DatabaseTypeRegistry.getTrunkDatabaseTypeName(metaDataContexts.getMetaData(schemaName).getResource().getDatabaseType()),
                 
metaDataContexts.getGlobalRuleMetaData().findSingleRule(SQLParserRule.class).orElse(null));
-        return sqlStatementParserEngine.parse(sql, true);
+    }
+    
+    private String convertSQLToJDBCStyle(final SQLStatement sqlStatement, 
final String sql) {
+        List<ParameterMarkerSegment> parameterMarkerSegments = new 
ArrayList<>(((AbstractSQLStatement) sqlStatement).getParameterMarkerSegments());
+        
parameterMarkerSegments.sort(Comparator.comparingInt(SQLSegment::getStopIndex));
+        StringBuilder result = new StringBuilder(sql);
+        for (int i = parameterMarkerSegments.size() - 1; i >= 0; i--) {
+            ParameterMarkerSegment each = parameterMarkerSegments.get(i);
+            result.replace(each.getStartIndex(), each.getStopIndex() + 1, 
ParameterMarkerType.QUESTION.getMarker());
+        }
+        return result.toString();
     }
     
     private List<PostgreSQLColumnType> paddingColumnTypes(final int 
parameterCount, final List<PostgreSQLColumnType> specifiedColumnTypes) {
@@ -73,7 +90,8 @@ public final class PostgreSQLComParseExecutor implements 
CommandExecutor {
         }
         List<PostgreSQLColumnType> result = new ArrayList<>(parameterCount);
         result.addAll(specifiedColumnTypes);
-        for (int i = 0; i < parameterCount; i++) {
+        int unspecifiedCount = parameterCount - specifiedColumnTypes.size();
+        for (int i = 0; i < unspecifiedCount; i++) {
             result.add(PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED);
         }
         return result;
diff --git 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
index fa22e6c..7a95baf 100644
--- 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
+++ 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
@@ -17,37 +17,40 @@
 
 package 
org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.parse;
 
+import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
+import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
+import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLPreparedStatement;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLPreparedStatementRegistry;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLComParsePacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLParseCompletePacket;
-import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
 import org.apache.shardingsphere.infra.database.type.dialect.MySQLDatabaseType;
-import org.apache.shardingsphere.infra.executor.kernel.ExecutorEngine;
-import 
org.apache.shardingsphere.infra.federation.optimizer.context.OptimizerContext;
+import 
org.apache.shardingsphere.infra.database.type.dialect.PostgreSQLDatabaseType;
 import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
-import 
org.apache.shardingsphere.infra.metadata.rule.ShardingSphereRuleMetaData;
 import org.apache.shardingsphere.mode.manager.ContextManager;
-import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
-import org.apache.shardingsphere.mode.metadata.persist.MetaDataPersistService;
 import org.apache.shardingsphere.parser.rule.SQLParserRule;
 import 
org.apache.shardingsphere.parser.rule.builder.DefaultSQLParserRuleConfigurationBuilder;
 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.EmptyStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLInsertStatement;
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.Answers;
+import org.mockito.InjectMocks;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 
-import java.lang.reflect.Field;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Map;
 import java.util.Optional;
-import java.util.Properties;
 
 import static org.hamcrest.CoreMatchers.is;
 import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -55,6 +58,8 @@ import static org.mockito.Mockito.when;
 @RunWith(MockitoJUnitRunner.class)
 public final class PostgreSQLComParseExecutorTest {
     
+    private static final int CONNECTION_ID = 1;
+    
     private final SQLParserRule sqlParserRule = new SQLParserRule(new 
DefaultSQLParserRuleConfigurationBuilder().build());
     
     @Mock
@@ -63,28 +68,55 @@ public final class PostgreSQLComParseExecutorTest {
     @Mock
     private ConnectionSession connectionSession;
     
+    @InjectMocks
+    private PostgreSQLComParseExecutor executor;
+    
+    private ContextManager contextManagerBefore;
+    
+    @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+    private ContextManager mockedContextManager;
+    
     @Before
     public void setup() {
-        PostgreSQLPreparedStatementRegistry.getInstance().register(1);
-        PostgreSQLPreparedStatementRegistry.getInstance().register(1, "2", "", 
new EmptyStatement(), Collections.emptyList());
-        when(connectionSession.getConnectionId()).thenReturn(1);
+        contextManagerBefore = ProxyContext.getInstance().getContextManager();
+        ProxyContext.getInstance().init(mockedContextManager);
+        
PostgreSQLPreparedStatementRegistry.getInstance().register(CONNECTION_ID);
+        when(connectionSession.getConnectionId()).thenReturn(CONNECTION_ID);
     }
     
     @Test
-    public void assertNewInstance() throws NoSuchFieldException, 
IllegalAccessException {
-        when(parsePacket.getSql()).thenReturn("SELECT 1");
-        when(parsePacket.getStatementId()).thenReturn("2");
+    public void assertExecuteWithEmptySQL() {
+        final String expectedSQL = "";
+        final String statementId = "S_1";
+        when(parsePacket.getSql()).thenReturn(expectedSQL);
+        when(parsePacket.getStatementId()).thenReturn(statementId);
+        Collection<DatabasePacket<?>> actualPackets = executor.execute();
+        assertThat(actualPackets.size(), is(1));
+        assertThat(actualPackets.iterator().next(), 
is(PostgreSQLParseCompletePacket.getInstance()));
+        PostgreSQLPreparedStatement actualPreparedStatement = 
PostgreSQLPreparedStatementRegistry.getInstance().get(CONNECTION_ID, 
statementId);
+        assertTrue(actualPreparedStatement.getSqlStatement() instanceof 
EmptyStatement);
+        assertThat(actualPreparedStatement.getSql(), is(expectedSQL));
+        assertThat(actualPreparedStatement.getParameterTypes(), 
is(Collections.emptyList()));
+    }
+    
+    @Test
+    public void assertExecuteWithParameterizedSQL() {
+        final String rawSQL = "/*$0*/insert into sbtest1 /* $1 */ -- $2 \n 
(id, k, c, pad) \r values \r\n($1, $2, 'apsbd$31a', '$99')/*$0*/ \n--$0";
+        final String expectedSQL = "/*$0*/insert into sbtest1 /* $1 */ -- $2 
\n (id, k, c, pad) \r values \r\n(?, ?, 'apsbd$31a', '$99')/*$0*/ \n--$0";
+        final String statementId = "S_2";
+        when(parsePacket.getSql()).thenReturn(rawSQL);
+        when(parsePacket.getStatementId()).thenReturn(statementId);
+        
when(parsePacket.readParameterTypes()).thenReturn(Collections.singletonList(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4));
         when(connectionSession.getSchemaName()).thenReturn("schema");
-        Field contextManagerField = 
ProxyContext.getInstance().getClass().getDeclaredField("contextManager");
-        contextManagerField.setAccessible(true);
-        ContextManager contextManager = mock(ContextManager.class, 
RETURNS_DEEP_STUBS);
-        MetaDataContexts metaDataContexts = new 
MetaDataContexts(mock(MetaDataPersistService.class), getMetaDataMap(),
-                mock(ShardingSphereRuleMetaData.class), 
mock(ExecutorEngine.class), mock(OptimizerContext.class), new 
ConfigurationProperties(new Properties()));
-        
when(contextManager.getMetaDataContexts()).thenReturn(metaDataContexts);
-        contextManagerField.set(ProxyContext.getInstance(), contextManager);
-        
when(contextManager.getMetaDataContexts().getGlobalRuleMetaData().findSingleRule(SQLParserRule.class)).thenReturn(Optional.of(sqlParserRule));
-        PostgreSQLComParseExecutor actual = new 
PostgreSQLComParseExecutor(parsePacket, connectionSession);
-        assertThat(actual.execute().iterator().next(), 
is(PostgreSQLParseCompletePacket.getInstance()));
+        
when(mockedContextManager.getMetaDataContexts().getMetaData("schema").getResource().getDatabaseType()).thenReturn(new
 PostgreSQLDatabaseType());
+        
when(mockedContextManager.getMetaDataContexts().getGlobalRuleMetaData().findSingleRule(SQLParserRule.class)).thenReturn(Optional.of(sqlParserRule));
+        Collection<DatabasePacket<?>> actualPackets = executor.execute();
+        assertThat(actualPackets.size(), is(1));
+        assertThat(actualPackets.iterator().next(), 
is(PostgreSQLParseCompletePacket.getInstance()));
+        PostgreSQLPreparedStatement actualPreparedStatement = 
PostgreSQLPreparedStatementRegistry.getInstance().get(CONNECTION_ID, 
statementId);
+        assertTrue(actualPreparedStatement.getSqlStatement() instanceof 
PostgreSQLInsertStatement);
+        assertThat(actualPreparedStatement.getSql(), is(expectedSQL));
+        assertThat(actualPreparedStatement.getParameterTypes(), 
is(Arrays.asList(PostgreSQLColumnType.POSTGRESQL_TYPE_INT4, 
PostgreSQLColumnType.POSTGRESQL_TYPE_UNSPECIFIED)));
     }
     
     private Map<String, ShardingSphereMetaData> getMetaDataMap() {
@@ -93,11 +125,8 @@ public final class PostgreSQLComParseExecutorTest {
         return Collections.singletonMap("schema", metaData);
     }
     
-    @Test
-    public void assertGetSqlWithNull() {
-        when(parsePacket.getStatementId()).thenReturn("");
-        when(parsePacket.getSql()).thenReturn("");
-        PostgreSQLComParseExecutor actual = new 
PostgreSQLComParseExecutor(parsePacket, connectionSession);
-        assertThat(actual.execute().iterator().next(), 
is(PostgreSQLParseCompletePacket.getInstance()));
+    @After
+    public void tearDown() {
+        ProxyContext.getInstance().init(contextManagerBefore);
     }
 }
diff --git 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
index c6799d9..002eb1c 100644
--- 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
+++ 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
@@ -17,12 +17,12 @@
 
 package org.apache.shardingsphere.sql.parser.sql.common.statement.dml;
 
-import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;
 
 /**
  * Empty statement.
  */
-public final class EmptyStatement implements SQLStatement {
+public final class EmptyStatement extends AbstractSQLStatement {
     
     @Override
     public int getParameterCount() {
diff --git 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatementTest.java
similarity index 76%
copy from 
shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
copy to 
shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatementTest.java
index c6799d9..692bf05 100644
--- 
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatement.java
+++ 
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/EmptyStatementTest.java
@@ -17,15 +17,15 @@
 
 package org.apache.shardingsphere.sql.parser.sql.common.statement.dml;
 
-import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
+import org.junit.Test;
 
-/**
- * Empty statement.
- */
-public final class EmptyStatement implements SQLStatement {
+import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertThat;
+
+public final class EmptyStatementTest {
     
-    @Override
-    public int getParameterCount() {
-        return 0;
+    @Test
+    public void assertGetParameterCount() {
+        assertThat(new EmptyStatement().getParameterCount(), is(0));
     }
 }

Reply via email to