This is an automated email from the ASF dual-hosted git repository.

tuichenchuxin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new a343cc642fc Minor refactor PaginationContextEngine init logic (#30753)
a343cc642fc is described below

commit a343cc642fc8b00fdf4a57a804ff604e422fb080
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Wed Apr 3 15:42:21 2024 +0800

    Minor refactor PaginationContextEngine init logic (#30753)
---
 .../pagination/engine/PaginationContextEngine.java     |  7 ++++++-
 .../engine/RowNumberPaginationContextEngine.java       | 17 ++++++++++++++---
 .../context/statement/dml/SelectStatementContext.java  |  4 ++--
 .../pagination/engine/PaginationContextEngineTest.java | 18 +++++++++++-------
 .../engine/RowNumberPaginationContextEngineTest.java   | 17 +++++++++++------
 5 files changed, 44 insertions(+), 19 deletions(-)

diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngine.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngine.java
index ea045fa783c..e17c49e40dd 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngine.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngine.java
@@ -17,8 +17,10 @@
 
 package 
org.apache.shardingsphere.infra.binder.context.segment.select.pagination.engine;
 
+import lombok.RequiredArgsConstructor;
 import 
org.apache.shardingsphere.infra.binder.context.segment.select.pagination.PaginationContext;
 import 
org.apache.shardingsphere.infra.binder.context.segment.select.projection.ProjectionsContext;
+import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.limit.LimitSegment;
@@ -40,8 +42,11 @@ import java.util.Optional;
 /**
  * Pagination context engine.
  */
+@RequiredArgsConstructor
 public final class PaginationContextEngine {
     
+    private final DatabaseType databaseType;
+    
     /**
      * Create pagination context.
      * 
@@ -66,7 +71,7 @@ public final class PaginationContextEngine {
             return new 
TopPaginationContextEngine().createPaginationContext(topProjectionSegment.get(),
 expressions, params);
         }
         if (!expressions.isEmpty() && 
containsRowNumberPagination(selectStatement)) {
-            return new 
RowNumberPaginationContextEngine().createPaginationContext(expressions, 
projectionsContext, params);
+            return new 
RowNumberPaginationContextEngine(databaseType).createPaginationContext(expressions,
 projectionsContext, params);
         }
         return new PaginationContext(null, null, params);
     }
diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/RowNumberPaginationContextEngine.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/RowNumberPaginationContextEngine.java
index 2cc26aaa460..0e7bc8d1403 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/RowNumberPaginationContextEngine.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/RowNumberPaginationContextEngine.java
@@ -17,8 +17,12 @@
 
 package 
org.apache.shardingsphere.infra.binder.context.segment.select.pagination.engine;
 
+import com.cedarsoftware.util.CaseInsensitiveMap;
+import com.cedarsoftware.util.CaseInsensitiveSet;
+import lombok.RequiredArgsConstructor;
 import 
org.apache.shardingsphere.infra.binder.context.segment.select.pagination.PaginationContext;
 import 
org.apache.shardingsphere.infra.binder.context.segment.select.projection.ProjectionsContext;
+import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
@@ -34,20 +38,27 @@ import 
org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUti
 import java.util.Collection;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Map;
 import java.util.Optional;
-import java.util.TreeSet;
 import java.util.stream.Collectors;
 
 /**
  * Pagination context engine for row number.
  */
+@RequiredArgsConstructor
 public final class RowNumberPaginationContextEngine {
     
-    private static final Collection<String> ROW_NUMBER_IDENTIFIERS = new 
TreeSet<>(String.CASE_INSENSITIVE_ORDER);
+    private static final Collection<String> ROW_NUMBER_IDENTIFIERS = new 
CaseInsensitiveSet<>(2, 1F);
+    
+    private static final Map<String, String> 
DATABASE_TYPE_ROW_NUMBER_IDENTIFIERS = new CaseInsensitiveMap<>(2, 1F);
+    
+    private final DatabaseType databaseType;
     
     static {
         ROW_NUMBER_IDENTIFIERS.add("ROWNUM");
         ROW_NUMBER_IDENTIFIERS.add("ROW_NUMBER");
+        DATABASE_TYPE_ROW_NUMBER_IDENTIFIERS.put("Oracle", "ROWNUM");
+        DATABASE_TYPE_ROW_NUMBER_IDENTIFIERS.put("SQLServer", "ROW_NUMBER");
     }
     
     /**
@@ -87,7 +98,7 @@ public final class RowNumberPaginationContextEngine {
                 return result;
             }
         }
-        return Optional.empty();
+        return 
Optional.ofNullable(DATABASE_TYPE_ROW_NUMBER_IDENTIFIERS.get(databaseType.getType()));
     }
     
     private boolean isRowNumberColumn(final ExpressionSegment predicate, final 
String rowNumberAlias) {
diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java
index d1600b2bcd0..e7fd9d0f27d 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java
@@ -122,7 +122,7 @@ public final class SelectStatementContext extends 
CommonSQLStatementContext impl
         groupByContext = new 
GroupByContextEngine().createGroupByContext(sqlStatement);
         orderByContext = new 
OrderByContextEngine().createOrderBy(sqlStatement, groupByContext);
         projectionsContext = new 
ProjectionsContextEngine(getDatabaseType()).createProjectionsContext(getSqlStatement().getProjections(),
 groupByContext, orderByContext);
-        paginationContext = new 
PaginationContextEngine().createPaginationContext(sqlStatement, 
projectionsContext, params, whereSegments);
+        paginationContext = new 
PaginationContextEngine(getDatabaseType()).createPaginationContext(sqlStatement,
 projectionsContext, params, whereSegments);
         String databaseName = 
tablesContext.getDatabaseName().orElse(defaultDatabaseName);
         containsEnhancedTable = isContainsEnhancedTable(metaData, 
databaseName, getTablesContext().getTableNames());
     }
@@ -393,6 +393,6 @@ public final class SelectStatementContext extends 
CommonSQLStatementContext impl
     
     @Override
     public void setUpParameters(final List<Object> params) {
-        paginationContext = new 
PaginationContextEngine().createPaginationContext(getSqlStatement(), 
projectionsContext, params, whereSegments);
+        paginationContext = new 
PaginationContextEngine(getDatabaseType()).createPaginationContext(getSqlStatement(),
 projectionsContext, params, whereSegments);
     }
 }
diff --git 
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngineTest.java
 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngineTest.java
index 39343b29443..807958a980e 100644
--- 
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngineTest.java
+++ 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngineTest.java
@@ -19,6 +19,10 @@ package 
org.apache.shardingsphere.infra.binder.context.segment.select.pagination
 
 import 
org.apache.shardingsphere.infra.binder.context.segment.select.pagination.PaginationContext;
 import 
org.apache.shardingsphere.infra.binder.context.segment.select.projection.ProjectionsContext;
+import org.apache.shardingsphere.infra.database.mysql.type.MySQLDatabaseType;
+import 
org.apache.shardingsphere.infra.database.postgresql.type.PostgreSQLDatabaseType;
+import org.apache.shardingsphere.infra.database.sql92.type.SQL92DatabaseType;
+import 
org.apache.shardingsphere.infra.database.sqlserver.type.SQLServerDatabaseType;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.SubqueryProjectionSegment;
@@ -49,7 +53,7 @@ class PaginationContextEngineTest {
         MySQLSelectStatement selectStatement = new MySQLSelectStatement();
         selectStatement.setLimit(new LimitSegment(0, 10, new 
NumberLiteralLimitValueSegment(0, 10, 100L),
                 new NumberLiteralLimitValueSegment(0, 10, 100L)));
-        PaginationContext paginationContext = new 
PaginationContextEngine().createPaginationContext(
+        PaginationContext paginationContext = new PaginationContextEngine(new 
MySQLDatabaseType()).createPaginationContext(
                 selectStatement, mock(ProjectionsContext.class), 
Collections.emptyList(), Collections.emptyList());
         assertTrue(paginationContext.getOffsetSegment().isPresent());
         assertTrue(paginationContext.getRowCountSegment().isPresent());
@@ -60,7 +64,7 @@ class PaginationContextEngineTest {
         PostgreSQLSelectStatement selectStatement = new 
PostgreSQLSelectStatement();
         selectStatement.setLimit(new LimitSegment(0, 10, new 
NumberLiteralLimitValueSegment(0, 10, 100L),
                 new NumberLiteralLimitValueSegment(0, 10, 100L)));
-        PaginationContext paginationContext = new 
PaginationContextEngine().createPaginationContext(
+        PaginationContext paginationContext = new PaginationContextEngine(new 
PostgreSQLDatabaseType()).createPaginationContext(
                 selectStatement, mock(ProjectionsContext.class), 
Collections.emptyList(), Collections.emptyList());
         assertTrue(paginationContext.getOffsetSegment().isPresent());
         assertTrue(paginationContext.getRowCountSegment().isPresent());
@@ -71,7 +75,7 @@ class PaginationContextEngineTest {
         SQL92SelectStatement selectStatement = new SQL92SelectStatement();
         selectStatement.setLimit(new LimitSegment(0, 10, new 
NumberLiteralLimitValueSegment(0, 10, 100L),
                 new NumberLiteralLimitValueSegment(0, 10, 100L)));
-        PaginationContext paginationContext = new 
PaginationContextEngine().createPaginationContext(
+        PaginationContext paginationContext = new PaginationContextEngine(new 
SQL92DatabaseType()).createPaginationContext(
                 selectStatement, mock(ProjectionsContext.class), 
Collections.emptyList(), Collections.emptyList());
         assertTrue(paginationContext.getOffsetSegment().isPresent());
         assertTrue(paginationContext.getRowCountSegment().isPresent());
@@ -82,7 +86,7 @@ class PaginationContextEngineTest {
         SQLServerSelectStatement selectStatement = new 
SQLServerSelectStatement();
         selectStatement.setLimit(new LimitSegment(0, 10, new 
NumberLiteralLimitValueSegment(0, 10, 100L),
                 new NumberLiteralLimitValueSegment(0, 10, 100L)));
-        PaginationContext paginationContext = new 
PaginationContextEngine().createPaginationContext(
+        PaginationContext paginationContext = new PaginationContextEngine(new 
SQLServerDatabaseType()).createPaginationContext(
                 selectStatement, mock(ProjectionsContext.class), 
Collections.emptyList(), Collections.emptyList());
         assertTrue(paginationContext.getOffsetSegment().isPresent());
         assertTrue(paginationContext.getRowCountSegment().isPresent());
@@ -96,7 +100,7 @@ class PaginationContextEngineTest {
         SQLServerSelectStatement selectStatement = new 
SQLServerSelectStatement();
         selectStatement.setProjections(new ProjectionsSegment(0, 0));
         selectStatement.getProjections().getProjections().add(new 
SubqueryProjectionSegment(new SubquerySegment(0, 0, subquerySelectStatement, 
""), ""));
-        PaginationContext paginationContext = new 
PaginationContextEngine().createPaginationContext(
+        PaginationContext paginationContext = new PaginationContextEngine(new 
SQLServerDatabaseType()).createPaginationContext(
                 selectStatement, mock(ProjectionsContext.class), 
Collections.emptyList(), Collections.emptyList());
         assertFalse(paginationContext.getOffsetSegment().isPresent());
         assertFalse(paginationContext.getRowCountSegment().isPresent());
@@ -109,7 +113,7 @@ class PaginationContextEngineTest {
         WhereSegment where = new WhereSegment(0, 10, null);
         selectStatement.setWhere(where);
         ProjectionsContext projectionsContext = new ProjectionsContext(0, 0, 
false, Collections.emptyList());
-        PaginationContext paginationContext = new 
PaginationContextEngine().createPaginationContext(
+        PaginationContext paginationContext = new PaginationContextEngine(new 
SQLServerDatabaseType()).createPaginationContext(
                 selectStatement, projectionsContext, Collections.emptyList(), 
Collections.singletonList(where));
         assertFalse(paginationContext.getOffsetSegment().isPresent());
         assertFalse(paginationContext.getRowCountSegment().isPresent());
@@ -143,7 +147,7 @@ class PaginationContextEngineTest {
     private void 
assertCreatePaginationContextWhenResultIsPaginationContext(final 
SelectStatement selectStatement) {
         selectStatement.setProjections(new ProjectionsSegment(0, 0));
         ProjectionsContext projectionsContext = new ProjectionsContext(0, 0, 
false, Collections.emptyList());
-        assertThat(new PaginationContextEngine().createPaginationContext(
+        assertThat(new PaginationContextEngine(new 
MySQLDatabaseType()).createPaginationContext(
                 selectStatement, projectionsContext, Collections.emptyList(), 
Collections.emptyList()), instanceOf(PaginationContext.class));
     }
 }
diff --git 
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/RowNumberPaginationContextEngineTest.java
 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/RowNumberPaginationContextEngineTest.java
index fa0aa57030a..e08f5339182 100644
--- 
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/RowNumberPaginationContextEngineTest.java
+++ 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/RowNumberPaginationContextEngineTest.java
@@ -22,6 +22,7 @@ import 
org.apache.shardingsphere.infra.binder.context.segment.select.projection.
 import 
org.apache.shardingsphere.infra.binder.context.segment.select.projection.ProjectionsContext;
 import 
org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
 import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
+import org.apache.shardingsphere.infra.database.oracle.type.OracleDatabaseType;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
@@ -52,7 +53,8 @@ class RowNumberPaginationContextEngineTest {
     @Test
     void assertCreatePaginationContextWhenRowNumberAliasNotPresent() {
         ProjectionsContext projectionsContext = new ProjectionsContext(0, 0, 
false, Collections.emptyList());
-        PaginationContext paginationContext = new 
RowNumberPaginationContextEngine().createPaginationContext(Collections.emptyList(),
 projectionsContext, Collections.emptyList());
+        PaginationContext paginationContext =
+                new RowNumberPaginationContextEngine(new 
OracleDatabaseType()).createPaginationContext(Collections.emptyList(), 
projectionsContext, Collections.emptyList());
         assertFalse(paginationContext.getOffsetSegment().isPresent());
         assertFalse(paginationContext.getRowCountSegment().isPresent());
     }
@@ -61,7 +63,8 @@ class RowNumberPaginationContextEngineTest {
     void 
assertCreatePaginationContextWhenRowNumberAliasIsPresentAndRowNumberPredicatesIsEmpty()
 {
         Projection projectionWithRowNumberAlias = new ColumnProjection(null, 
ROW_NUMBER_COLUMN_NAME, ROW_NUMBER_COLUMN_ALIAS, mock(DatabaseType.class));
         ProjectionsContext projectionsContext = new ProjectionsContext(0, 0, 
false, Collections.singleton(projectionWithRowNumberAlias));
-        PaginationContext paginationContext = new 
RowNumberPaginationContextEngine().createPaginationContext(Collections.emptyList(),
 projectionsContext, Collections.emptyList());
+        PaginationContext paginationContext =
+                new RowNumberPaginationContextEngine(new 
OracleDatabaseType()).createPaginationContext(Collections.emptyList(), 
projectionsContext, Collections.emptyList());
         assertFalse(paginationContext.getOffsetSegment().isPresent());
         assertFalse(paginationContext.getRowCountSegment().isPresent());
     }
@@ -94,7 +97,8 @@ class RowNumberPaginationContextEngineTest {
         ColumnSegment left = new ColumnSegment(0, 10, new 
IdentifierValue(ROW_NUMBER_COLUMN_NAME));
         BinaryOperationExpression predicateSegment = new 
BinaryOperationExpression(0, 0, left, null, null, null);
         andPredicate.getPredicates().add(predicateSegment);
-        PaginationContext paginationContext = new 
RowNumberPaginationContextEngine().createPaginationContext(Collections.emptyList(),
 projectionsContext, Collections.emptyList());
+        PaginationContext paginationContext =
+                new RowNumberPaginationContextEngine(new 
OracleDatabaseType()).createPaginationContext(Collections.emptyList(), 
projectionsContext, Collections.emptyList());
         assertFalse(paginationContext.getOffsetSegment().isPresent());
         assertFalse(paginationContext.getRowCountSegment().isPresent());
     }
@@ -106,7 +110,7 @@ class RowNumberPaginationContextEngineTest {
         ColumnSegment left = new ColumnSegment(0, 10, new 
IdentifierValue(ROW_NUMBER_COLUMN_NAME));
         ParameterMarkerExpressionSegment right = new 
ParameterMarkerExpressionSegment(0, 10, 0);
         BinaryOperationExpression expression = new 
BinaryOperationExpression(0, 0, left, right, ">", null);
-        PaginationContext paginationContext = new 
RowNumberPaginationContextEngine()
+        PaginationContext paginationContext = new 
RowNumberPaginationContextEngine(new OracleDatabaseType())
                 
.createPaginationContext(Collections.singletonList(expression), 
projectionsContext, Collections.singletonList(1));
         Optional<PaginationValueSegment> offSetSegmentPaginationValue = 
paginationContext.getOffsetSegment();
         assertTrue(offSetSegmentPaginationValue.isPresent());
@@ -120,7 +124,8 @@ class RowNumberPaginationContextEngineTest {
         ColumnSegment left = new ColumnSegment(0, 10, new 
IdentifierValue(ROW_NUMBER_COLUMN_NAME));
         LiteralExpressionSegment right = new LiteralExpressionSegment(0, 10, 
100);
         BinaryOperationExpression expression = new 
BinaryOperationExpression(0, 0, left, right, operator, null);
-        PaginationContext paginationContext = new 
RowNumberPaginationContextEngine().createPaginationContext(Collections.singletonList(expression),
 projectionsContext, Collections.emptyList());
+        PaginationContext paginationContext =
+                new RowNumberPaginationContextEngine(new 
OracleDatabaseType()).createPaginationContext(Collections.singletonList(expression),
 projectionsContext, Collections.emptyList());
         assertFalse(paginationContext.getOffsetSegment().isPresent());
         Optional<PaginationValueSegment> paginationValueSegment = 
paginationContext.getRowCountSegment();
         assertTrue(paginationValueSegment.isPresent());
@@ -136,7 +141,7 @@ class RowNumberPaginationContextEngineTest {
         ColumnSegment left = new ColumnSegment(0, 10, new 
IdentifierValue(ROW_NUMBER_COLUMN_NAME));
         LiteralExpressionSegment right = new LiteralExpressionSegment(0, 10, 
100);
         BinaryOperationExpression expression = new 
BinaryOperationExpression(0, 0, left, right, operator, null);
-        PaginationContext rowNumberPaginationContextEngine = new 
RowNumberPaginationContextEngine()
+        PaginationContext rowNumberPaginationContextEngine = new 
RowNumberPaginationContextEngine(new OracleDatabaseType())
                 
.createPaginationContext(Collections.singletonList(expression), 
projectionsContext, Collections.emptyList());
         Optional<PaginationValueSegment> paginationValueSegment = 
rowNumberPaginationContextEngine.getOffsetSegment();
         assertTrue(paginationValueSegment.isPresent());

Reply via email to