This is an automated email from the ASF dual-hosted git repository.

duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new be717e1733d Parse unpivot segment and extract column in sql binder 
(#28438)
be717e1733d is described below

commit be717e1733d2f9ef99c6348b069be099a6904c8f
Author: ZhangCheng <[email protected]>
AuthorDate: Fri Sep 15 10:44:48 2023 +0800

    Parse unpivot segment and extract column in sql binder (#28438)
    
    * Parse unpivot segment and extract column in sql binder
    
    * Parse unpivot segment and extract column in sql binder
---
 .../expression/impl/ColumnSegmentBinder.java       |  3 +--
 .../from/impl/SimpleTableSegmentBinder.java        |  5 ++++
 .../from/impl/SubqueryTableSegmentBinder.java      |  8 ++----
 .../src/main/antlr4/imports/oracle/DMLStatement.g4 |  6 ++++-
 .../statement/type/OracleDMLStatementVisitor.java  | 23 +++++++++++++++++
 .../sql/common/segment/generic/PivotSegment.java   | 29 ++++++++++++++++++++++
 6 files changed, 65 insertions(+), 9 deletions(-)

diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
index faeb0bca37d..d59a8884b18 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
@@ -166,8 +166,7 @@ public final class ColumnSegmentBinder {
             return Optional.empty();
         }
         if 
(pivotColumnNames.contains(segment.getIdentifier().getValue().toLowerCase())) {
-            ColumnSegment result = new ColumnSegment(0, 0, 
segment.getIdentifier());
-            return Optional.of(result);
+            return Optional.of(new ColumnSegment(0, 0, 
segment.getIdentifier()));
         }
         return Optional.empty();
     }
diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SimpleTableSegmentBinder.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SimpleTableSegmentBinder.java
index 76431790969..8e2df609fae 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SimpleTableSegmentBinder.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SimpleTableSegmentBinder.java
@@ -76,6 +76,7 @@ public final class SimpleTableSegmentBinder {
      * @return bounded simple table segment
      */
     public static SimpleTableSegment bind(final SimpleTableSegment segment, 
final SQLStatementBinderContext statementBinderContext, final Map<String, 
TableSegmentBinderContext> tableBinderContexts) {
+        fillPivotColumnNamesInBinderContext(segment, statementBinderContext);
         IdentifierValue originalDatabase = getDatabaseName(segment, 
statementBinderContext);
         IdentifierValue originalSchema = getSchemaName(segment, 
statementBinderContext);
         checkTableExists(segment.getTableName().getIdentifier().getValue(), 
statementBinderContext, originalDatabase.getValue(), originalSchema.getValue());
@@ -90,6 +91,10 @@ public final class SimpleTableSegmentBinder {
         return result;
     }
     
+    private static void fillPivotColumnNamesInBinderContext(final 
SimpleTableSegment segment, final SQLStatementBinderContext 
statementBinderContext) {
+        segment.getPivot().ifPresent(optional -> 
optional.getPivotColumns().forEach(each -> 
statementBinderContext.getPivotColumnNames().add(each.getIdentifier().getValue().toLowerCase())));
+    }
+    
     private static IdentifierValue getDatabaseName(final SimpleTableSegment 
tableSegment, final SQLStatementBinderContext statementBinderContext) {
         DialectDatabaseMetaData dialectDatabaseMetaData = new 
DatabaseTypeRegistry(statementBinderContext.getDatabaseType()).getDialectDatabaseMetaData();
         Optional<OwnerSegment> owner = 
dialectDatabaseMetaData.getDefaultSchema().isPresent() ? 
tableSegment.getOwner().flatMap(OwnerSegment::getOwner) : 
tableSegment.getOwner();
diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SubqueryTableSegmentBinder.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SubqueryTableSegmentBinder.java
index e82fab4b026..871a159d9b4 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SubqueryTableSegmentBinder.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SubqueryTableSegmentBinder.java
@@ -57,12 +57,12 @@ public final class SubqueryTableSegmentBinder {
      */
     public static SubqueryTableSegment bind(final SubqueryTableSegment 
segment, final SQLStatementBinderContext statementBinderContext,
                                             final Map<String, 
TableSegmentBinderContext> tableBinderContexts, final Map<String, 
TableSegmentBinderContext> outerTableBinderContexts) {
+        fillPivotColumnNamesInBinderContext(segment, statementBinderContext);
         SelectStatement boundedSelect = new 
SelectStatementBinder().bindCorrelateSubquery(segment.getSubquery().getSelect(),
 statementBinderContext.getMetaData(),
                 statementBinderContext.getDefaultDatabaseName(), 
outerTableBinderContexts, 
statementBinderContext.getExternalTableBinderContexts());
         SubquerySegment boundedSubquerySegment = new 
SubquerySegment(segment.getSubquery().getStartIndex(), 
segment.getSubquery().getStopIndex(), boundedSelect);
         
boundedSubquerySegment.setSubqueryType(segment.getSubquery().getSubqueryType());
         SubqueryTableSegment result = new 
SubqueryTableSegment(boundedSubquerySegment);
-        fillPivotColumnNamesInBinderContext(segment, statementBinderContext);
         segment.getAliasSegment().ifPresent(result::setAlias);
         IdentifierValue subqueryTableName = 
segment.getAliasSegment().map(AliasSegment::getIdentifier).orElseGet(() -> new 
IdentifierValue(""));
         tableBinderContexts.put(subqueryTableName.getValue().toLowerCase(),
@@ -71,11 +71,7 @@ public final class SubqueryTableSegmentBinder {
     }
     
     private static void fillPivotColumnNamesInBinderContext(final 
SubqueryTableSegment segment, final SQLStatementBinderContext 
statementBinderContext) {
-        segment.getPivot().ifPresent(optional -> {
-            for (ColumnSegment each : optional.getPivotInColumns()) {
-                
statementBinderContext.getPivotColumnNames().add(each.getIdentifier().getValue().toLowerCase());
-            }
-        });
+        segment.getPivot().ifPresent(optional -> 
optional.getPivotColumns().forEach(each -> 
statementBinderContext.getPivotColumnNames().add(each.getIdentifier().getValue().toLowerCase())));
     }
     
     private static Collection<ProjectionSegment> 
createSubqueryProjections(final Collection<ProjectionSegment> projections, 
final IdentifierValue subqueryTableName) {
diff --git 
a/parser/sql/dialect/oracle/src/main/antlr4/imports/oracle/DMLStatement.g4 
b/parser/sql/dialect/oracle/src/main/antlr4/imports/oracle/DMLStatement.g4
index 6b850168f86..c511e5458f6 100644
--- a/parser/sql/dialect/oracle/src/main/antlr4/imports/oracle/DMLStatement.g4
+++ b/parser/sql/dialect/oracle/src/main/antlr4/imports/oracle/DMLStatement.g4
@@ -535,7 +535,11 @@ unpivotClause
     ;
 
 unpivotInClause
-    : IN LP_ (columnName | columnNames) (AS (literals | LP_ literals (COMMA_ 
literals)* RP_))? (COMMA_ (columnName | columnNames) (AS (literals | LP_ 
literals (COMMA_ literals)* RP_))?)* RP_
+    : IN LP_ unpivotInClauseExpr (COMMA_ unpivotInClauseExpr)* RP_
+    ;
+
+unpivotInClauseExpr
+    : (columnName | columnNames) (AS (literals | LP_ literals (COMMA_ 
literals)* RP_))?
     ;
 
 sampleClause
diff --git 
a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
 
b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
index 08ce6ff45d3..b11f3d12d41 100644
--- 
a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
+++ 
b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java
@@ -95,6 +95,7 @@ import 
org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.Subque
 import 
org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.SubqueryFactoringClauseContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TableCollectionExprContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TableNameContext;
+import 
org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.UnpivotClauseContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.UpdateContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.UpdateSetClauseContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.UpdateSetColumnClauseContext;
@@ -437,6 +438,19 @@ public final class OracleDMLStatementVisitor extends 
OracleStatementVisitor impl
         return new PivotSegment(ctx.getStart().getStartIndex(), 
ctx.getStop().getStopIndex(), pivotForColumn, pivotInColumns);
     }
     
+    @Override
+    public ASTNode visitUnpivotClause(final UnpivotClauseContext ctx) {
+        ColumnSegment unpivotColumn = (ColumnSegment) 
visitColumnName(ctx.columnName());
+        ColumnSegment unpivotForColumn = (ColumnSegment) 
visitColumnName(ctx.pivotForClause().columnName());
+        Collection<ColumnSegment> unpivotInColumns = new LinkedList<>();
+        if (null != ctx.unpivotInClause()) {
+            ctx.unpivotInClause().unpivotInClauseExpr().forEach(each -> 
unpivotInColumns.add((ColumnSegment) visit(each.columnName())));
+        }
+        PivotSegment result = new PivotSegment(ctx.getStart().getStartIndex(), 
ctx.getStop().getStopIndex(), unpivotForColumn, unpivotInColumns, true);
+        result.setUnpivotColumn(unpivotColumn);
+        return result;
+    }
+    
     @Override
     public ASTNode visitDmlTableClause(final DmlTableClauseContext ctx) {
         return visit(ctx.tableName());
@@ -1043,6 +1057,15 @@ public final class OracleDMLStatementVisitor extends 
OracleStatementVisitor impl
                 ((SimpleTableSegment) result).setPivot(pivotClause);
             }
         }
+        if (null != ctx.unpivotClause()) {
+            PivotSegment pivotClause = (PivotSegment) 
visit(ctx.unpivotClause());
+            if (result instanceof SubqueryTableSegment) {
+                ((SubqueryTableSegment) result).setPivot(pivotClause);
+            }
+            if (result instanceof SimpleTableSegment) {
+                ((SimpleTableSegment) result).setPivot(pivotClause);
+            }
+        }
         return result;
     }
     
diff --git 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/PivotSegment.java
 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/PivotSegment.java
index d2bbfc70948..8dcea1768a2 100644
--- 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/PivotSegment.java
+++ 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/PivotSegment.java
@@ -19,10 +19,12 @@ package 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic;
 
 import lombok.Getter;
 import lombok.RequiredArgsConstructor;
+import lombok.Setter;
 import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
 
 import java.util.Collection;
+import java.util.HashSet;
 
 /**
  * Pivot segment.
@@ -38,4 +40,31 @@ public final class PivotSegment implements SQLSegment {
     private final ColumnSegment pivotForColumn;
     
     private final Collection<ColumnSegment> pivotInColumns;
+    
+    private final boolean isUnPivot;
+    
+    @Setter
+    private ColumnSegment unpivotColumn;
+    
+    public PivotSegment(final int startIndex, final int stopIndex, final 
ColumnSegment pivotForColumn, final Collection<ColumnSegment> pivotInColumns) {
+        this.startIndex = startIndex;
+        this.stopIndex = stopIndex;
+        this.pivotForColumn = pivotForColumn;
+        this.pivotInColumns = pivotInColumns;
+        this.isUnPivot = false;
+    }
+    
+    /**
+     * Get pivot columns.
+     * 
+     * @return pivot columns
+     */
+    public Collection<ColumnSegment> getPivotColumns() {
+        Collection<ColumnSegment> result = new HashSet<>(pivotInColumns);
+        result.add(pivotForColumn);
+        if (null != unpivotColumn) {
+            result.add(unpivotColumn);
+        }
+        return result;
+    }
 }

Reply via email to