This is an automated email from the ASF dual-hosted git repository.
duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 64b5c3a8507 Fix suquery expression projection in merge bind ex (#28392)
64b5c3a8507 is described below
commit 64b5c3a85072ba497aebb4f293119d66b7a3314c
Author: Chuxin Chen <[email protected]>
AuthorDate: Fri Sep 8 16:12:41 2023 +0800
Fix suquery expression projection in merge bind ex (#28392)
* Fix suquery expression projection in merge bind ex
* Fix suquery expression projection in merge bind ex
---
.../expression/impl/ColumnSegmentBinder.java | 13 +++++----
.../binder/statement/MergeStatementBinderTest.java | 34 ++++++++++++++++++++++
2 files changed, 42 insertions(+), 5 deletions(-)
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
index dad579cf30f..4502311fa6e 100644
---
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
@@ -130,8 +130,11 @@ public final class ColumnSegmentBinder {
}
}
if (!isFindInputColumn) {
- result = findInputColumnSegmentFromExternalTables(segment,
statementBinderContext.getExternalTableBinderContexts()).orElse(null);
- isFindInputColumn = result != null;
+ Optional<ProjectionSegment> projectionSegment =
findInputColumnSegmentFromExternalTables(segment,
statementBinderContext.getExternalTableBinderContexts());
+ isFindInputColumn = projectionSegment.isPresent();
+ if (projectionSegment.isPresent() && projectionSegment.get()
instanceof ColumnProjectionSegment) {
+ result = ((ColumnProjectionSegment)
projectionSegment.get()).getColumn();
+ }
}
if (!isFindInputColumn) {
result = findInputColumnSegmentByVariables(segment,
statementBinderContext.getVariableNames()).orElse(null);
@@ -142,11 +145,11 @@ public final class ColumnSegmentBinder {
return Optional.ofNullable(result);
}
- private static Optional<ColumnSegment>
findInputColumnSegmentFromExternalTables(final ColumnSegment segment, final
Map<String, TableSegmentBinderContext> externalTableBinderContexts) {
+ private static Optional<ProjectionSegment>
findInputColumnSegmentFromExternalTables(final ColumnSegment segment, final
Map<String, TableSegmentBinderContext> externalTableBinderContexts) {
for (TableSegmentBinderContext each :
externalTableBinderContexts.values()) {
ProjectionSegment projectionSegment =
each.getProjectionSegmentByColumnLabel(segment.getIdentifier().getValue());
- if (projectionSegment instanceof ColumnProjectionSegment) {
- return Optional.of(((ColumnProjectionSegment)
projectionSegment).getColumn());
+ if (null != projectionSegment) {
+ return Optional.of(projectionSegment);
}
}
return Optional.empty();
diff --git
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
index 0ca795d4a74..e5671c97ad2 100644
---
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
+++
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
@@ -27,15 +27,20 @@ import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.Se
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ExpressionProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.AliasSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.MergeStatement;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
import
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleMergeStatement;
+import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleSelectStatement;
import
org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleUpdateStatement;
import org.junit.jupiter.api.Test;
@@ -105,4 +110,33 @@ class MergeStatementBinderTest {
when(result.getDatabase(DefaultDatabase.LOGIC_NAME).getSchema(DefaultDatabase.LOGIC_NAME).containsTable("t_order_item")).thenReturn(true);
return result;
}
+
+ @Test
+ void assertBindWithSubQuery() {
+ MergeStatement mergeStatement = new OracleMergeStatement();
+ SimpleTableSegment targetTable = new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("t_order")));
+ targetTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("a")));
+ mergeStatement.setTarget(targetTable);
+ ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0);
+ ExpressionProjectionSegment expressionProjectionSegment = new
ExpressionProjectionSegment(0, 0, "status + 1", new
BinaryOperationExpression(0, 0,
+ new ColumnSegment(0, 0, new IdentifierValue("status")), new
LiteralExpressionSegment(0, 0, 1), "+", "status + 1"));
+ expressionProjectionSegment.setAlias(new AliasSegment(0, 0, new
IdentifierValue("new_status")));
+ projectionsSegment.getProjections().add(expressionProjectionSegment);
+ OracleSelectStatement oracleSelectStatement = new
OracleSelectStatement();
+ oracleSelectStatement.setProjections(projectionsSegment);
+ oracleSelectStatement.setFrom(new SimpleTableSegment(new
TableNameSegment(0, 0, new IdentifierValue("t_order_item"))));
+ SubqueryTableSegment subqueryTableSegment = new
SubqueryTableSegment(new SubquerySegment(0, 0, oracleSelectStatement));
+ subqueryTableSegment.setAlias(new AliasSegment(0, 0, new
IdentifierValue("b")));
+ mergeStatement.setSource(subqueryTableSegment);
+ UpdateStatement updateStatement = new OracleUpdateStatement();
+ ColumnSegment targetTableColumn = new ColumnSegment(0, 0, new
IdentifierValue("status"));
+ targetTableColumn.setOwner(new OwnerSegment(0, 0, new
IdentifierValue("a")));
+ ColumnSegment sourceTableColumn = new ColumnSegment(0, 0, new
IdentifierValue("new_status"));
+ SetAssignmentSegment setAssignmentSegment = new
SetAssignmentSegment(0, 0,
+ Collections.singletonList(new ColumnAssignmentSegment(0, 0,
Collections.singletonList(targetTableColumn), sourceTableColumn)));
+ updateStatement.setSetAssignment(setAssignmentSegment);
+ mergeStatement.setUpdate(updateStatement);
+ MergeStatement actual = new
MergeStatementBinder().bind(mergeStatement, createMetaData(),
DefaultDatabase.LOGIC_NAME);
+ assertThat(actual, not(mergeStatement));
+ }
}