This is an automated email from the ASF dual-hosted git repository.

lancelly pushed a commit to branch remove_duplicate_code_in_column_transformer
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 4662161dca3b63e13f42427cec6883bfbf1ad050
Author: lancelly <[email protected]>
AuthorDate: Sat Jan 25 11:00:36 2025 +0800

    remove duplicate code in ColumnTranformerBuilder
---
 .../relational/ColumnTransformerBuilder.java       | 234 +++++++--------------
 .../rule/TransformCorrelatedScalarSubquery.java    | 189 +++++++++++++++++
 2 files changed, 269 insertions(+), 154 deletions(-)

diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
index 55de5391938..ce69b23c3bc 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
@@ -237,16 +237,8 @@ public class ColumnTransformerBuilder
             throw new UnsupportedOperationException(
                 String.format(UNSUPPORTED_EXPRESSION, node.getOperator()));
         }
-        TSDataType tsDataType = InternalTypeManager.getTSDataType(type);
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                type, context.originSize + 
context.commonTransformerList.size());
-        ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        context.inputDataTypes.add(tsDataType);
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(
+            node, type, InternalTypeManager.getTSDataType(type), context);
       } else {
         ZoneId zoneId = context.sessionInfo.getZoneId();
         ColumnTransformer left = process(node.getLeft(), context);
@@ -276,9 +268,7 @@ public class ColumnTransformerBuilder
         context.cache.put(node, child);
       }
     }
-    ColumnTransformer res = context.cache.get(node);
-    res.addReferenceCount();
-    return res;
+    return getColumnTransformerFromCacheAndAddReferenceCount(node, context);
   }
 
   @Override
@@ -290,15 +280,7 @@ public class ColumnTransformerBuilder
       case MINUS:
         if (!context.cache.containsKey(node)) {
           if (context.hasSeen.containsKey(node)) {
-            IdentityColumnTransformer identity =
-                new IdentityColumnTransformer(
-                    DOUBLE, context.originSize + 
context.commonTransformerList.size());
-            ColumnTransformer columnTransformer = context.hasSeen.get(node);
-            columnTransformer.addReferenceCount();
-            context.commonTransformerList.add(columnTransformer);
-            context.leafList.add(identity);
-            context.inputDataTypes.add(TSDataType.DOUBLE);
-            context.cache.put(node, identity);
+            appendIdentityColumnTransformer(node, DOUBLE, TSDataType.DOUBLE, 
context);
           } else {
             ColumnTransformer childColumnTransformer = 
process(node.getValue(), context);
             context.cache.put(
@@ -306,9 +288,7 @@ public class ColumnTransformerBuilder
                 
ArithmeticColumnTransformerApi.getNegationTransformer(childColumnTransformer));
           }
         }
-        ColumnTransformer res = context.cache.get(node);
-        res.addReferenceCount();
-        return res;
+        return getColumnTransformerFromCacheAndAddReferenceCount(node, 
context);
       default:
         throw new UnsupportedOperationException("Unknown sign: " + 
node.getSign());
     }
@@ -318,15 +298,7 @@ public class ColumnTransformerBuilder
   protected ColumnTransformer visitBetweenPredicate(BetweenPredicate node, 
Context context) {
     if (!context.cache.containsKey(node)) {
       if (context.hasSeen.containsKey(node)) {
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                BOOLEAN, context.originSize + 
context.commonTransformerList.size());
-        ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        context.inputDataTypes.add(TSDataType.BOOLEAN);
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, 
context);
       } else {
         ColumnTransformer value = this.process(node.getValue(), context);
         ColumnTransformer min = this.process(node.getMin(), context);
@@ -334,9 +306,7 @@ public class ColumnTransformerBuilder
         context.cache.put(node, new BetweenColumnTransformer(BOOLEAN, value, 
min, max, false));
       }
     }
-    ColumnTransformer res = context.cache.get(node);
-    res.addReferenceCount();
-    return res;
+    return getColumnTransformerFromCacheAndAddReferenceCount(node, context);
   }
 
   @Override
@@ -345,15 +315,12 @@ public class ColumnTransformerBuilder
     if (!context.cache.containsKey(node)) {
       if (context.hasSeen.containsKey(node)) {
         ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                columnTransformer.getType(),
-                context.originSize + context.commonTransformerList.size());
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        context.inputDataTypes.add(getTSDataType(columnTransformer.getType()));
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(
+            node,
+            columnTransformer.getType(),
+            getTSDataType(columnTransformer.getType()),
+            context,
+            columnTransformer);
       } else {
         ColumnTransformer child = this.process(node.getExpression(), context);
         Type type;
@@ -369,9 +336,7 @@ public class ColumnTransformerBuilder
                 : new CastFunctionColumnTransformer(type, child, 
context.sessionInfo.getZoneId()));
       }
     }
-    ColumnTransformer res = context.cache.get(node);
-    res.addReferenceCount();
-    return res;
+    return getColumnTransformerFromCacheAndAddReferenceCount(node, context);
   }
 
   @Override
@@ -642,24 +607,19 @@ public class ColumnTransformerBuilder
     if (!context.cache.containsKey(node)) {
       if (context.hasSeen.containsKey(node)) {
         ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                columnTransformer.getType(),
-                context.originSize + context.commonTransformerList.size());
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        context.inputDataTypes.add(getTSDataType(columnTransformer.getType()));
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(
+            node,
+            columnTransformer.getType(),
+            getTSDataType(columnTransformer.getType()),
+            context,
+            columnTransformer);
       } else {
         context.cache.put(
             node,
             getFunctionColumnTransformer(node.getName().getSuffix(), 
node.getArguments(), context));
       }
     }
-    ColumnTransformer res = context.cache.get(node);
-    res.addReferenceCount();
-    return res;
+    return getColumnTransformerFromCacheAndAddReferenceCount(node, context);
   }
 
   private ColumnTransformer getFunctionColumnTransformer(
@@ -1047,15 +1007,7 @@ public class ColumnTransformerBuilder
   protected ColumnTransformer visitInPredicate(InPredicate node, Context 
context) {
     if (!context.cache.containsKey(node)) {
       if (context.hasSeen.containsKey(node)) {
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                BOOLEAN, context.originSize + 
context.commonTransformerList.size());
-        ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        context.inputDataTypes.add(TSDataType.BOOLEAN);
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, 
context);
       } else {
         ColumnTransformer childColumnTransformer = process(node.getValue(), 
context);
         TypeEnum childTypeEnum = 
childColumnTransformer.getType().getTypeEnum();
@@ -1076,9 +1028,7 @@ public class ColumnTransformerBuilder
       }
     }
 
-    ColumnTransformer res = context.cache.get(node);
-    res.addReferenceCount();
-    return res;
+    return getColumnTransformerFromCacheAndAddReferenceCount(node, context);
   }
 
   private static InMultiColumnTransformer constructInColumnTransformer(
@@ -1181,38 +1131,20 @@ public class ColumnTransformerBuilder
   protected ColumnTransformer visitNotExpression(NotExpression node, Context 
context) {
     if (!context.cache.containsKey(node)) {
       if (context.hasSeen.containsKey(node)) {
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                BOOLEAN, context.originSize + 
context.commonTransformerList.size());
-        ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        context.inputDataTypes.add(TSDataType.BOOLEAN);
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, 
context);
       } else {
         ColumnTransformer childColumnTransformer = process(node.getValue(), 
context);
         context.cache.put(node, new LogicNotColumnTransformer(BOOLEAN, 
childColumnTransformer));
       }
     }
-    ColumnTransformer res = context.cache.get(node);
-    res.addReferenceCount();
-    return res;
+    return getColumnTransformerFromCacheAndAddReferenceCount(node, context);
   }
 
   @Override
   protected ColumnTransformer visitLikePredicate(LikePredicate node, Context 
context) {
     if (!context.cache.containsKey(node)) {
       if (context.hasSeen.containsKey(node)) {
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                BOOLEAN, context.originSize + 
context.commonTransformerList.size());
-        ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        context.inputDataTypes.add(TSDataType.BOOLEAN);
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, 
context);
       } else {
         ColumnTransformer likeColumnTransformer = null;
         ColumnTransformer childColumnTransformer = process(node.getValue(), 
context);
@@ -1247,56 +1179,34 @@ public class ColumnTransformerBuilder
         context.cache.put(node, likeColumnTransformer);
       }
     }
-    ColumnTransformer res = context.cache.get(node);
-    res.addReferenceCount();
-    return res;
+    return getColumnTransformerFromCacheAndAddReferenceCount(node, context);
   }
 
   @Override
   protected ColumnTransformer visitIsNotNullPredicate(IsNotNullPredicate node, 
Context context) {
     if (!context.cache.containsKey(node)) {
       if (context.hasSeen.containsKey(node)) {
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                BOOLEAN, context.originSize + 
context.commonTransformerList.size());
-        ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        context.inputDataTypes.add(TSDataType.BOOLEAN);
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, 
context);
       } else {
         ColumnTransformer childColumnTransformer = process(node.getValue(), 
context);
         context.cache.put(node, new IsNullColumnTransformer(BOOLEAN, 
childColumnTransformer, true));
       }
     }
-    ColumnTransformer res = context.cache.get(node);
-    res.addReferenceCount();
-    return res;
+    return getColumnTransformerFromCacheAndAddReferenceCount(node, context);
   }
 
   @Override
   protected ColumnTransformer visitIsNullPredicate(IsNullPredicate node, 
Context context) {
     if (!context.cache.containsKey(node)) {
       if (context.hasSeen.containsKey(node)) {
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                BOOLEAN, context.originSize + 
context.commonTransformerList.size());
-        ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        context.inputDataTypes.add(TSDataType.BOOLEAN);
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(node, BOOLEAN, TSDataType.BOOLEAN, 
context);
       } else {
         ColumnTransformer childColumnTransformer = process(node.getValue(), 
context);
         context.cache.put(
             node, new IsNullColumnTransformer(BOOLEAN, childColumnTransformer, 
false));
       }
     }
-    ColumnTransformer res = context.cache.get(node);
-    res.addReferenceCount();
-    return res;
+    return getColumnTransformerFromCacheAndAddReferenceCount(node, context);
   }
 
   @Override
@@ -1366,15 +1276,12 @@ public class ColumnTransformerBuilder
     if (!context.cache.containsKey(node)) {
       if (context.hasSeen.containsKey(node)) {
         ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                columnTransformer.getType(),
-                context.originSize + context.commonTransformerList.size());
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        context.inputDataTypes.add(getTSDataType(columnTransformer.getType()));
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(
+            node,
+            columnTransformer.getType(),
+            getTSDataType(columnTransformer.getType()),
+            context,
+            columnTransformer);
       } else {
         List<ColumnTransformer> children =
             node.getChildren().stream().map(c -> process(c, 
context)).collect(Collectors.toList());
@@ -1392,15 +1299,12 @@ public class ColumnTransformerBuilder
     if (!context.cache.containsKey(node)) {
       if (context.hasSeen.containsKey(node)) {
         ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                columnTransformer.getType(),
-                context.originSize + context.commonTransformerList.size());
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        
context.inputDataTypes.add(InternalTypeManager.getTSDataType(columnTransformer.getType()));
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(
+            node,
+            columnTransformer.getType(),
+            getTSDataType(columnTransformer.getType()),
+            context,
+            columnTransformer);
       } else {
         List<ColumnTransformer> whenList = new ArrayList<>();
         List<ColumnTransformer> thenList = new ArrayList<>();
@@ -1423,9 +1327,7 @@ public class ColumnTransformerBuilder
                 thenList.get(0).getType(), whenList, thenList, 
elseColumnTransformer));
       }
     }
-    ColumnTransformer res = context.cache.get(node);
-    res.addReferenceCount();
-    return res;
+    return getColumnTransformerFromCacheAndAddReferenceCount(node, context);
   }
 
   @Override
@@ -1434,15 +1336,12 @@ public class ColumnTransformerBuilder
     if (!context.cache.containsKey(node)) {
       if (context.hasSeen.containsKey(node)) {
         ColumnTransformer columnTransformer = context.hasSeen.get(node);
-        IdentityColumnTransformer identity =
-            new IdentityColumnTransformer(
-                columnTransformer.getType(),
-                context.originSize + context.commonTransformerList.size());
-        columnTransformer.addReferenceCount();
-        context.commonTransformerList.add(columnTransformer);
-        context.leafList.add(identity);
-        
context.inputDataTypes.add(InternalTypeManager.getTSDataType(columnTransformer.getType()));
-        context.cache.put(node, identity);
+        appendIdentityColumnTransformer(
+            node,
+            columnTransformer.getType(),
+            InternalTypeManager.getTSDataType(columnTransformer.getType()),
+            context,
+            columnTransformer);
       } else {
         List<ColumnTransformer> whenList = new ArrayList<>();
         List<ColumnTransformer> thenList = new ArrayList<>();
@@ -1460,9 +1359,7 @@ public class ColumnTransformerBuilder
                 thenList.get(0).getType(), whenList, thenList, 
elseColumnTransformer));
       }
     }
-    ColumnTransformer res = context.cache.get(node);
-    res.addReferenceCount();
-    return res;
+    return getColumnTransformerFromCacheAndAddReferenceCount(node, context);
   }
 
   @Override
@@ -1480,6 +1377,35 @@ public class ColumnTransformerBuilder
     throw new 
UnsupportedOperationException(String.format(UNSUPPORTED_EXPRESSION, node));
   }
 
+  private void appendIdentityColumnTransformer(
+      Expression expression, Type identityReturnType, TSDataType inputType, 
Context context) {
+    appendIdentityColumnTransformer(
+        expression, identityReturnType, inputType, context, 
context.hasSeen.get(expression));
+  }
+
+  private void appendIdentityColumnTransformer(
+      Expression expression,
+      Type identityReturnType,
+      TSDataType inputType,
+      Context context,
+      ColumnTransformer columnTransformer) {
+    IdentityColumnTransformer identity =
+        new IdentityColumnTransformer(
+            identityReturnType, context.originSize + 
context.commonTransformerList.size());
+    columnTransformer.addReferenceCount();
+    context.commonTransformerList.add(columnTransformer);
+    context.leafList.add(identity);
+    context.inputDataTypes.add(inputType);
+    context.cache.put(expression, identity);
+  }
+
+  private ColumnTransformer getColumnTransformerFromCacheAndAddReferenceCount(
+      Expression expression, Context context) {
+    ColumnTransformer columnTransformer = context.cache.get(expression);
+    columnTransformer.addReferenceCount();
+    return columnTransformer;
+  }
+
   public static boolean isLongLiteral(Expression expression) {
     return expression instanceof LongLiteral;
   }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedScalarSubquery.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedScalarSubquery.java
new file mode 100644
index 00000000000..3ed74b12bc8
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/TransformCorrelatedScalarSubquery.java
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;
+
+import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AssignUniqueId;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.EnforceSingleRowNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FilterNode;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.MarkDistinctNode;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.Cardinality;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Cast;
+import 
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SimpleCaseExpression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.WhenClause;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.Optional;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Objects.requireNonNull;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.PlanNodeSearcher.searchFrom;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.INNER;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.LEFT;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.correlation;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.CorrelatedJoin.filter;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.correlatedJoin;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.QueryCardinalityUtil.extractCardinality;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignatureTranslator.toSqlType;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern.nonEmpty;
+import static org.apache.tsfile.read.common.type.BooleanType.BOOLEAN;
+import static org.apache.tsfile.read.common.type.LongType.INT64;
+
+/**
+ * Scalar filter scan query is something like:
+ *
+ * <pre>
+ *     SELECT a,b,c FROM rel WHERE a = correlated1 AND b = correlated2
+ * </pre>
+ *
+ * <p>This optimizer can rewrite to mark distinct and filter over a left outer 
join:
+ *
+ * <p>From:
+ *
+ * <pre>
+ * - CorrelatedJoin (with correlation list: [C])
+ *   - (input) plan which produces symbols: [A, B, C]
+ *   - (scalar subquery) Project F
+ *     - Filter(D = C AND E > 5)
+ *       - plan which produces symbols: [D, E, F]
+ * </pre>
+ *
+ * to:
+ *
+ * <pre>
+ * - Filter(CASE isDistinct WHEN true THEN true ELSE fail('Scalar sub-query 
has returned multiple rows'))
+ *   - MarkDistinct(isDistinct)
+ *     - CorrelatedJoin (with correlation list: [C])
+ *       - AssignUniqueId(adds symbol U)
+ *         - (input) plan which produces symbols: [A, B, C]
+ *       - non scalar subquery
+ * </pre>
+ *
+ * <p>This must be run after aggregation decorrelation rules.
+ */
+public class TransformCorrelatedScalarSubquery implements 
Rule<CorrelatedJoinNode> {
+  private static final Pattern<CorrelatedJoinNode> PATTERN =
+      
correlatedJoin().with(nonEmpty(correlation())).with(filter().equalTo(TRUE_LITERAL));
+
+  private final Metadata metadata;
+
+  public TransformCorrelatedScalarSubquery(Metadata metadata) {
+    this.metadata = requireNonNull(metadata, "metadata is null");
+  }
+
+  @Override
+  public Pattern<CorrelatedJoinNode> getPattern() {
+    return PATTERN;
+  }
+
+  @Override
+  public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures 
captures, Context context) {
+    // lateral references are only allowed for INNER or LEFT correlated join
+    checkArgument(
+        correlatedJoinNode.getJoinType() == INNER || 
correlatedJoinNode.getJoinType() == LEFT,
+        "unexpected correlated join type: %s",
+        correlatedJoinNode.getJoinType());
+    PlanNode subquery = 
context.getLookup().resolve(correlatedJoinNode.getSubquery());
+
+    if (!searchFrom(subquery, context.getLookup())
+        .where(EnforceSingleRowNode.class::isInstance)
+        .recurseOnlyWhen(ProjectNode.class::isInstance)
+        .matches()) {
+      return Result.empty();
+    }
+
+    PlanNode rewrittenSubquery =
+        searchFrom(subquery, context.getLookup())
+            .where(EnforceSingleRowNode.class::isInstance)
+            .recurseOnlyWhen(ProjectNode.class::isInstance)
+            .removeFirst();
+
+    Cardinality subqueryCardinality = extractCardinality(rewrittenSubquery, 
context.getLookup());
+    boolean producesAtMostOneRow = subqueryCardinality.isAtMostScalar();
+    if (producesAtMostOneRow) {
+      boolean producesSingleRow = subqueryCardinality.isScalar();
+      return Result.ofPlanNode(
+          new CorrelatedJoinNode(
+              context.getIdAllocator().genPlanNodeId(),
+              correlatedJoinNode.getInput(),
+              rewrittenSubquery,
+              correlatedJoinNode.getCorrelation(),
+              // EnforceSingleRowNode guarantees that exactly single matching 
row is produced
+              // for every input row (independently of correlated join type). 
Decorrelated plan
+              // must preserve this semantics.
+              producesSingleRow ? INNER : LEFT,
+              correlatedJoinNode.getFilter(),
+              correlatedJoinNode.getOriginSubquery()));
+    }
+
+    Symbol unique = context.getSymbolAllocator().newSymbol("unique", INT64);
+
+    CorrelatedJoinNode rewrittenCorrelatedJoinNode =
+        new CorrelatedJoinNode(
+            context.getIdAllocator().genPlanNodeId(),
+            new AssignUniqueId(
+                context.getIdAllocator().genPlanNodeId(), 
correlatedJoinNode.getInput(), unique),
+            rewrittenSubquery,
+            correlatedJoinNode.getCorrelation(),
+            LEFT,
+            correlatedJoinNode.getFilter(),
+            correlatedJoinNode.getOriginSubquery());
+
+    Symbol isDistinct = context.getSymbolAllocator().newSymbol("is_distinct", 
BOOLEAN);
+    MarkDistinctNode markDistinctNode =
+        new MarkDistinctNode(
+            context.getIdAllocator().genPlanNodeId(),
+            rewrittenCorrelatedJoinNode,
+            isDistinct,
+            rewrittenCorrelatedJoinNode.getInput().getOutputSymbols(),
+            Optional.empty());
+
+    FilterNode filterNode =
+        new FilterNode(
+            context.getIdAllocator().genPlanNodeId(),
+            markDistinctNode,
+            new SimpleCaseExpression(
+                isDistinct.toSymbolReference(),
+                ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)),
+                Optional.of(
+                    new Cast(
+                        failFunction(
+                            metadata,
+                            SUBQUERY_MULTIPLE_ROWS,
+                            "Scalar sub-query has returned multiple rows"),
+                        toSqlType(BOOLEAN)))));
+
+    return Result.ofPlanNode(
+        new ProjectNode(
+            context.getIdAllocator().genPlanNodeId(),
+            filterNode,
+            Assignments.identity(correlatedJoinNode.getOutputSymbols())));
+  }
+}

Reply via email to