This is an automated email from the ASF dual-hosted git repository.

wuweijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f989592509 re-order pg parameters in jdbc style (#25988)
4f989592509 is described below

commit 4f9895925092b18a654aa6efa5976f689af9a0ae
Author: 亥时 <[email protected]>
AuthorDate: Sat Jun 3 19:22:12 2023 +0800

    re-order pg parameters in jdbc style (#25988)
---
 .../PostgreSQLServerPreparedStatement.java         | 36 ++++++++++++++++++++--
 .../extended/bind/PostgreSQLComBindExecutor.java   |  4 ++-
 .../extended/parse/PostgreSQLComParseExecutor.java | 11 ++++---
 .../bind/PostgreSQLComBindExecutorTest.java        | 33 +++++++++++++++++++-
 .../parse/PostgreSQLComParseExecutorTest.java      | 18 +++++++++++
 5 files changed, 94 insertions(+), 8 deletions(-)

diff --git 
a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLServerPreparedStatement.java
 
b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLServerPreparedStatement.java
index 68a28ef8160..38396559ee6 100644
--- 
a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLServerPreparedStatement.java
+++ 
b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLServerPreparedStatement.java
@@ -19,7 +19,6 @@ package 
org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extend
 
 import lombok.AccessLevel;
 import lombok.Getter;
-import lombok.RequiredArgsConstructor;
 import lombok.Setter;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLParameterDescriptionPacket;
@@ -27,13 +26,14 @@ import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.ext
 import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
 import org.apache.shardingsphere.proxy.backend.session.ServerPreparedStatement;
 
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Optional;
 
 /**
  * Prepared statement for PostgreSQL.
  */
-@RequiredArgsConstructor
 @Getter
 @Setter
 public final class PostgreSQLServerPreparedStatement implements 
ServerPreparedStatement {
@@ -44,9 +44,25 @@ public final class PostgreSQLServerPreparedStatement 
implements ServerPreparedSt
     
     private final List<PostgreSQLColumnType> parameterTypes;
     
+    private final List<Integer> actualParameterMarkerIndexes;
+    
     @Getter(AccessLevel.NONE)
     private PostgreSQLPacket rowDescription;
     
+    public PostgreSQLServerPreparedStatement(final String sql, final 
SQLStatementContext sqlStatementContext, final List<PostgreSQLColumnType> 
parameterTypes) {
+        this(sql, sqlStatementContext, parameterTypes, 
Collections.emptyList());
+    }
+    
+    public PostgreSQLServerPreparedStatement(final String sql,
+                                             final SQLStatementContext 
sqlStatementContext,
+                                             final List<PostgreSQLColumnType> 
parameterTypes,
+                                             final List<Integer> 
actualParameterMarkerIndexes) {
+        this.sql = sql;
+        this.sqlStatementContext = sqlStatementContext;
+        this.parameterTypes = parameterTypes;
+        this.actualParameterMarkerIndexes = actualParameterMarkerIndexes;
+    }
+    
     /**
      * Describe parameters of the prepared statement.
      *
@@ -64,4 +80,20 @@ public final class PostgreSQLServerPreparedStatement 
implements ServerPreparedSt
     public Optional<PostgreSQLPacket> describeRows() {
         return Optional.ofNullable(rowDescription);
     }
+    
+    /**
+     * Adjust Parameters order.
+     * @param parameters parameters in pg marker index order
+     * @return parameters in jdbc style marker index order
+     */
+    public List<Object> adjustParametersOrder(final List<Object> parameters) {
+        if (parameters == null || parameters.size() == 0) {
+            return parameters;
+        }
+        List<Object> reOrdered = new ArrayList<>(parameters.size());
+        for (Integer parameterMarkerIndex : actualParameterMarkerIndexes) {
+            reOrdered.add(parameters.get(parameterMarkerIndex));
+        }
+        return reOrdered;
+    }
 }
diff --git 
a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutor.java
 
b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutor.java
index e13b3b44b78..bb02e9ea49e 100644
--- 
a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutor.java
+++ 
b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutor.java
@@ -31,6 +31,7 @@ import 
org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extende
 import java.sql.SQLException;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.List;
 
 /**
  * Command bind executor for PostgreSQL.
@@ -48,7 +49,8 @@ public final class PostgreSQLComBindExecutor implements 
CommandExecutor {
     public Collection<DatabasePacket> execute() throws SQLException {
         PostgreSQLServerPreparedStatement preparedStatement = 
connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(packet.getStatementId());
         ProxyDatabaseConnectionManager databaseConnectionManager = 
connectionSession.getDatabaseConnectionManager();
-        Portal portal = new Portal(packet.getPortal(), preparedStatement, 
packet.readParameters(preparedStatement.getParameterTypes()), 
packet.readResultFormats(), databaseConnectionManager);
+        List<Object> parameters = 
preparedStatement.adjustParametersOrder(packet.readParameters(preparedStatement.getParameterTypes()));
+        Portal portal = new Portal(packet.getPortal(), preparedStatement, 
parameters, packet.readResultFormats(), databaseConnectionManager);
         portalContext.add(portal);
         portal.bind();
         return 
Collections.singleton(PostgreSQLBindCompletePacket.getInstance());
diff --git 
a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
 
b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
index d132ce68ef4..cb6d5b1f7aa 100644
--- 
a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
+++ 
b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutor.java
@@ -45,6 +45,7 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.List;
+import java.util.stream.Collectors;
 
 /**
  * PostgreSQL command parse executor.
@@ -61,15 +62,18 @@ public final class PostgreSQLComParseExecutor implements 
CommandExecutor {
         SQLParserEngine sqlParserEngine = 
createShardingSphereSQLParserEngine(connectionSession.getDatabaseName());
         String sql = packet.getSQL();
         SQLStatement sqlStatement = sqlParserEngine.parse(sql, true);
+        List<Integer> actualParameterMarkerIndexes = new ArrayList<>();
         if (sqlStatement.getParameterCount() > 0) {
-            sql = convertSQLToJDBCStyle(sqlStatement, sql);
+            List<ParameterMarkerSegment> parameterMarkerSegments = new 
ArrayList<>(((AbstractSQLStatement) sqlStatement).getParameterMarkerSegments());
+            
actualParameterMarkerIndexes.addAll(parameterMarkerSegments.stream().map(ParameterMarkerSegment::getParameterIndex).collect(Collectors.toList()));
+            sql = convertSQLToJDBCStyle(parameterMarkerSegments, sql);
             sqlStatement = sqlParserEngine.parse(sql, true);
         }
         List<PostgreSQLColumnType> paddedColumnTypes = 
paddingColumnTypes(sqlStatement.getParameterCount(), 
packet.readParameterTypes());
         SQLStatementContext sqlStatementContext = sqlStatement instanceof 
DistSQLStatement ? new DistSQLStatementContext((DistSQLStatement) sqlStatement)
                 : 
SQLStatementContextFactory.newInstance(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData(),
                         sqlStatement, 
connectionSession.getDefaultDatabaseName());
-        PostgreSQLServerPreparedStatement serverPreparedStatement = new 
PostgreSQLServerPreparedStatement(sql, sqlStatementContext, paddedColumnTypes);
+        PostgreSQLServerPreparedStatement serverPreparedStatement = new 
PostgreSQLServerPreparedStatement(sql, sqlStatementContext, paddedColumnTypes, 
actualParameterMarkerIndexes);
         
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(packet.getStatementId(),
 serverPreparedStatement);
         return 
Collections.singleton(PostgreSQLParseCompletePacket.getInstance());
     }
@@ -80,8 +84,7 @@ public final class PostgreSQLComParseExecutor implements 
CommandExecutor {
         return 
sqlParserRule.getSQLParserEngine(DatabaseTypeEngine.getTrunkDatabaseTypeName(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType()));
     }
     
-    private String convertSQLToJDBCStyle(final SQLStatement sqlStatement, 
final String sql) {
-        List<ParameterMarkerSegment> parameterMarkerSegments = new 
ArrayList<>(((AbstractSQLStatement) sqlStatement).getParameterMarkerSegments());
+    private String convertSQLToJDBCStyle(final List<ParameterMarkerSegment> 
parameterMarkerSegments, final String sql) {
         
parameterMarkerSegments.sort(Comparator.comparingInt(SQLSegment::getStopIndex));
         StringBuilder result = new StringBuilder(sql);
         for (int i = parameterMarkerSegments.size() - 1; i >= 0; i--) {
diff --git 
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutorTest.java
 
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutorTest.java
index 0c0c0a819c9..1e746c51120 100644
--- 
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutorTest.java
+++ 
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/bind/PostgreSQLComBindExecutorTest.java
@@ -18,6 +18,7 @@
 package 
org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.bind;
 
 import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
+import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLBindCompletePacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLComBindPacket;
 import 
org.apache.shardingsphere.infra.binder.statement.UnknownSQLStatementContext;
@@ -42,8 +43,10 @@ import org.mockito.InjectMocks;
 import org.mockito.Mock;
 
 import java.sql.SQLException;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.List;
 
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -63,7 +66,7 @@ class PostgreSQLComBindExecutorTest {
     @Mock
     private PostgreSQLComBindPacket bindPacket;
     
-    @Mock
+    @Mock(answer = Answers.CALLS_REAL_METHODS)
     private ConnectionSession connectionSession;
     
     @InjectMocks
@@ -94,4 +97,32 @@ class PostgreSQLComBindExecutorTest {
         assertThat(actual.iterator().next(), 
is(PostgreSQLBindCompletePacket.getInstance()));
         verify(portalContext).add(any(Portal.class));
     }
+    
+    @Test
+    void assertExecuteBindParameters() throws SQLException {
+        String databaseName = "postgres";
+        ShardingSphereDatabase database = mock(ShardingSphereDatabase.class);
+        when(database.getProtocolType()).thenReturn(new 
PostgreSQLDatabaseType());
+        
when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(new 
ServerPreparedStatementRegistry());
+        ProxyDatabaseConnectionManager databaseConnectionManager = 
mock(ProxyDatabaseConnectionManager.class);
+        
when(databaseConnectionManager.getConnectionSession()).thenReturn(connectionSession);
+        
when(connectionSession.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
+        
when(connectionSession.getDefaultDatabaseName()).thenReturn(databaseName);
+        String statementId = "S_1";
+        List<Object> parameters = Arrays.asList(1, "updated_name");
+        PostgreSQLServerPreparedStatement serverPreparedStatement = new 
PostgreSQLServerPreparedStatement("update test set name = $2 where id = $1",
+                new UnknownSQLStatementContext(new PostgreSQLEmptyStatement()),
+                Arrays.asList(PostgreSQLColumnType.VARCHAR, 
PostgreSQLColumnType.INT4),
+                Arrays.asList(1, 0));
+        
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId,
 serverPreparedStatement);
+        when(bindPacket.getStatementId()).thenReturn(statementId);
+        when(bindPacket.getPortal()).thenReturn("C_1");
+        when(bindPacket.readParameters(anyList())).thenReturn(parameters);
+        
when(bindPacket.readResultFormats()).thenReturn(Collections.emptyList());
+        ContextManager contextManager = mock(ContextManager.class, 
Answers.RETURNS_DEEP_STUBS);
+        
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+        
when(ProxyContext.getInstance().getDatabase(databaseName)).thenReturn(database);
+        executor.execute();
+        assertThat(connectionSession.getQueryContext().getParameters(), 
is(Arrays.asList(parameters.get(1), parameters.get(0))));
+    }
 }
diff --git 
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
 
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
index 9667e79c445..ada9e711e68 100644
--- 
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
+++ 
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/parse/PostgreSQLComParseExecutorTest.java
@@ -115,6 +115,24 @@ class PostgreSQLComParseExecutorTest {
         assertThat(actualPreparedStatement.getParameterTypes(), 
is(Arrays.asList(PostgreSQLColumnType.INT4, PostgreSQLColumnType.UNSPECIFIED)));
     }
     
+    @Test
+    void assetExecuteWithNonOrderedParameterizedSQL() throws 
ReflectiveOperationException {
+        final String rawSQL = "update t_test set name=$2 where id=$1";
+        final String expectedSQL = "update t_test set name=? where id=?";
+        final String statementId = "S_2";
+        when(parsePacket.getSQL()).thenReturn(rawSQL);
+        when(parsePacket.getStatementId()).thenReturn(statementId);
+        
when(parsePacket.readParameterTypes()).thenReturn(Arrays.asList(PostgreSQLColumnType.JSON,
 PostgreSQLColumnType.INT4));
+        
Plugins.getMemberAccessor().set(PostgreSQLComParseExecutor.class.getDeclaredField("connectionSession"),
 executor, connectionSession);
+        ContextManager contextManager = mockContextManager();
+        
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);
+        executor.execute();
+        PostgreSQLServerPreparedStatement actualPreparedStatement = 
connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(statementId);
+        assertThat(actualPreparedStatement.getSql(), is(expectedSQL));
+        assertThat(actualPreparedStatement.getParameterTypes(), 
is(Arrays.asList(PostgreSQLColumnType.JSON, PostgreSQLColumnType.INT4)));
+        assertThat(actualPreparedStatement.getActualParameterMarkerIndexes(), 
is(Arrays.asList(1, 0)));
+    }
+    
     @Test
     void assertExecuteWithDistSQL() {
         String sql = "SHOW DIST VARIABLE WHERE NAME = sql_show";

Reply via email to