This is an automated email from the ASF dual-hosted git repository. lancelly pushed a commit to branch remove_duplicate_code_in_column_transformer in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 4662161dca3b63e13f42427cec6883bfbf1ad050 Author: lancelly <[email protected]> AuthorDate: Sat Jan 25 11:00:36 2025 +0800 remove duplicate code in ColumnTranformerBuilder --- .../relational/ColumnTransformerBuilder.java | 234 +++++++-------------- .../rule/TransformCorrelatedScalarSubquery.java | 189 +++++++++++++++++ 2 files changed, 269 insertions(+), 154 deletions(-) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java index 55de5391938..ce69b23c3bc 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java @@ -237,16 +237,8 @@ public class ColumnTransformerBuilder throw new UnsupportedOperationException( String.format(UNSUPPORTED_EXPRESSION, node.getOperator())); } - TSDataType tsDataType = InternalTypeManager.getTSDataType(type); - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - type, context.originSize + context.commonTransformerList.size()); - ColumnTransformer columnTransformer = context.hasSeen.get(node); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(tsDataType); - context.cache.put(node, identity); + appendIdentityColumnTransformer( + node, type, InternalTypeManager.getTSDataType(type), context); } else { ZoneId zoneId = context.sessionInfo.getZoneId(); ColumnTransformer left = process(node.getLeft(), context); @@ -276,9 +268,7 @@ public class ColumnTransformerBuilder context.cache.put(node, child); } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); } @Override @@ -290,15 +280,7 @@ public class ColumnTransformerBuilder case MINUS: if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - DOUBLE, context.originSize + context.commonTransformerList.size()); - ColumnTransformer columnTransformer = context.hasSeen.get(node); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(TSDataType.DOUBLE); - context.cache.put(node, identity); + appendIdentityColumnTransformer(node, DOUBLE, TSDataType.DOUBLE, context); } else { ColumnTransformer childColumnTransformer = process(node.getValue(), context); context.cache.put( @@ -306,9 +288,7 @@ public class ColumnTransformerBuilder ArithmeticColumnTransformerApi.getNegationTransformer(childColumnTransformer)); } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); default: throw new UnsupportedOperationException("Unknown sign: " + node.getSign()); } @@ -318,15 +298,7 @@ public class ColumnTransformerBuilder protected ColumnTransformer visitBetweenPredicate(BetweenPredicate node, Context context) { if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - BOOLEAN, context.originSize + context.commonTransformerList.size()); - ColumnTransformer columnTransformer = context.hasSeen.get(node); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(TSDataType.BOOLEAN); - context.cache.put(node, identity); + appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, context); } else { ColumnTransformer value = this.process(node.getValue(), context); ColumnTransformer min = this.process(node.getMin(), context); @@ -334,9 +306,7 @@ public class ColumnTransformerBuilder context.cache.put(node, new BetweenColumnTransformer(BOOLEAN, value, min, max, false)); } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); } @Override @@ -345,15 +315,12 @@ public class ColumnTransformerBuilder if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { ColumnTransformer columnTransformer = context.hasSeen.get(node); - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - columnTransformer.getType(), - context.originSize + context.commonTransformerList.size()); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(getTSDataType(columnTransformer.getType())); - context.cache.put(node, identity); + appendIdentityColumnTransformer( + node, + columnTransformer.getType(), + getTSDataType(columnTransformer.getType()), + context, + columnTransformer); } else { ColumnTransformer child = this.process(node.getExpression(), context); Type type; @@ -369,9 +336,7 @@ public class ColumnTransformerBuilder : new CastFunctionColumnTransformer(type, child, context.sessionInfo.getZoneId())); } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); } @Override @@ -642,24 +607,19 @@ public class ColumnTransformerBuilder if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { ColumnTransformer columnTransformer = context.hasSeen.get(node); - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - columnTransformer.getType(), - context.originSize + context.commonTransformerList.size()); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(getTSDataType(columnTransformer.getType())); - context.cache.put(node, identity); + appendIdentityColumnTransformer( + node, + columnTransformer.getType(), + getTSDataType(columnTransformer.getType()), + context, + columnTransformer); } else { context.cache.put( node, getFunctionColumnTransformer(node.getName().getSuffix(), node.getArguments(), context)); } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); } private ColumnTransformer getFunctionColumnTransformer( @@ -1047,15 +1007,7 @@ public class ColumnTransformerBuilder protected ColumnTransformer visitInPredicate(InPredicate node, Context context) { if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - BOOLEAN, context.originSize + context.commonTransformerList.size()); - ColumnTransformer columnTransformer = context.hasSeen.get(node); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(TSDataType.BOOLEAN); - context.cache.put(node, identity); + appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, context); } else { ColumnTransformer childColumnTransformer = process(node.getValue(), context); TypeEnum childTypeEnum = childColumnTransformer.getType().getTypeEnum(); @@ -1076,9 +1028,7 @@ public class ColumnTransformerBuilder } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); } private static InMultiColumnTransformer constructInColumnTransformer( @@ -1181,38 +1131,20 @@ public class ColumnTransformerBuilder protected ColumnTransformer visitNotExpression(NotExpression node, Context context) { if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - BOOLEAN, context.originSize + context.commonTransformerList.size()); - ColumnTransformer columnTransformer = context.hasSeen.get(node); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(TSDataType.BOOLEAN); - context.cache.put(node, identity); + appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, context); } else { ColumnTransformer childColumnTransformer = process(node.getValue(), context); context.cache.put(node, new LogicNotColumnTransformer(BOOLEAN, childColumnTransformer)); } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); } @Override protected ColumnTransformer visitLikePredicate(LikePredicate node, Context context) { if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - BOOLEAN, context.originSize + context.commonTransformerList.size()); - ColumnTransformer columnTransformer = context.hasSeen.get(node); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(TSDataType.BOOLEAN); - context.cache.put(node, identity); + appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, context); } else { ColumnTransformer likeColumnTransformer = null; ColumnTransformer childColumnTransformer = process(node.getValue(), context); @@ -1247,56 +1179,34 @@ public class ColumnTransformerBuilder context.cache.put(node, likeColumnTransformer); } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); } @Override protected ColumnTransformer visitIsNotNullPredicate(IsNotNullPredicate node, Context context) { if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - BOOLEAN, context.originSize + context.commonTransformerList.size()); - ColumnTransformer columnTransformer = context.hasSeen.get(node); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(TSDataType.BOOLEAN); - context.cache.put(node, identity); + appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, context); } else { ColumnTransformer childColumnTransformer = process(node.getValue(), context); context.cache.put(node, new IsNullColumnTransformer(BOOLEAN, childColumnTransformer, true)); } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); } @Override protected ColumnTransformer visitIsNullPredicate(IsNullPredicate node, Context context) { if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - BOOLEAN, context.originSize + context.commonTransformerList.size()); - ColumnTransformer columnTransformer = context.hasSeen.get(node); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(TSDataType.BOOLEAN); - context.cache.put(node, identity); + appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, context); } else { ColumnTransformer childColumnTransformer = process(node.getValue(), context); context.cache.put( node, new IsNullColumnTransformer(BOOLEAN, childColumnTransformer, false)); } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); } @Override @@ -1366,15 +1276,12 @@ public class ColumnTransformerBuilder if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { ColumnTransformer columnTransformer = context.hasSeen.get(node); - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - columnTransformer.getType(), - context.originSize + context.commonTransformerList.size()); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(getTSDataType(columnTransformer.getType())); - context.cache.put(node, identity); + appendIdentityColumnTransformer( + node, + columnTransformer.getType(), + getTSDataType(columnTransformer.getType()), + context, + columnTransformer); } else { List<ColumnTransformer> children = node.getChildren().stream().map(c -> process(c, context)).collect(Collectors.toList()); @@ -1392,15 +1299,12 @@ public class ColumnTransformerBuilder if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { ColumnTransformer columnTransformer = context.hasSeen.get(node); - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - columnTransformer.getType(), - context.originSize + context.commonTransformerList.size()); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(InternalTypeManager.getTSDataType(columnTransformer.getType())); - context.cache.put(node, identity); + appendIdentityColumnTransformer( + node, + columnTransformer.getType(), + getTSDataType(columnTransformer.getType()), + context, + columnTransformer); } else { List<ColumnTransformer> whenList = new ArrayList<>(); List<ColumnTransformer> thenList = new ArrayList<>(); @@ -1423,9 +1327,7 @@ public class ColumnTransformerBuilder thenList.get(0).getType(), whenList, thenList, elseColumnTransformer)); } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); } @Override @@ -1434,15 +1336,12 @@ public class ColumnTransformerBuilder if (!context.cache.containsKey(node)) { if (context.hasSeen.containsKey(node)) { ColumnTransformer columnTransformer = context.hasSeen.get(node); - IdentityColumnTransformer identity = - new IdentityColumnTransformer( - columnTransformer.getType(), - context.originSize + context.commonTransformerList.size()); - columnTransformer.addReferenceCount(); - context.commonTransformerList.add(columnTransformer); - context.leafList.add(identity); - context.inputDataTypes.add(InternalTypeManager.getTSDataType(columnTransformer.getType())); - context.cache.put(node, identity); + appendIdentityColumnTransformer( + node, + columnTransformer.getType(), + InternalTypeManager.getTSDataType(columnTransformer.getType()), + context, + columnTransformer); } else { List<ColumnTransformer> whenList = new ArrayList<>(); List<ColumnTransformer> thenList = new ArrayList<>(); @@ -1460,9 +1359,7 @@ public class ColumnTransformerBuilder thenList.get(0).getType(), whenList, thenList, elseColumnTransformer)); } } - ColumnTransformer res = context.cache.get(node); - res.addReferenceCount(); - return res; + return getColumnTransformerFromCacheAndAddReferenceCount(node, context); } @Override @@ -1480,6 +1377,35 @@ public class ColumnTransformerBuilder throw new UnsupportedOperationException(String.format(UNSUPPORTED_EXPRESSION, node)); } + private void appendIdentityColumnTransformer( + Expression expression, Type identityReturnType, TSDataType inputType, Context context) { + appendIdentityColumnTransformer( + expression, identityReturnType, inputType, context, context.hasSeen.get(expression)); + } + + private void appendIdentityColumnTransformer( + Expression expression, + Type identityReturnType, + TSDataType inputType, + Context context, + ColumnTransformer columnTransformer) { + IdentityColumnTransformer identity = + new IdentityColumnTransformer( + identityReturnType, context.originSize + context.commonTransformerList.size()); + columnTransformer.addReferenceCount(); + context.commonTransformerList.add(columnTransformer); + context.leafList.add(identity); + context.inputDataTypes.add(inputType); + context.cache.put(expression, identity); + } + + private ColumnTransformer getColumnTransformerFromCacheAndAddReferenceCount( + Expression expression, Context context) { + ColumnTransformer columnTransformer = context.cache.get(expression); + columnTransformer.addReferenceCount(); + return columnTransformer; + } + public static boolean isLongLiteral(Expression expression) { return expression instanceof LongLiteral; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedScalarSubquery.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedScalarSubquery.java new file mode 100644 index 00000000000..3ed74b12bc8 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedScalarSubquery.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AssignUniqueId; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.EnforceSingleRowNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FilterNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.MarkDistinctNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.Cardinality; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Cast; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SimpleCaseExpression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.WhenClause; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import com.google.common.collect.ImmutableList; + +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.PlanNodeSearcher.searchFrom; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.INNER; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.LEFT; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.correlation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.filter; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.correlatedJoin; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.QueryCardinalityUtil.extractCardinality; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL; +import static org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignatureTranslator.toSqlType; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern.nonEmpty; +import static org.apache.tsfile.read.common.type.BooleanType.BOOLEAN; +import static org.apache.tsfile.read.common.type.LongType.INT64; + +/** + * Scalar filter scan query is something like: + * + * <pre> + * SELECT a,b,c FROM rel WHERE a = correlated1 AND b = correlated2 + * </pre> + * + * <p>This optimizer can rewrite to mark distinct and filter over a left outer join: + * + * <p>From: + * + * <pre> + * - CorrelatedJoin (with correlation list: [C]) + * - (input) plan which produces symbols: [A, B, C] + * - (scalar subquery) Project F + * - Filter(D = C AND E > 5) + * - plan which produces symbols: [D, E, F] + * </pre> + * + * to: + * + * <pre> + * - Filter(CASE isDistinct WHEN true THEN true ELSE fail('Scalar sub-query has returned multiple rows')) + * - MarkDistinct(isDistinct) + * - CorrelatedJoin (with correlation list: [C]) + * - AssignUniqueId(adds symbol U) + * - (input) plan which produces symbols: [A, B, C] + * - non scalar subquery + * </pre> + * + * <p>This must be run after aggregation decorrelation rules. + */ +public class TransformCorrelatedScalarSubquery implements Rule<CorrelatedJoinNode> { + private static final Pattern<CorrelatedJoinNode> PATTERN = + correlatedJoin().with(nonEmpty(correlation())).with(filter().equalTo(TRUE_LITERAL)); + + private final Metadata metadata; + + public TransformCorrelatedScalarSubquery(Metadata metadata) { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public Pattern<CorrelatedJoinNode> getPattern() { + return PATTERN; + } + + @Override + public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) { + // lateral references are only allowed for INNER or LEFT correlated join + checkArgument( + correlatedJoinNode.getJoinType() == INNER || correlatedJoinNode.getJoinType() == LEFT, + "unexpected correlated join type: %s", + correlatedJoinNode.getJoinType()); + PlanNode subquery = context.getLookup().resolve(correlatedJoinNode.getSubquery()); + + if (!searchFrom(subquery, context.getLookup()) + .where(EnforceSingleRowNode.class::isInstance) + .recurseOnlyWhen(ProjectNode.class::isInstance) + .matches()) { + return Result.empty(); + } + + PlanNode rewrittenSubquery = + searchFrom(subquery, context.getLookup()) + .where(EnforceSingleRowNode.class::isInstance) + .recurseOnlyWhen(ProjectNode.class::isInstance) + .removeFirst(); + + Cardinality subqueryCardinality = extractCardinality(rewrittenSubquery, context.getLookup()); + boolean producesAtMostOneRow = subqueryCardinality.isAtMostScalar(); + if (producesAtMostOneRow) { + boolean producesSingleRow = subqueryCardinality.isScalar(); + return Result.ofPlanNode( + new CorrelatedJoinNode( + context.getIdAllocator().genPlanNodeId(), + correlatedJoinNode.getInput(), + rewrittenSubquery, + correlatedJoinNode.getCorrelation(), + // EnforceSingleRowNode guarantees that exactly single matching row is produced + // for every input row (independently of correlated join type). Decorrelated plan + // must preserve this semantics. + producesSingleRow ? INNER : LEFT, + correlatedJoinNode.getFilter(), + correlatedJoinNode.getOriginSubquery())); + } + + Symbol unique = context.getSymbolAllocator().newSymbol("unique", INT64); + + CorrelatedJoinNode rewrittenCorrelatedJoinNode = + new CorrelatedJoinNode( + context.getIdAllocator().genPlanNodeId(), + new AssignUniqueId( + context.getIdAllocator().genPlanNodeId(), correlatedJoinNode.getInput(), unique), + rewrittenSubquery, + correlatedJoinNode.getCorrelation(), + LEFT, + correlatedJoinNode.getFilter(), + correlatedJoinNode.getOriginSubquery()); + + Symbol isDistinct = context.getSymbolAllocator().newSymbol("is_distinct", BOOLEAN); + MarkDistinctNode markDistinctNode = + new MarkDistinctNode( + context.getIdAllocator().genPlanNodeId(), + rewrittenCorrelatedJoinNode, + isDistinct, + rewrittenCorrelatedJoinNode.getInput().getOutputSymbols(), + Optional.empty()); + + FilterNode filterNode = + new FilterNode( + context.getIdAllocator().genPlanNodeId(), + markDistinctNode, + new SimpleCaseExpression( + isDistinct.toSymbolReference(), + ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), + Optional.of( + new Cast( + failFunction( + metadata, + SUBQUERY_MULTIPLE_ROWS, + "Scalar sub-query has returned multiple rows"), + toSqlType(BOOLEAN))))); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().genPlanNodeId(), + filterNode, + Assignments.identity(correlatedJoinNode.getOutputSymbols()))); + } +}
