This is an automated email from the ASF dual-hosted git repository.

wuweijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new a3288a0a569 escape pg ? operator in proxy (#27909)
a3288a0a569 is described below

commit a3288a0a56948d31710c661ae1744b95f8f50b79
Author: 亥时 <[email protected]>
AuthorDate: Tue Aug 15 18:50:54 2023 +0800

    escape pg ? operator in proxy (#27909)
    
    * escape pg ? operator in proxy
    
    * check style
    
    * only escape ? operator in dml
---
 .../query/extended/parse/PostgreSQLComParseExecutor.java  | 15 ++++++++++++++-
 .../extended/parse/PostgreSQLComParseExecutorTest.java    | 15 +++++++++++++++
 2 files changed, 29 insertions(+), 1 deletion(-)

diff --git 
a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
 
b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
index 77aeee2c08d..82a912c6183 100644
--- 
a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
+++ 
b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
@@ -23,8 +23,8 @@ import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.ext
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLComParsePacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLParseCompletePacket;
 import org.apache.shardingsphere.distsql.parser.statement.DistSQLStatement;
-import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
 import 
org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
+import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
 import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
 import org.apache.shardingsphere.infra.parser.SQLParserEngine;
 import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
@@ -39,6 +39,7 @@ import 
org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;
 import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DMLStatement;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -61,6 +62,11 @@ public final class PostgreSQLComParseExecutor implements 
CommandExecutor {
         SQLParserEngine sqlParserEngine = 
createShardingSphereSQLParserEngine(connectionSession.getDatabaseName());
         String sql = packet.getSQL();
         SQLStatement sqlStatement = sqlParserEngine.parse(sql, true);
+        String escapedSql = escape(sqlStatement, sql);
+        if (!escapedSql.equalsIgnoreCase(sql)) {
+            sqlStatement = sqlParserEngine.parse(escapedSql, true);
+            sql = escapedSql;
+        }
         List<Integer> actualParameterMarkerIndexes = new ArrayList<>();
         if (sqlStatement.getParameterCount() > 0) {
             List<ParameterMarkerSegment> parameterMarkerSegments = new 
ArrayList<>(((AbstractSQLStatement) sqlStatement).getParameterMarkerSegments());
@@ -87,6 +93,13 @@ public final class PostgreSQLComParseExecutor implements 
CommandExecutor {
         return 
sqlParserRule.getSQLParserEngine(protocolType.getTrunkDatabaseType().orElse(protocolType));
     }
     
+    private String escape(final SQLStatement sqlStatement, final String sql) {
+        if (sqlStatement instanceof DMLStatement) {
+            return sql.replace("?", "??");
+        }
+        return sql;
+    }
+    
     private String convertSQLToJDBCStyle(final List<ParameterMarkerSegment> 
parameterMarkerSegments, final String sql) {
         
parameterMarkerSegments.sort(Comparator.comparingInt(SQLSegment::getStopIndex));
         StringBuilder result = new StringBuilder(sql);
diff --git 
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
 
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
index 8fa7257db05..aa07e2ea81e 100644
--- 
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
+++ 
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
@@ -134,6 +134,21 @@ class PostgreSQLComParseExecutorTest {
         assertThat(actualPreparedStatement.getActualParameterMarkerIndexes(), 
is(Arrays.asList(1, 0)));
     }
     
+    @Test
+    void assetExecuteWithQuestionOperator() throws 
ReflectiveOperationException {
+        final String rawSQL = "update t_test set enabled = $1 where name ?& 
$2";
+        final String expectedSQL = "update t_test set enabled = ? where name 
??& ?";
+        final String statementId = "S_2";
+        when(parsePacket.getSQL()).thenReturn(rawSQL);
+        when(parsePacket.getStatementId()).thenReturn(statementId);
+        
Plugins.getMemberAccessor().set(PostgreSQLComParseExecutor.class.getDeclaredField("connectionSession"),
 executor, connectionSession);
+        ContextManager contextManager = mockContextManager();
+        
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+        executor.execute();
+        PostgreSQLServerPreparedStatement actualPreparedStatement = 
connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(statementId);
+        assertThat(actualPreparedStatement.getSql(), is(expectedSQL));
+    }
+    
     @Test
     void assertExecuteWithDistSQL() {
         String sql = "SHOW DIST VARIABLE WHERE NAME = sql_show";

Reply via email to