This is an automated email from the ASF dual-hosted git repository.

caogaofei pushed a commit to branch beyyes/join
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/beyyes/join by this push:
     new 459d5717aef add more join pushdown predicate optimize
459d5717aef is described below

commit 459d5717aefe2b11b4d87b29d92e15e38638e10f
Author: Beyyes <[email protected]>
AuthorDate: Mon Aug 19 18:55:49 2024 +0800

    add more join pushdown predicate optimize
---
 .../plan/relational/planner/EqualityInference.java | 422 +++++++++++++++++
 .../plan/relational/planner/ir/AstUtils.java       | 102 +++++
 .../planner/ir/ExpressionNodeInliner.java          |  24 +
 .../plan/relational/planner/ir/IrUtils.java        |   4 +
 .../planner/ir/SubExpressionExtractor.java         |  19 +
 .../planner/iterative/IterativeOptimizer.java      |   2 +-
 .../plan/relational/planner/node/JoinNode.java     |   1 +
 .../planner/optimizations/PlanOptimizer.java       |   2 +-
 .../optimizations/PushPredicateIntoTableScan.java  | 506 ++++++++++++++++++++-
 .../plan/relational/utils/DisjointSet.java         | 114 +++++
 10 files changed, 1192 insertions(+), 4 deletions(-)

diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/EqualityInference.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/EqualityInference.java
new file mode 100644
index 00000000000..1ca72118c5c
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/EqualityInference.java
@@ -0,0 +1,422 @@
+package org.apache.iotdb.db.queryengine.plan.relational.planner;
+
+import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.SubExpressionExtractor;
+import 
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.DisjointSet;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.ImmutableSetMultimap;
+import com.google.common.collect.Multimap;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.function.Predicate;
+import java.util.function.ToIntFunction;
+import java.util.stream.Stream;
+
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.util.Objects.requireNonNull;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.PredicateUtils.extractConjuncts;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.DeterminismEvaluator.isDeterministic;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.ExpressionNodeInliner.replaceExpression;
+
+/**
+ * Makes equality based inferences to rewrite Expressions and generate 
equality sets in terms of
+ * specified symbol scopes
+ */
+public class EqualityInference {
+  // Comparator used to determine Expression preference when determining 
canonicals
+  private final Comparator<Expression> canonicalComparator;
+  private final Multimap<Expression, Expression> equalitySets; // Indexed by 
canonical expression
+  private final Map<Expression, Expression>
+      canonicalMap; // Map each known expression to canonical expression
+  private final Set<Expression> derivedExpressions;
+  private final Map<Expression, List<Expression>> expressionCache = new 
HashMap<>();
+  private final Map<Expression, List<Symbol>> symbolsCache = new HashMap<>();
+  private final Map<Expression, Set<Symbol>> uniqueSymbolsCache = new 
HashMap<>();
+
+  public EqualityInference(Metadata metadata, Expression... expressions) {
+    this(metadata, Arrays.asList(expressions));
+  }
+
+  public EqualityInference(Metadata metadata, Collection<Expression> 
expressions) {
+    DisjointSet<Expression> equalities = new DisjointSet<>();
+    expressions.stream()
+        .flatMap(expression -> extractConjuncts(expression).stream())
+        .filter(expression -> isInferenceCandidate(metadata, expression))
+        .forEach(
+            expression -> {
+              ComparisonExpression comparison = (ComparisonExpression) 
expression;
+              Expression expression1 = comparison.getLeft();
+              Expression expression2 = comparison.getRight();
+
+              equalities.findAndUnion(expression1, expression2);
+            });
+
+    Collection<Set<Expression>> equivalentClasses = 
equalities.getEquivalentClasses();
+
+    // Map every expression to the set of equivalent expressions
+    Map<Expression, Set<Expression>> byExpression = new LinkedHashMap<>();
+    for (Set<Expression> equivalence : equivalentClasses) {
+      equivalence.forEach(expression -> byExpression.put(expression, 
equivalence));
+    }
+
+    // For every non-derived expression, extract the sub-expressions and see 
if they can be
+    // rewritten as other expressions. If so,
+    // use this new information to update the known equalities.
+    Set<Expression> derivedExpressions = new LinkedHashSet<>();
+    for (Expression expression : byExpression.keySet()) {
+      if (derivedExpressions.contains(expression)) {
+        continue;
+      }
+
+      extractSubExpressions(expression).stream()
+          .filter(e -> !e.equals(expression))
+          .forEach(
+              subExpression ->
+                  byExpression.getOrDefault(subExpression, 
ImmutableSet.of()).stream()
+                      .filter(e -> !e.equals(subExpression))
+                      .forEach(
+                          equivalentSubExpression -> {
+                            Expression rewritten =
+                                replaceExpression(
+                                    expression,
+                                    ImmutableMap.of(subExpression, 
equivalentSubExpression));
+                            equalities.findAndUnion(expression, rewritten);
+                            derivedExpressions.add(rewritten);
+                          }));
+    }
+
+    Comparator<Expression> canonicalComparator =
+        Comparator
+            // Current cost heuristic:
+            // 1) Prefer fewer input symbols
+            // 2) Prefer smaller expression trees
+            // 3) Sort the expressions alphabetically - creates a stable 
consistent ordering
+            // (extremely useful for unit testing)
+            // TODO: be more precise in determining the cost of an expression
+            .comparingInt(
+                (ToIntFunction<Expression>) expression -> 
extractAllSymbols(expression).size())
+            .thenComparingLong(expression -> 
extractSubExpressions(expression).size())
+            .thenComparing(Expression::toString);
+
+    Multimap<Expression, Expression> equalitySets =
+        makeEqualitySets(equalities, canonicalComparator);
+
+    ImmutableMap.Builder<Expression, Expression> canonicalMappings = 
ImmutableMap.builder();
+    for (Map.Entry<Expression, Expression> entry : equalitySets.entries()) {
+      Expression canonical = entry.getKey();
+      Expression expression = entry.getValue();
+      canonicalMappings.put(expression, canonical);
+    }
+
+    this.equalitySets = equalitySets;
+    this.canonicalMap = canonicalMappings.buildOrThrow();
+    this.derivedExpressions = derivedExpressions;
+    this.canonicalComparator = canonicalComparator;
+  }
+
+  /**
+   * Attempts to rewrite an Expression in terms of the symbols allowed by the 
symbol scope given the
+   * known equalities. Returns null if unsuccessful.
+   */
+  public Expression rewrite(Expression expression, Set<Symbol> scope) {
+    return rewrite(expression, scope::contains, true);
+  }
+
+  /**
+   * Dumps the inference equalities as equality expressions that are 
partitioned by the symbolScope.
+   * All stored equalities are returned in a compact set and will be 
classified into three groups as
+   * determined by the symbol scope:
+   *
+   * <ol>
+   *   <li>equalities that fit entirely within the symbol scope
+   *   <li>equalities that fit entirely outside of the symbol scope
+   *   <li>equalities that straddle the symbol scope
+   * </ol>
+   *
+   * <pre>
+   * Example:
+   *   Stored Equalities:
+   *     a = b = c
+   *     d = e = f = g
+   *
+   *   Symbol Scope:
+   *     a, b, d, e
+   *
+   *   Output EqualityPartition:
+   *     Scope Equalities:
+   *       a = b
+   *       d = e
+   *     Complement Scope Equalities
+   *       f = g
+   *     Scope Straddling Equalities
+   *       a = c
+   *       d = f
+   * </pre>
+   */
+  public EqualityPartition generateEqualitiesPartitionedBy(Set<Symbol> scope) {
+    ImmutableSet.Builder<Expression> scopeEqualities = ImmutableSet.builder();
+    ImmutableSet.Builder<Expression> scopeComplementEqualities = 
ImmutableSet.builder();
+    ImmutableSet.Builder<Expression> scopeStraddlingEqualities = 
ImmutableSet.builder();
+
+    for (Collection<Expression> equalitySet : equalitySets.asMap().values()) {
+      Set<Expression> scopeExpressions = new LinkedHashSet<>();
+      Set<Expression> scopeComplementExpressions = new LinkedHashSet<>();
+      Set<Expression> scopeStraddlingExpressions = new LinkedHashSet<>();
+
+      // Try to push each non-derived expression into one side of the scope
+      equalitySet.stream()
+          .filter(candidate -> !derivedExpressions.contains(candidate))
+          .forEach(
+              candidate -> {
+                Expression scopeRewritten = rewrite(candidate, 
scope::contains, false);
+                if (scopeRewritten != null) {
+                  scopeExpressions.add(scopeRewritten);
+                }
+                Expression scopeComplementRewritten =
+                    rewrite(candidate, symbol -> !scope.contains(symbol), 
false);
+                if (scopeComplementRewritten != null) {
+                  scopeComplementExpressions.add(scopeComplementRewritten);
+                }
+                if (scopeRewritten == null && scopeComplementRewritten == 
null) {
+                  scopeStraddlingExpressions.add(candidate);
+                }
+              });
+      // Compile the equality expressions on each side of the scope
+      Expression matchingCanonical = getCanonical(scopeExpressions.stream());
+      if (scopeExpressions.size() >= 2) {
+        scopeExpressions.stream()
+            .filter(expression -> !expression.equals(matchingCanonical))
+            .map(
+                expression ->
+                    new ComparisonExpression(
+                        ComparisonExpression.Operator.EQUAL, 
matchingCanonical, expression))
+            .forEach(scopeEqualities::add);
+      }
+      Expression complementCanonical = 
getCanonical(scopeComplementExpressions.stream());
+      if (scopeComplementExpressions.size() >= 2) {
+        scopeComplementExpressions.stream()
+            .filter(expression -> !expression.equals(complementCanonical))
+            .map(
+                expression ->
+                    new ComparisonExpression(
+                        ComparisonExpression.Operator.EQUAL, 
complementCanonical, expression))
+            .forEach(scopeComplementEqualities::add);
+      }
+
+      // Compile single equality between matching and complement scope.
+      // Only consider expressions that don't have derived expression in other 
scope.
+      // Otherwise, redundant equality would be generated.
+      Optional<Expression> matchingConnecting =
+          scopeExpressions.stream()
+              .filter(
+                  expression ->
+                      SymbolsExtractor.extractAll(expression).isEmpty()
+                          || rewrite(expression, symbol -> 
!scope.contains(symbol), false) == null)
+              .min(canonicalComparator);
+      Optional<Expression> complementConnecting =
+          scopeComplementExpressions.stream()
+              .filter(
+                  expression ->
+                      SymbolsExtractor.extractAll(expression).isEmpty()
+                          || rewrite(expression, scope::contains, false) == 
null)
+              .min(canonicalComparator);
+      if (matchingConnecting.isPresent()
+          && complementConnecting.isPresent()
+          && !matchingConnecting.equals(complementConnecting)) {
+        scopeStraddlingEqualities.add(
+            new ComparisonExpression(
+                ComparisonExpression.Operator.EQUAL,
+                matchingConnecting.get(),
+                complementConnecting.get()));
+      }
+
+      // Compile the scope straddling equality expressions.
+      // scopeStraddlingExpressions couldn't be pushed to either side,
+      // therefore there needs to be an equality generated with
+      // one of the scopes (either matching or complement).
+      List<Expression> straddlingExpressions = new ArrayList<>();
+      if (matchingCanonical != null) {
+        straddlingExpressions.add(matchingCanonical);
+      } else if (complementCanonical != null) {
+        straddlingExpressions.add(complementCanonical);
+      }
+      straddlingExpressions.addAll(scopeStraddlingExpressions);
+      Expression connectingCanonical = 
getCanonical(straddlingExpressions.stream());
+      if (connectingCanonical != null) {
+        straddlingExpressions.stream()
+            .filter(expression -> !expression.equals(connectingCanonical))
+            .map(
+                expression ->
+                    new ComparisonExpression(
+                        ComparisonExpression.Operator.EQUAL, 
connectingCanonical, expression))
+            .forEach(scopeStraddlingEqualities::add);
+      }
+    }
+
+    return new EqualityPartition(
+        scopeEqualities.build(),
+        scopeComplementEqualities.build(),
+        scopeStraddlingEqualities.build());
+  }
+
+  /** Determines whether an Expression may be successfully applied to the 
equality inference */
+  public static boolean isInferenceCandidate(Metadata metadata, Expression 
expression) {
+    if (expression instanceof ComparisonExpression && 
isDeterministic(expression)
+    // && !mayReturnNullOnNonNullInput(expression)
+    ) {
+      ComparisonExpression comparison = (ComparisonExpression) expression;
+      if (comparison.getOperator() == ComparisonExpression.Operator.EQUAL) {
+        // We should only consider equalities that have distinct left and 
right components
+        return !comparison.getLeft().equals(comparison.getRight());
+      }
+    }
+    return false;
+  }
+
+  /**
+   * Provides a convenience Stream of Expression conjuncts which have not been 
added to the
+   * inference
+   */
+  public static Stream<Expression> nonInferrableConjuncts(
+      Metadata metadata, Expression expression) {
+    return extractConjuncts(expression).stream().filter(e -> 
!isInferenceCandidate(metadata, e));
+  }
+
+  private Expression rewrite(
+      Expression expression, Predicate<Symbol> symbolScope, boolean 
allowFullReplacement) {
+    Map<Expression, Expression> expressionRemap = new HashMap<>();
+    extractSubExpressions(expression).stream()
+        .filter(
+            allowFullReplacement
+                ? subExpression -> true
+                : subExpression -> !subExpression.equals(expression))
+        .forEach(
+            subExpression -> {
+              Expression canonical = getScopedCanonical(subExpression, 
symbolScope);
+              if (canonical != null) {
+                expressionRemap.putIfAbsent(subExpression, canonical);
+              }
+            });
+
+    // Perform a naive single-pass traversal to try to rewrite non-compliant 
portions of the tree.
+    // Prefers to replace
+    // larger subtrees over smaller subtrees
+    // TODO: this rewrite can probably be made more sophisticated
+    Expression rewritten = replaceExpression(expression, expressionRemap);
+    if (!isScoped(rewritten, symbolScope)) {
+      // If the rewritten is still not compliant with the symbol scope, just 
give up
+      return null;
+    }
+    return rewritten;
+  }
+
+  /** Returns the most preferrable expression to be used as the canonical 
expression */
+  private Expression getCanonical(Stream<Expression> expressions) {
+    return expressions.min(canonicalComparator).orElse(null);
+  }
+
+  /**
+   * Returns a canonical expression that is fully contained by the symbolScope 
and that is
+   * equivalent to the specified expression. Returns null if unable to find a 
canonical.
+   */
+  @VisibleForTesting
+  Expression getScopedCanonical(Expression expression, Predicate<Symbol> 
symbolScope) {
+    Expression canonicalIndex = canonicalMap.get(expression);
+    if (canonicalIndex == null) {
+      return null;
+    }
+
+    Collection<Expression> equivalences = equalitySets.get(canonicalIndex);
+    if (expression instanceof SymbolReference) {
+      boolean inScope =
+          equivalences.stream()
+              .filter(SymbolReference.class::isInstance)
+              .map(Symbol::from)
+              .anyMatch(symbolScope);
+
+      if (!inScope) {
+        return null;
+      }
+    }
+
+    return getCanonical(equivalences.stream().filter(e -> isScoped(e, 
symbolScope)));
+  }
+
+  private boolean isScoped(Expression expression, Predicate<Symbol> 
symbolScope) {
+    return extractUniqueSymbols(expression).stream().allMatch(symbolScope);
+  }
+
+  private static Multimap<Expression, Expression> makeEqualitySets(
+      DisjointSet<Expression> equalities, Comparator<Expression> 
canonicalComparator) {
+    ImmutableSetMultimap.Builder<Expression, Expression> builder = 
ImmutableSetMultimap.builder();
+    for (Set<Expression> equalityGroup : equalities.getEquivalentClasses()) {
+      if (!equalityGroup.isEmpty()) {
+        builder.putAll(equalityGroup.stream().min(canonicalComparator).get(), 
equalityGroup);
+      }
+    }
+    return builder.build();
+  }
+
+  private List<Expression> extractSubExpressions(Expression expression) {
+    return expressionCache.computeIfAbsent(
+        expression, e -> 
SubExpressionExtractor.extract(e).collect(toImmutableList()));
+  }
+
+  private Set<Symbol> extractUniqueSymbols(Expression expression) {
+    return uniqueSymbolsCache.computeIfAbsent(
+        expression, e -> ImmutableSet.copyOf(extractAllSymbols(expression)));
+  }
+
+  private List<Symbol> extractAllSymbols(Expression expression) {
+    return symbolsCache.computeIfAbsent(expression, 
SymbolsExtractor::extractAll);
+  }
+
+  public static class EqualityPartition {
+    private final List<Expression> scopeEqualities;
+    private final List<Expression> scopeComplementEqualities;
+    private final List<Expression> scopeStraddlingEqualities;
+
+    public EqualityPartition(
+        Iterable<Expression> scopeEqualities,
+        Iterable<Expression> scopeComplementEqualities,
+        Iterable<Expression> scopeStraddlingEqualities) {
+      this.scopeEqualities =
+          ImmutableList.copyOf(requireNonNull(scopeEqualities, 
"scopeEqualities is null"));
+      this.scopeComplementEqualities =
+          ImmutableList.copyOf(
+              requireNonNull(scopeComplementEqualities, 
"scopeComplementEqualities is null"));
+      this.scopeStraddlingEqualities =
+          ImmutableList.copyOf(
+              requireNonNull(scopeStraddlingEqualities, 
"scopeStraddlingEqualities is null"));
+    }
+
+    public List<Expression> getScopeEqualities() {
+      return scopeEqualities;
+    }
+
+    public List<Expression> getScopeComplementEqualities() {
+      return scopeComplementEqualities;
+    }
+
+    public List<Expression> getScopeStraddlingEqualities() {
+      return scopeStraddlingEqualities;
+    }
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/AstUtils.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/AstUtils.java
new file mode 100644
index 00000000000..7e47072947c
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/AstUtils.java
@@ -0,0 +1,102 @@
+package org.apache.iotdb.db.queryengine.plan.relational.planner.ir;
+
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import 
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LogicalExpression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Node;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.graph.SuccessorsFunction;
+import com.google.common.graph.Traverser;
+
+import java.util.List;
+import java.util.OptionalInt;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import java.util.stream.Stream;
+
+import static com.google.common.collect.Streams.stream;
+import static java.util.Objects.requireNonNull;
+
+public final class AstUtils {
+  public static Stream<Node> preOrder(Node node) {
+    return stream(
+        Traverser.forTree((SuccessorsFunction<Node>) Node::getChildren)
+            .depthFirstPreOrder(requireNonNull(node, "node is null")));
+  }
+
+  /**
+   * Compares two AST trees recursively by applying the provided comparator to 
each pair of nodes.
+   *
+   * <p>The comparator can perform a hybrid shallow/deep comparison. If it 
returns true or false,
+   * the nodes and any subtrees are considered equal or different, 
respectively. If it returns null,
+   * the nodes are considered shallowly-equal and their children will be 
compared recursively.
+   */
+  public static boolean treeEqual(
+      Node left, Node right, BiFunction<Node, Node, Boolean> 
subtreeComparator) {
+    Boolean equal = subtreeComparator.apply(left, right);
+
+    if (equal != null) {
+      return equal;
+    }
+
+    List<? extends Node> leftChildren = left.getChildren();
+    List<? extends Node> rightChildren = right.getChildren();
+
+    if (leftChildren.size() != rightChildren.size()) {
+      return false;
+    }
+
+    for (int i = 0; i < leftChildren.size(); i++) {
+      if (!treeEqual(leftChildren.get(i), rightChildren.get(i), 
subtreeComparator)) {
+        return false;
+      }
+    }
+
+    return true;
+  }
+
+  /**
+   * Computes a hash of the given AST by applying the provided subtree hasher 
at each level.
+   *
+   * <p>If the hasher returns a non-empty {@link OptionalInt}, the value is 
treated as the hash for
+   * the subtree at that node. Otherwise, the hashes of its children are 
computed and combined.
+   */
+  public static int treeHash(Node node, Function<Node, OptionalInt> 
subtreeHasher) {
+    OptionalInt hash = subtreeHasher.apply(node);
+
+    if (hash.isPresent()) {
+      return hash.getAsInt();
+    }
+
+    List<? extends Node> children = node.getChildren();
+
+    int result = node.getClass().hashCode();
+    for (Node element : children) {
+      result = 31 * result + treeHash(element, subtreeHasher);
+    }
+
+    return result;
+  }
+
+  public static List<Expression> extractConjuncts(Expression expression) {
+    ImmutableList.Builder<Expression> resultBuilder = ImmutableList.builder();
+    extractPredicates(LogicalExpression.Operator.AND, expression, 
resultBuilder);
+    return resultBuilder.build();
+  }
+
+  private static void extractPredicates(
+      LogicalExpression.Operator operator,
+      Expression expression,
+      ImmutableList.Builder<Expression> resultBuilder) {
+    if (expression instanceof LogicalExpression
+        && ((LogicalExpression) expression).getOperator() == operator) {
+      for (Expression term : ((LogicalExpression) expression).getTerms()) {
+        extractPredicates(operator, term, resultBuilder);
+      }
+    } else {
+      resultBuilder.add(expression);
+    }
+  }
+
+  private AstUtils() {}
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/ExpressionNodeInliner.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/ExpressionNodeInliner.java
new file mode 100644
index 00000000000..813d22cfdba
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/ExpressionNodeInliner.java
@@ -0,0 +1,24 @@
+package org.apache.iotdb.db.queryengine.plan.relational.planner.ir;
+
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+
+import java.util.Map;
+
+public class ExpressionNodeInliner extends ExpressionRewriter<Void> {
+  public static Expression replaceExpression(
+      Expression expression, Map<? extends Expression, ? extends Expression> 
mappings) {
+    return ExpressionTreeRewriter.rewriteWith(new 
ExpressionNodeInliner(mappings), expression);
+  }
+
+  private final Map<? extends Expression, ? extends Expression> mappings;
+
+  public ExpressionNodeInliner(Map<? extends Expression, ? extends Expression> 
mappings) {
+    this.mappings = mappings;
+  }
+
+  @Override
+  protected Expression rewriteExpression(
+      Expression node, Void context, ExpressionTreeRewriter<Void> 
treeRewriter) {
+    return mappings.get(node);
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
index e0012a2db8b..ec0ad3fbb48 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
@@ -152,6 +152,10 @@ public final class IrUtils {
     return combineConjuncts(Arrays.asList(expressions));
   }
 
+  public static Expression filterDeterministicConjuncts(Expression expression) 
{
+    return filterConjuncts(expression, DeterminismEvaluator::isDeterministic);
+  }
+
   public static Expression combineConjuncts(Collection<Expression> 
expressions) {
     requireNonNull(expressions, "expressions is null");
 
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/SubExpressionExtractor.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/SubExpressionExtractor.java
new file mode 100644
index 00000000000..13842e5d883
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/SubExpressionExtractor.java
@@ -0,0 +1,19 @@
+package org.apache.iotdb.db.queryengine.plan.relational.planner.ir;
+
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+
+import java.util.stream.Stream;
+
+/**
+ * Extracts and returns the stream of all expression subtrees within an 
Expression, including
+ * Expression itself
+ */
+public final class SubExpressionExtractor {
+  private SubExpressionExtractor() {}
+
+  public static Stream<Expression> extract(Expression expression) {
+    return AstUtils.preOrder(expression)
+        .filter(Expression.class::isInstance)
+        .map(Expression.class::cast);
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/IterativeOptimizer.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/IterativeOptimizer.java
index 89de7c169c1..04327b6ef0b 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/IterativeOptimizer.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/IterativeOptimizer.java
@@ -106,7 +106,7 @@ public class IterativeOptimizer implements 
AdaptivePlanOptimizer {
             memo,
             lookup,
             context.idAllocator(),
-            context.symbolAllocator(),
+            context.getSymbolAllocator(),
             nanoTime(),
             timeout.toMillis(),
             context.sessionInfo(),
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
index 25eafc4c626..45358d7d9e9 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
@@ -33,6 +33,7 @@ public class JoinNode extends TwoChildProcessNode {
   private final List<Symbol> leftOutputSymbols;
   private final List<Symbol> rightOutputSymbols;
   private final boolean maySkipOutputDuplicates;
+  // some filter like 'a.xx_column < b.yy_column'
   private final Optional<Expression> filter;
   private final Optional<Symbol> leftHashSymbol;
   private final Optional<Symbol> rightHashSymbol;
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PlanOptimizer.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PlanOptimizer.java
index 003e7d12277..31641432b3a 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PlanOptimizer.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PlanOptimizer.java
@@ -82,7 +82,7 @@ public interface PlanOptimizer {
       return typeProvider;
     }
 
-    public SymbolAllocator symbolAllocator() {
+    public SymbolAllocator getSymbolAllocator() {
       return symbolAllocator;
     }
 
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
index 5a0f9e3bc95..9fc49530418 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
@@ -34,31 +34,56 @@ import 
org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate.Predic
 import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnSchema;
 import org.apache.iotdb.db.queryengine.plan.relational.metadata.DeviceEntry;
 import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.EqualityInference;
 import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator;
 import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FilterNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.TableScanNode;
+import 
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
 import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
 import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.FunctionCall;
 import 
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LogicalExpression;
 import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Node;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
 
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 import org.apache.tsfile.read.filter.basic.Filter;
 import org.apache.tsfile.utils.Pair;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
+import java.util.EnumSet;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.collect.ImmutableMap.toImmutableMap;
 import static java.util.Objects.requireNonNull;
 import static 
org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory.ATTRIBUTE;
 import static 
org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory.MEASUREMENT;
 import static 
org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory.TIME;
 import static 
org.apache.iotdb.db.queryengine.plan.analyze.AnalyzeVisitor.getTimePartitionSlotList;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.PredicateUtils.combineConjuncts;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.PredicateUtils.extractConjuncts;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolsExtractor.extractUnique;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.DeterminismEvaluator.isDeterministic;
 import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.GlobalTimePredicateExtractVisitor.extractGlobalTimeFilter;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils.filterDeterministicConjuncts;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.FULL;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.INNER;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.LEFT;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.RIGHT;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL;
 
 /**
  * <b>Optimization phase:</b> Logical plan planning.
@@ -88,7 +113,11 @@ public class PushPredicateIntoTableScan implements 
PlanOptimizer {
   @Override
   public PlanNode optimize(PlanNode plan, Context context) {
     return plan.accept(
-        new Rewriter(context.getQueryContext(), context.getAnalysis(), 
context.getMetadata()),
+        new Rewriter(
+            context.getQueryContext(),
+            context.getAnalysis(),
+            context.getMetadata(),
+            context.getSymbolAllocator()),
         null);
   }
 
@@ -97,11 +126,17 @@ public class PushPredicateIntoTableScan implements 
PlanOptimizer {
     private final Analysis analysis;
     private final Metadata metadata;
     private Expression predicate;
+    private SymbolAllocator symbolAllocator;
 
-    Rewriter(MPPQueryContext queryContext, Analysis analysis, Metadata 
metadata) {
+    Rewriter(
+        MPPQueryContext queryContext,
+        Analysis analysis,
+        Metadata metadata,
+        SymbolAllocator symbolAllocator) {
       this.queryContext = queryContext;
       this.analysis = analysis;
       this.metadata = metadata;
+      this.symbolAllocator = symbolAllocator;
     }
 
     @Override
@@ -150,6 +185,8 @@ public class PushPredicateIntoTableScan implements 
PlanOptimizer {
         if (node.getChild() instanceof TableScanNode) {
           // child of FilterNode is TableScanNode, means FilterNode must get 
from where clause
           return combineFilterAndScan((TableScanNode) node.getChild());
+        } else if (node.getChild() instanceof JoinNode) {
+          return visitJoin((JoinNode) node.getChild(), context);
         } else {
           // FilterNode may get from having or subquery
           node.setChild(node.getChild().accept(this, context));
@@ -256,6 +293,437 @@ public class PushPredicateIntoTableScan implements 
PlanOptimizer {
           metadataExpressions, expressionsCanPushDown, 
expressionsCannotPushDown);
     }
 
+    @Override
+    public PlanNode visitJoin(JoinNode node, Void context) {
+      Expression inheritedPredicate = predicate;
+
+      // See if we can rewrite outer joins in terms of a plain inner join
+      node = tryNormalizeToOuterToInnerJoin(node, inheritedPredicate);
+
+      Expression leftEffectivePredicate = TRUE_LITERAL;
+      // effectivePredicateExtractor.extract(session, node.getLeftChild(), 
types, typeAnalyzer);
+      Expression rightEffectivePredicate = TRUE_LITERAL;
+      // effectivePredicateExtractor.extract(session, node.getRightChild(), 
types, typeAnalyzer);
+      Expression joinPredicate = extractJoinPredicate(node);
+
+      Expression leftPredicate;
+      Expression rightPredicate;
+      Expression postJoinPredicate;
+      Expression newJoinPredicate;
+
+      switch (node.getJoinType()) {
+        case INNER:
+          InnerJoinPushDownResult innerJoinPushDownResult =
+              processInnerJoin(
+                  inheritedPredicate,
+                  leftEffectivePredicate,
+                  rightEffectivePredicate,
+                  joinPredicate,
+                  node.getLeftChild().getOutputSymbols(),
+                  node.getRightChild().getOutputSymbols());
+          leftPredicate = innerJoinPushDownResult.getLeftPredicate();
+          rightPredicate = innerJoinPushDownResult.getRightPredicate();
+          postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
+          newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
+          break;
+        default:
+          throw new IllegalStateException("Only support INNER JOIN in current 
version");
+      }
+
+      // newJoinPredicate = simplifyExpression(newJoinPredicate);
+
+      // Create identity projections for all existing symbols
+      Assignments.Builder leftProjections = Assignments.builder();
+      leftProjections.putAll(
+          node.getLeftChild().getOutputSymbols().stream()
+              .collect(toImmutableMap(key -> key, Symbol::toSymbolReference)));
+
+      Assignments.Builder rightProjections = Assignments.builder();
+      rightProjections.putAll(
+          node.getRightChild().getOutputSymbols().stream()
+              .collect(toImmutableMap(key -> key, Symbol::toSymbolReference)));
+
+      // Create new projections for the new join clauses
+      List<JoinNode.EquiJoinClause> equiJoinClauses = new ArrayList<>();
+      ImmutableList.Builder<Expression> joinFilterBuilder = 
ImmutableList.builder();
+      for (Expression conjunct : extractConjuncts(newJoinPredicate)) {
+        if (joinEqualityExpression(
+            conjunct,
+            node.getLeftChild().getOutputSymbols(),
+            node.getRightChild().getOutputSymbols())) {
+          ComparisonExpression equality = (ComparisonExpression) conjunct;
+
+          boolean alignedComparison =
+              
node.getLeftChild().getOutputSymbols().containsAll(extractUnique(equality.getLeft()));
+          Expression leftExpression = alignedComparison ? equality.getLeft() : 
equality.getRight();
+          Expression rightExpression = alignedComparison ? equality.getRight() 
: equality.getLeft();
+
+          Symbol leftSymbol = symbolForExpression(leftExpression);
+          if (!node.getLeftChild().getOutputSymbols().contains(leftSymbol)) {
+            leftProjections.put(leftSymbol, leftExpression);
+          }
+
+          Symbol rightSymbol = symbolForExpression(rightExpression);
+          if (!node.getRightChild().getOutputSymbols().contains(rightSymbol)) {
+            rightProjections.put(rightSymbol, rightExpression);
+          }
+
+          equiJoinClauses.add(new JoinNode.EquiJoinClause(leftSymbol, 
rightSymbol));
+        } else {
+          joinFilterBuilder.add(conjunct);
+        }
+      }
+
+      List<Expression> joinFilter = joinFilterBuilder.build();
+      //      DynamicFiltersResult dynamicFiltersResult = 
createDynamicFilters(node,
+      // equiJoinClauses, joinFilter, session, idAllocator);
+      //      Map<DynamicFilterId, Symbol> dynamicFilters =
+      // dynamicFiltersResult.getDynamicFilters();
+      // leftPredicate = combineConjuncts(metadata, leftPredicate, 
combineConjuncts(metadata,
+      // dynamicFiltersResult.getPredicates()));
+
+      PlanNode leftSource = node.getLeftChild();
+      PlanNode rightSource = node.getRightChild();
+      boolean equiJoinClausesUnmodified =
+          
ImmutableSet.copyOf(equiJoinClauses).equals(ImmutableSet.copyOf(node.getCriteria()));
+      if (!equiJoinClausesUnmodified) {
+        // leftSource = context.rewrite(new 
ProjectNode(queryContext.getQueryId().genPlanNodeId(),
+        // node.getLeftChild(), leftProjections.build()), leftPredicate);
+        // rightSource = context.rewrite(new 
ProjectNode(queryContext.getQueryId().genPlanNodeId(),
+        // node.getRightChild(), rightProjections.build()), rightPredicate);
+      } else {
+        // leftSource = context.rewrite(node.getLeftChild(), leftPredicate);
+        // rightSource = context.rewrite(node.getRightChild(), rightPredicate);
+        // TODO rewrite
+      }
+
+      Optional<Expression> newJoinFilter = 
Optional.of(combineConjuncts(joinFilter));
+      if (newJoinFilter.get().equals(TRUE_LITERAL)) {
+        newJoinFilter = Optional.empty();
+      }
+
+      if (node.getJoinType() == INNER && newJoinFilter.isPresent() && 
equiJoinClauses.isEmpty()) {
+        // if we do not have any equi conjunct we do not pushdown non-equality 
condition into
+        // inner join, so we plan execution as nested-loops-join followed by 
filter instead
+        // hash join.
+        // todo: remove the code when we have support for filter function in 
nested loop join
+        // postJoinPredicate = combineConjuncts(postJoinPredicate, 
newJoinFilter.get());
+        newJoinFilter = Optional.empty();
+      }
+
+      boolean filtersEquivalent =
+          newJoinFilter.isPresent() == node.getFilter().isPresent() && 
(!newJoinFilter.isPresent());
+      // areExpressionsEquivalent(newJoinFilter.get(), node.getFilter().get());
+
+      PlanNode output = node;
+      if (leftSource != node.getLeftChild()
+          || rightSource != node.getRightChild()
+          || !filtersEquivalent
+          ||
+          // !dynamicFilters.equals(node.getDynamicFilters()) ||
+          !equiJoinClausesUnmodified) {
+        leftSource =
+            new ProjectNode(
+                queryContext.getQueryId().genPlanNodeId(), leftSource, 
leftProjections.build());
+        rightSource =
+            new ProjectNode(
+                queryContext.getQueryId().genPlanNodeId(), rightSource, 
rightProjections.build());
+
+        output =
+            new JoinNode(
+                node.getPlanNodeId(),
+                node.getJoinType(),
+                leftSource,
+                rightSource,
+                equiJoinClauses,
+                leftSource.getOutputSymbols(),
+                rightSource.getOutputSymbols(),
+                node.isMaySkipOutputDuplicates(),
+                newJoinFilter,
+                node.getLeftHashSymbol(),
+                node.getRightHashSymbol(),
+                node.isSpillable());
+      }
+
+      if (!postJoinPredicate.equals(TRUE_LITERAL)) {
+        output =
+            new FilterNode(queryContext.getQueryId().genPlanNodeId(), output, 
postJoinPredicate);
+      }
+
+      if (!node.getOutputSymbols().equals(output.getOutputSymbols())) {
+        output =
+            new ProjectNode(
+                queryContext.getQueryId().genPlanNodeId(),
+                output,
+                Assignments.identity(node.getOutputSymbols()));
+      }
+
+      return output;
+    }
+
+    private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression 
inheritedPredicate) {
+      checkArgument(
+          EnumSet.of(INNER, RIGHT, LEFT, FULL).contains(node.getJoinType()),
+          "Unsupported join type: %s",
+          node.getJoinType());
+
+      if (node.getJoinType() == JoinNode.JoinType.INNER) {
+        return node;
+      }
+
+      if (node.getJoinType() == JoinNode.JoinType.FULL) {
+        boolean canConvertToLeftJoin =
+            canConvertOuterToInner(node.getLeftChild().getOutputSymbols(), 
inheritedPredicate);
+        boolean canConvertToRightJoin =
+            canConvertOuterToInner(node.getRightChild().getOutputSymbols(), 
inheritedPredicate);
+        if (!canConvertToLeftJoin && !canConvertToRightJoin) {
+          return node;
+        }
+        if (canConvertToLeftJoin && canConvertToRightJoin) {
+          return new JoinNode(
+              node.getPlanNodeId(),
+              INNER,
+              node.getLeftChild(),
+              node.getRightChild(),
+              node.getCriteria(),
+              node.getLeftOutputSymbols(),
+              node.getRightOutputSymbols(),
+              node.isMaySkipOutputDuplicates(),
+              node.getFilter(),
+              node.getLeftHashSymbol(),
+              node.getRightHashSymbol(),
+              node.isSpillable());
+        }
+        return new JoinNode(
+            node.getPlanNodeId(),
+            canConvertToLeftJoin ? LEFT : RIGHT,
+            node.getLeftChild(),
+            node.getRightChild(),
+            node.getCriteria(),
+            node.getLeftOutputSymbols(),
+            node.getRightOutputSymbols(),
+            node.isMaySkipOutputDuplicates(),
+            node.getFilter(),
+            node.getLeftHashSymbol(),
+            node.getRightHashSymbol(),
+            node.isSpillable());
+      }
+
+      if (node.getJoinType() == JoinNode.JoinType.LEFT
+              && !canConvertOuterToInner(
+                  node.getRightChild().getOutputSymbols(), inheritedPredicate)
+          || node.getJoinType() == JoinNode.JoinType.RIGHT
+              && !canConvertOuterToInner(
+                  node.getLeftChild().getOutputSymbols(), inheritedPredicate)) 
{
+        return node;
+      }
+      return new JoinNode(
+          node.getPlanNodeId(),
+          JoinNode.JoinType.INNER,
+          node.getLeftChild(),
+          node.getRightChild(),
+          node.getCriteria(),
+          node.getLeftOutputSymbols(),
+          node.getRightOutputSymbols(),
+          node.isMaySkipOutputDuplicates(),
+          node.getFilter(),
+          node.getLeftHashSymbol(),
+          node.getRightHashSymbol(),
+          node.isSpillable());
+    }
+
+    private boolean canConvertOuterToInner(
+        List<Symbol> innerSymbolsForOuterJoin, Expression inheritedPredicate) {
+      Set<Symbol> innerSymbols = ImmutableSet.copyOf(innerSymbolsForOuterJoin);
+      for (Expression conjunct : extractConjuncts(inheritedPredicate)) {
+        if (isDeterministic(conjunct)) {
+          // Ignore a conjunct for this test if we cannot deterministically 
get responses from it
+          // Object response = nullInputEvaluator(innerSymbols, conjunct);
+          // if (response == null || response instanceof NullLiteral ||
+          // Boolean.FALSE.equals(response)) {
+          // If there is a single conjunct that returns FALSE or NULL given 
all NULL inputs for the
+          // inner side symbols of an outer join
+          // then this conjunct removes all effects of the outer join, and 
effectively turns this
+          // into an equivalent of an inner join.
+          // So, let's just rewrite this join as an INNER join
+          return true;
+          // }
+        }
+      }
+      return false;
+    }
+
+    private InnerJoinPushDownResult processInnerJoin(
+        Expression inheritedPredicate,
+        Expression leftEffectivePredicate,
+        Expression rightEffectivePredicate,
+        Expression joinPredicate,
+        Collection<Symbol> leftSymbols,
+        Collection<Symbol> rightSymbols) {
+      checkArgument(
+          leftSymbols.containsAll(extractUnique(leftEffectivePredicate)),
+          "leftEffectivePredicate must only contain symbols from leftSymbols");
+      checkArgument(
+          rightSymbols.containsAll(extractUnique(rightEffectivePredicate)),
+          "rightEffectivePredicate must only contain symbols from 
rightSymbols");
+
+      ImmutableList.Builder<Expression> leftPushDownConjuncts = 
ImmutableList.builder();
+      ImmutableList.Builder<Expression> rightPushDownConjuncts = 
ImmutableList.builder();
+      ImmutableList.Builder<Expression> joinConjuncts = 
ImmutableList.builder();
+
+      // Strip out non-deterministic conjuncts
+      extractConjuncts(inheritedPredicate).stream()
+          .filter(deterministic -> !isDeterministic(deterministic))
+          .forEach(joinConjuncts::add);
+      inheritedPredicate = filterDeterministicConjuncts(inheritedPredicate);
+
+      extractConjuncts(joinPredicate).stream()
+          .filter(expression -> !isDeterministic(expression))
+          .forEach(joinConjuncts::add);
+      joinPredicate = filterDeterministicConjuncts(joinPredicate);
+
+      leftEffectivePredicate = 
filterDeterministicConjuncts(leftEffectivePredicate);
+      rightEffectivePredicate = 
filterDeterministicConjuncts(rightEffectivePredicate);
+
+      ImmutableSet<Symbol> leftScope = ImmutableSet.copyOf(leftSymbols);
+      ImmutableSet<Symbol> rightScope = ImmutableSet.copyOf(rightSymbols);
+
+      // Generate equality inferences
+      EqualityInference allInference =
+          new EqualityInference(
+              metadata,
+              inheritedPredicate,
+              leftEffectivePredicate,
+              rightEffectivePredicate,
+              joinPredicate);
+      EqualityInference allInferenceWithoutLeftInferred =
+          new EqualityInference(
+              metadata, inheritedPredicate, rightEffectivePredicate, 
joinPredicate);
+      EqualityInference allInferenceWithoutRightInferred =
+          new EqualityInference(
+              metadata, inheritedPredicate, leftEffectivePredicate, 
joinPredicate);
+
+      // Add equalities from the inference back in
+      leftPushDownConjuncts.addAll(
+          allInferenceWithoutLeftInferred
+              .generateEqualitiesPartitionedBy(leftScope)
+              .getScopeEqualities());
+      rightPushDownConjuncts.addAll(
+          allInferenceWithoutRightInferred
+              .generateEqualitiesPartitionedBy(rightScope)
+              .getScopeEqualities());
+      joinConjuncts.addAll(
+          allInference
+              .generateEqualitiesPartitionedBy(leftScope)
+              .getScopeStraddlingEqualities()); // scope straddling equalities 
get dropped in as
+      // part of the join predicate
+
+      // Sort through conjuncts in inheritedPredicate that were not used for 
inference
+      EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate)
+          .forEach(
+              conjunct -> {
+                Expression leftRewrittenConjunct = 
allInference.rewrite(conjunct, leftScope);
+                if (leftRewrittenConjunct != null) {
+                  leftPushDownConjuncts.add(leftRewrittenConjunct);
+                }
+
+                Expression rightRewrittenConjunct = 
allInference.rewrite(conjunct, rightScope);
+                if (rightRewrittenConjunct != null) {
+                  rightPushDownConjuncts.add(rightRewrittenConjunct);
+                }
+
+                // Drop predicate after join only if unable to push down to 
either side
+                if (leftRewrittenConjunct == null && rightRewrittenConjunct == 
null) {
+                  joinConjuncts.add(conjunct);
+                }
+              });
+
+      // See if we can push the right effective predicate to the left side
+      EqualityInference.nonInferrableConjuncts(metadata, 
rightEffectivePredicate)
+          .map(conjunct -> allInference.rewrite(conjunct, leftScope))
+          .filter(Objects::nonNull)
+          .forEach(leftPushDownConjuncts::add);
+
+      // See if we can push the left effective predicate to the right side
+      EqualityInference.nonInferrableConjuncts(metadata, 
leftEffectivePredicate)
+          .map(conjunct -> allInference.rewrite(conjunct, rightScope))
+          .filter(Objects::nonNull)
+          .forEach(rightPushDownConjuncts::add);
+
+      // See if we can push any parts of the join predicates to either side
+      EqualityInference.nonInferrableConjuncts(metadata, joinPredicate)
+          .forEach(
+              conjunct -> {
+                Expression leftRewritten = allInference.rewrite(conjunct, 
leftScope);
+                if (leftRewritten != null) {
+                  leftPushDownConjuncts.add(leftRewritten);
+                }
+
+                Expression rightRewritten = allInference.rewrite(conjunct, 
rightScope);
+                if (rightRewritten != null) {
+                  rightPushDownConjuncts.add(rightRewritten);
+                }
+
+                if (leftRewritten == null && rightRewritten == null) {
+                  joinConjuncts.add(conjunct);
+                }
+              });
+
+      return new InnerJoinPushDownResult(
+          combineConjuncts(leftPushDownConjuncts.build()),
+          combineConjuncts(rightPushDownConjuncts.build()),
+          combineConjuncts(joinConjuncts.build()),
+          TRUE_LITERAL);
+    }
+
+    private Expression extractJoinPredicate(JoinNode joinNode) {
+      ImmutableList.Builder<Expression> builder = ImmutableList.builder();
+      for (JoinNode.EquiJoinClause equiJoinClause : joinNode.getCriteria()) {
+        builder.add(equiJoinClause.toExpression());
+      }
+      joinNode.getFilter().ifPresent(builder::add);
+      return combineConjuncts(builder.build());
+    }
+
+    private boolean joinEqualityExpression(
+        Expression expression, Collection<Symbol> leftSymbols, 
Collection<Symbol> rightSymbols) {
+      return joinComparisonExpression(
+          expression,
+          leftSymbols,
+          rightSymbols,
+          ImmutableSet.of(ComparisonExpression.Operator.EQUAL));
+    }
+
+    private boolean joinComparisonExpression(
+        Expression expression,
+        Collection<Symbol> leftSymbols,
+        Collection<Symbol> rightSymbols,
+        Set<ComparisonExpression.Operator> operators) {
+      // At this point in time, our join predicates need to be deterministic
+      if (expression instanceof ComparisonExpression && 
isDeterministic(expression)) {
+        ComparisonExpression comparison = (ComparisonExpression) expression;
+        if (operators.contains(comparison.getOperator())) {
+          Set<Symbol> symbols1 = extractUnique(comparison.getLeft());
+          Set<Symbol> symbols2 = extractUnique(comparison.getRight());
+          if (symbols1.isEmpty() || symbols2.isEmpty()) {
+            return false;
+          }
+          return (leftSymbols.containsAll(symbols1) && 
rightSymbols.containsAll(symbols2))
+              || (rightSymbols.containsAll(symbols1) && 
leftSymbols.containsAll(symbols2));
+        }
+      }
+      return false;
+    }
+
+    private Symbol symbolForExpression(Expression expression) {
+      if (expression instanceof SymbolReference) {
+        return Symbol.from(expression);
+      }
+
+      // TODO(beyyes) verify the rightness of type
+      return symbolAllocator.newSymbol(expression, 
analysis.getType(expression));
+    }
+
     @Override
     public PlanNode visitTableScan(TableScanNode node, Void context) {
       tableMetadataIndexScan(node, Collections.emptyList());
@@ -391,4 +859,38 @@ public class PushPredicateIntoTableScan implements 
PlanOptimizer {
       return this.expressionsCannotPushDown;
     }
   }
+
+  private static class InnerJoinPushDownResult {
+    private final Expression leftPredicate;
+    private final Expression rightPredicate;
+    private final Expression joinPredicate;
+    private final Expression postJoinPredicate;
+
+    private InnerJoinPushDownResult(
+        Expression leftPredicate,
+        Expression rightPredicate,
+        Expression joinPredicate,
+        Expression postJoinPredicate) {
+      this.leftPredicate = leftPredicate;
+      this.rightPredicate = rightPredicate;
+      this.joinPredicate = joinPredicate;
+      this.postJoinPredicate = postJoinPredicate;
+    }
+
+    private Expression getLeftPredicate() {
+      return leftPredicate;
+    }
+
+    private Expression getRightPredicate() {
+      return rightPredicate;
+    }
+
+    private Expression getJoinPredicate() {
+      return joinPredicate;
+    }
+
+    private Expression getPostJoinPredicate() {
+      return postJoinPredicate;
+    }
+  }
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/DisjointSet.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/DisjointSet.java
new file mode 100644
index 00000000000..eb13d09d043
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/DisjointSet.java
@@ -0,0 +1,114 @@
+package org.apache.iotdb.db.queryengine.plan.relational.utils;
+
+import java.util.Collection;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Set;
+
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.base.Verify.verify;
+
+public class DisjointSet<T> {
+  private static class Entry<T> {
+    private T parent;
+    // Without path compression, this would be equal to depth. Depth of 1-node 
tree is
+    // considered 0.
+    private int rank;
+
+    Entry() {
+      this(null, 0);
+    }
+
+    private Entry(T parent, int rank) {
+      this.parent = parent;
+      this.rank = rank;
+    }
+
+    public T getParent() {
+      return parent;
+    }
+
+    public void setParent(T parent) {
+      this.parent = parent;
+      this.rank = -1;
+    }
+
+    public int getRank() {
+      checkState(parent == null);
+      return rank;
+    }
+
+    public void incrementRank() {
+      checkState(parent == null);
+      rank++;
+    }
+  }
+
+  private final Map<T, Entry<T>> map;
+
+  public DisjointSet() {
+    map = new LinkedHashMap<>();
+  }
+
+  /**
+   * @return <tt>true</tt> if the specified equivalence is new
+   */
+  public boolean findAndUnion(T node1, T node2) {
+    return union(find(node1), find(node2));
+  }
+
+  public T find(T element) {
+    if (!map.containsKey(element)) {
+      map.put(element, new Entry<>());
+      return element;
+    }
+    return findInternal(element);
+  }
+
+  private boolean union(T root1, T root2) {
+    if (root1.equals(root2)) {
+      return false;
+    }
+    Entry<T> entry1 = map.get(root1);
+    Entry<T> entry2 = map.get(root2);
+    int entry1Rank = entry1.getRank();
+    int entry2Rank = entry2.getRank();
+    verify(entry1Rank >= 0);
+    verify(entry2Rank >= 0);
+    if (entry1Rank < entry2Rank) {
+      // make root1 child of root2
+      entry1.setParent(root2);
+    } else {
+      if (entry1Rank == entry2Rank) {
+        // increment rank of root1 when both side were equally deep
+        entry1.incrementRank();
+      }
+      // make root2 child of root1
+      entry2.setParent(root1);
+    }
+    return true;
+  }
+
+  private T findInternal(T element) {
+    Entry<T> value = map.get(element);
+    if (value.getParent() == null) {
+      return element;
+    }
+    T root = findInternal(value.getParent());
+    value.setParent(root);
+    return root;
+  }
+
+  public Collection<Set<T>> getEquivalentClasses() {
+    // map from root element to all element in the tree
+    Map<T, Set<T>> rootToTreeElements = new LinkedHashMap<>();
+    for (Map.Entry<T, Entry<T>> entry : map.entrySet()) {
+      T node = entry.getKey();
+      T root = findInternal(node);
+      rootToTreeElements.computeIfAbsent(root, unused -> new 
LinkedHashSet<>());
+      rootToTreeElements.get(root).add(node);
+    }
+    return rootToTreeElements.values();
+  }
+}

Reply via email to