This is an automated email from the ASF dual-hosted git repository.
caogaofei pushed a commit to branch beyyes/join
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/beyyes/join by this push:
new 459d5717aef add more join pushdown predicate optimize
459d5717aef is described below
commit 459d5717aefe2b11b4d87b29d92e15e38638e10f
Author: Beyyes <[email protected]>
AuthorDate: Mon Aug 19 18:55:49 2024 +0800
add more join pushdown predicate optimize
---
.../plan/relational/planner/EqualityInference.java | 422 +++++++++++++++++
.../plan/relational/planner/ir/AstUtils.java | 102 +++++
.../planner/ir/ExpressionNodeInliner.java | 24 +
.../plan/relational/planner/ir/IrUtils.java | 4 +
.../planner/ir/SubExpressionExtractor.java | 19 +
.../planner/iterative/IterativeOptimizer.java | 2 +-
.../plan/relational/planner/node/JoinNode.java | 1 +
.../planner/optimizations/PlanOptimizer.java | 2 +-
.../optimizations/PushPredicateIntoTableScan.java | 506 ++++++++++++++++++++-
.../plan/relational/utils/DisjointSet.java | 114 +++++
10 files changed, 1192 insertions(+), 4 deletions(-)
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/EqualityInference.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/EqualityInference.java
new file mode 100644
index 00000000000..1ca72118c5c
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/EqualityInference.java
@@ -0,0 +1,422 @@
+package org.apache.iotdb.db.queryengine.plan.relational.planner;
+
+import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.SubExpressionExtractor;
+import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.DisjointSet;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.ImmutableSetMultimap;
+import com.google.common.collect.Multimap;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.function.Predicate;
+import java.util.function.ToIntFunction;
+import java.util.stream.Stream;
+
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.util.Objects.requireNonNull;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.PredicateUtils.extractConjuncts;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.DeterminismEvaluator.isDeterministic;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.ExpressionNodeInliner.replaceExpression;
+
+/**
+ * Makes equality based inferences to rewrite Expressions and generate
equality sets in terms of
+ * specified symbol scopes
+ */
+public class EqualityInference {
+ // Comparator used to determine Expression preference when determining
canonicals
+ private final Comparator<Expression> canonicalComparator;
+ private final Multimap<Expression, Expression> equalitySets; // Indexed by
canonical expression
+ private final Map<Expression, Expression>
+ canonicalMap; // Map each known expression to canonical expression
+ private final Set<Expression> derivedExpressions;
+ private final Map<Expression, List<Expression>> expressionCache = new
HashMap<>();
+ private final Map<Expression, List<Symbol>> symbolsCache = new HashMap<>();
+ private final Map<Expression, Set<Symbol>> uniqueSymbolsCache = new
HashMap<>();
+
+ public EqualityInference(Metadata metadata, Expression... expressions) {
+ this(metadata, Arrays.asList(expressions));
+ }
+
+ public EqualityInference(Metadata metadata, Collection<Expression>
expressions) {
+ DisjointSet<Expression> equalities = new DisjointSet<>();
+ expressions.stream()
+ .flatMap(expression -> extractConjuncts(expression).stream())
+ .filter(expression -> isInferenceCandidate(metadata, expression))
+ .forEach(
+ expression -> {
+ ComparisonExpression comparison = (ComparisonExpression)
expression;
+ Expression expression1 = comparison.getLeft();
+ Expression expression2 = comparison.getRight();
+
+ equalities.findAndUnion(expression1, expression2);
+ });
+
+ Collection<Set<Expression>> equivalentClasses =
equalities.getEquivalentClasses();
+
+ // Map every expression to the set of equivalent expressions
+ Map<Expression, Set<Expression>> byExpression = new LinkedHashMap<>();
+ for (Set<Expression> equivalence : equivalentClasses) {
+ equivalence.forEach(expression -> byExpression.put(expression,
equivalence));
+ }
+
+ // For every non-derived expression, extract the sub-expressions and see
if they can be
+ // rewritten as other expressions. If so,
+ // use this new information to update the known equalities.
+ Set<Expression> derivedExpressions = new LinkedHashSet<>();
+ for (Expression expression : byExpression.keySet()) {
+ if (derivedExpressions.contains(expression)) {
+ continue;
+ }
+
+ extractSubExpressions(expression).stream()
+ .filter(e -> !e.equals(expression))
+ .forEach(
+ subExpression ->
+ byExpression.getOrDefault(subExpression,
ImmutableSet.of()).stream()
+ .filter(e -> !e.equals(subExpression))
+ .forEach(
+ equivalentSubExpression -> {
+ Expression rewritten =
+ replaceExpression(
+ expression,
+ ImmutableMap.of(subExpression,
equivalentSubExpression));
+ equalities.findAndUnion(expression, rewritten);
+ derivedExpressions.add(rewritten);
+ }));
+ }
+
+ Comparator<Expression> canonicalComparator =
+ Comparator
+ // Current cost heuristic:
+ // 1) Prefer fewer input symbols
+ // 2) Prefer smaller expression trees
+ // 3) Sort the expressions alphabetically - creates a stable
consistent ordering
+ // (extremely useful for unit testing)
+ // TODO: be more precise in determining the cost of an expression
+ .comparingInt(
+ (ToIntFunction<Expression>) expression ->
extractAllSymbols(expression).size())
+ .thenComparingLong(expression ->
extractSubExpressions(expression).size())
+ .thenComparing(Expression::toString);
+
+ Multimap<Expression, Expression> equalitySets =
+ makeEqualitySets(equalities, canonicalComparator);
+
+ ImmutableMap.Builder<Expression, Expression> canonicalMappings =
ImmutableMap.builder();
+ for (Map.Entry<Expression, Expression> entry : equalitySets.entries()) {
+ Expression canonical = entry.getKey();
+ Expression expression = entry.getValue();
+ canonicalMappings.put(expression, canonical);
+ }
+
+ this.equalitySets = equalitySets;
+ this.canonicalMap = canonicalMappings.buildOrThrow();
+ this.derivedExpressions = derivedExpressions;
+ this.canonicalComparator = canonicalComparator;
+ }
+
+ /**
+ * Attempts to rewrite an Expression in terms of the symbols allowed by the
symbol scope given the
+ * known equalities. Returns null if unsuccessful.
+ */
+ public Expression rewrite(Expression expression, Set<Symbol> scope) {
+ return rewrite(expression, scope::contains, true);
+ }
+
+ /**
+ * Dumps the inference equalities as equality expressions that are
partitioned by the symbolScope.
+ * All stored equalities are returned in a compact set and will be
classified into three groups as
+ * determined by the symbol scope:
+ *
+ * <ol>
+ * <li>equalities that fit entirely within the symbol scope
+ * <li>equalities that fit entirely outside of the symbol scope
+ * <li>equalities that straddle the symbol scope
+ * </ol>
+ *
+ * <pre>
+ * Example:
+ * Stored Equalities:
+ * a = b = c
+ * d = e = f = g
+ *
+ * Symbol Scope:
+ * a, b, d, e
+ *
+ * Output EqualityPartition:
+ * Scope Equalities:
+ * a = b
+ * d = e
+ * Complement Scope Equalities
+ * f = g
+ * Scope Straddling Equalities
+ * a = c
+ * d = f
+ * </pre>
+ */
+ public EqualityPartition generateEqualitiesPartitionedBy(Set<Symbol> scope) {
+ ImmutableSet.Builder<Expression> scopeEqualities = ImmutableSet.builder();
+ ImmutableSet.Builder<Expression> scopeComplementEqualities =
ImmutableSet.builder();
+ ImmutableSet.Builder<Expression> scopeStraddlingEqualities =
ImmutableSet.builder();
+
+ for (Collection<Expression> equalitySet : equalitySets.asMap().values()) {
+ Set<Expression> scopeExpressions = new LinkedHashSet<>();
+ Set<Expression> scopeComplementExpressions = new LinkedHashSet<>();
+ Set<Expression> scopeStraddlingExpressions = new LinkedHashSet<>();
+
+ // Try to push each non-derived expression into one side of the scope
+ equalitySet.stream()
+ .filter(candidate -> !derivedExpressions.contains(candidate))
+ .forEach(
+ candidate -> {
+ Expression scopeRewritten = rewrite(candidate,
scope::contains, false);
+ if (scopeRewritten != null) {
+ scopeExpressions.add(scopeRewritten);
+ }
+ Expression scopeComplementRewritten =
+ rewrite(candidate, symbol -> !scope.contains(symbol),
false);
+ if (scopeComplementRewritten != null) {
+ scopeComplementExpressions.add(scopeComplementRewritten);
+ }
+ if (scopeRewritten == null && scopeComplementRewritten ==
null) {
+ scopeStraddlingExpressions.add(candidate);
+ }
+ });
+ // Compile the equality expressions on each side of the scope
+ Expression matchingCanonical = getCanonical(scopeExpressions.stream());
+ if (scopeExpressions.size() >= 2) {
+ scopeExpressions.stream()
+ .filter(expression -> !expression.equals(matchingCanonical))
+ .map(
+ expression ->
+ new ComparisonExpression(
+ ComparisonExpression.Operator.EQUAL,
matchingCanonical, expression))
+ .forEach(scopeEqualities::add);
+ }
+ Expression complementCanonical =
getCanonical(scopeComplementExpressions.stream());
+ if (scopeComplementExpressions.size() >= 2) {
+ scopeComplementExpressions.stream()
+ .filter(expression -> !expression.equals(complementCanonical))
+ .map(
+ expression ->
+ new ComparisonExpression(
+ ComparisonExpression.Operator.EQUAL,
complementCanonical, expression))
+ .forEach(scopeComplementEqualities::add);
+ }
+
+ // Compile single equality between matching and complement scope.
+ // Only consider expressions that don't have derived expression in other
scope.
+ // Otherwise, redundant equality would be generated.
+ Optional<Expression> matchingConnecting =
+ scopeExpressions.stream()
+ .filter(
+ expression ->
+ SymbolsExtractor.extractAll(expression).isEmpty()
+ || rewrite(expression, symbol ->
!scope.contains(symbol), false) == null)
+ .min(canonicalComparator);
+ Optional<Expression> complementConnecting =
+ scopeComplementExpressions.stream()
+ .filter(
+ expression ->
+ SymbolsExtractor.extractAll(expression).isEmpty()
+ || rewrite(expression, scope::contains, false) ==
null)
+ .min(canonicalComparator);
+ if (matchingConnecting.isPresent()
+ && complementConnecting.isPresent()
+ && !matchingConnecting.equals(complementConnecting)) {
+ scopeStraddlingEqualities.add(
+ new ComparisonExpression(
+ ComparisonExpression.Operator.EQUAL,
+ matchingConnecting.get(),
+ complementConnecting.get()));
+ }
+
+ // Compile the scope straddling equality expressions.
+ // scopeStraddlingExpressions couldn't be pushed to either side,
+ // therefore there needs to be an equality generated with
+ // one of the scopes (either matching or complement).
+ List<Expression> straddlingExpressions = new ArrayList<>();
+ if (matchingCanonical != null) {
+ straddlingExpressions.add(matchingCanonical);
+ } else if (complementCanonical != null) {
+ straddlingExpressions.add(complementCanonical);
+ }
+ straddlingExpressions.addAll(scopeStraddlingExpressions);
+ Expression connectingCanonical =
getCanonical(straddlingExpressions.stream());
+ if (connectingCanonical != null) {
+ straddlingExpressions.stream()
+ .filter(expression -> !expression.equals(connectingCanonical))
+ .map(
+ expression ->
+ new ComparisonExpression(
+ ComparisonExpression.Operator.EQUAL,
connectingCanonical, expression))
+ .forEach(scopeStraddlingEqualities::add);
+ }
+ }
+
+ return new EqualityPartition(
+ scopeEqualities.build(),
+ scopeComplementEqualities.build(),
+ scopeStraddlingEqualities.build());
+ }
+
+ /** Determines whether an Expression may be successfully applied to the
equality inference */
+ public static boolean isInferenceCandidate(Metadata metadata, Expression
expression) {
+ if (expression instanceof ComparisonExpression &&
isDeterministic(expression)
+ // && !mayReturnNullOnNonNullInput(expression)
+ ) {
+ ComparisonExpression comparison = (ComparisonExpression) expression;
+ if (comparison.getOperator() == ComparisonExpression.Operator.EQUAL) {
+ // We should only consider equalities that have distinct left and
right components
+ return !comparison.getLeft().equals(comparison.getRight());
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Provides a convenience Stream of Expression conjuncts which have not been
added to the
+ * inference
+ */
+ public static Stream<Expression> nonInferrableConjuncts(
+ Metadata metadata, Expression expression) {
+ return extractConjuncts(expression).stream().filter(e ->
!isInferenceCandidate(metadata, e));
+ }
+
+ private Expression rewrite(
+ Expression expression, Predicate<Symbol> symbolScope, boolean
allowFullReplacement) {
+ Map<Expression, Expression> expressionRemap = new HashMap<>();
+ extractSubExpressions(expression).stream()
+ .filter(
+ allowFullReplacement
+ ? subExpression -> true
+ : subExpression -> !subExpression.equals(expression))
+ .forEach(
+ subExpression -> {
+ Expression canonical = getScopedCanonical(subExpression,
symbolScope);
+ if (canonical != null) {
+ expressionRemap.putIfAbsent(subExpression, canonical);
+ }
+ });
+
+ // Perform a naive single-pass traversal to try to rewrite non-compliant
portions of the tree.
+ // Prefers to replace
+ // larger subtrees over smaller subtrees
+ // TODO: this rewrite can probably be made more sophisticated
+ Expression rewritten = replaceExpression(expression, expressionRemap);
+ if (!isScoped(rewritten, symbolScope)) {
+ // If the rewritten is still not compliant with the symbol scope, just
give up
+ return null;
+ }
+ return rewritten;
+ }
+
+ /** Returns the most preferrable expression to be used as the canonical
expression */
+ private Expression getCanonical(Stream<Expression> expressions) {
+ return expressions.min(canonicalComparator).orElse(null);
+ }
+
+ /**
+ * Returns a canonical expression that is fully contained by the symbolScope
and that is
+ * equivalent to the specified expression. Returns null if unable to find a
canonical.
+ */
+ @VisibleForTesting
+ Expression getScopedCanonical(Expression expression, Predicate<Symbol>
symbolScope) {
+ Expression canonicalIndex = canonicalMap.get(expression);
+ if (canonicalIndex == null) {
+ return null;
+ }
+
+ Collection<Expression> equivalences = equalitySets.get(canonicalIndex);
+ if (expression instanceof SymbolReference) {
+ boolean inScope =
+ equivalences.stream()
+ .filter(SymbolReference.class::isInstance)
+ .map(Symbol::from)
+ .anyMatch(symbolScope);
+
+ if (!inScope) {
+ return null;
+ }
+ }
+
+ return getCanonical(equivalences.stream().filter(e -> isScoped(e,
symbolScope)));
+ }
+
+ private boolean isScoped(Expression expression, Predicate<Symbol>
symbolScope) {
+ return extractUniqueSymbols(expression).stream().allMatch(symbolScope);
+ }
+
+ private static Multimap<Expression, Expression> makeEqualitySets(
+ DisjointSet<Expression> equalities, Comparator<Expression>
canonicalComparator) {
+ ImmutableSetMultimap.Builder<Expression, Expression> builder =
ImmutableSetMultimap.builder();
+ for (Set<Expression> equalityGroup : equalities.getEquivalentClasses()) {
+ if (!equalityGroup.isEmpty()) {
+ builder.putAll(equalityGroup.stream().min(canonicalComparator).get(),
equalityGroup);
+ }
+ }
+ return builder.build();
+ }
+
+ private List<Expression> extractSubExpressions(Expression expression) {
+ return expressionCache.computeIfAbsent(
+ expression, e ->
SubExpressionExtractor.extract(e).collect(toImmutableList()));
+ }
+
+ private Set<Symbol> extractUniqueSymbols(Expression expression) {
+ return uniqueSymbolsCache.computeIfAbsent(
+ expression, e -> ImmutableSet.copyOf(extractAllSymbols(expression)));
+ }
+
+ private List<Symbol> extractAllSymbols(Expression expression) {
+ return symbolsCache.computeIfAbsent(expression,
SymbolsExtractor::extractAll);
+ }
+
+ public static class EqualityPartition {
+ private final List<Expression> scopeEqualities;
+ private final List<Expression> scopeComplementEqualities;
+ private final List<Expression> scopeStraddlingEqualities;
+
+ public EqualityPartition(
+ Iterable<Expression> scopeEqualities,
+ Iterable<Expression> scopeComplementEqualities,
+ Iterable<Expression> scopeStraddlingEqualities) {
+ this.scopeEqualities =
+ ImmutableList.copyOf(requireNonNull(scopeEqualities,
"scopeEqualities is null"));
+ this.scopeComplementEqualities =
+ ImmutableList.copyOf(
+ requireNonNull(scopeComplementEqualities,
"scopeComplementEqualities is null"));
+ this.scopeStraddlingEqualities =
+ ImmutableList.copyOf(
+ requireNonNull(scopeStraddlingEqualities,
"scopeStraddlingEqualities is null"));
+ }
+
+ public List<Expression> getScopeEqualities() {
+ return scopeEqualities;
+ }
+
+ public List<Expression> getScopeComplementEqualities() {
+ return scopeComplementEqualities;
+ }
+
+ public List<Expression> getScopeStraddlingEqualities() {
+ return scopeStraddlingEqualities;
+ }
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/AstUtils.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/AstUtils.java
new file mode 100644
index 00000000000..7e47072947c
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/AstUtils.java
@@ -0,0 +1,102 @@
+package org.apache.iotdb.db.queryengine.plan.relational.planner.ir;
+
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LogicalExpression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Node;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.graph.SuccessorsFunction;
+import com.google.common.graph.Traverser;
+
+import java.util.List;
+import java.util.OptionalInt;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import java.util.stream.Stream;
+
+import static com.google.common.collect.Streams.stream;
+import static java.util.Objects.requireNonNull;
+
+public final class AstUtils {
+ public static Stream<Node> preOrder(Node node) {
+ return stream(
+ Traverser.forTree((SuccessorsFunction<Node>) Node::getChildren)
+ .depthFirstPreOrder(requireNonNull(node, "node is null")));
+ }
+
+ /**
+ * Compares two AST trees recursively by applying the provided comparator to
each pair of nodes.
+ *
+ * <p>The comparator can perform a hybrid shallow/deep comparison. If it
returns true or false,
+ * the nodes and any subtrees are considered equal or different,
respectively. If it returns null,
+ * the nodes are considered shallowly-equal and their children will be
compared recursively.
+ */
+ public static boolean treeEqual(
+ Node left, Node right, BiFunction<Node, Node, Boolean>
subtreeComparator) {
+ Boolean equal = subtreeComparator.apply(left, right);
+
+ if (equal != null) {
+ return equal;
+ }
+
+ List<? extends Node> leftChildren = left.getChildren();
+ List<? extends Node> rightChildren = right.getChildren();
+
+ if (leftChildren.size() != rightChildren.size()) {
+ return false;
+ }
+
+ for (int i = 0; i < leftChildren.size(); i++) {
+ if (!treeEqual(leftChildren.get(i), rightChildren.get(i),
subtreeComparator)) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ /**
+ * Computes a hash of the given AST by applying the provided subtree hasher
at each level.
+ *
+ * <p>If the hasher returns a non-empty {@link OptionalInt}, the value is
treated as the hash for
+ * the subtree at that node. Otherwise, the hashes of its children are
computed and combined.
+ */
+ public static int treeHash(Node node, Function<Node, OptionalInt>
subtreeHasher) {
+ OptionalInt hash = subtreeHasher.apply(node);
+
+ if (hash.isPresent()) {
+ return hash.getAsInt();
+ }
+
+ List<? extends Node> children = node.getChildren();
+
+ int result = node.getClass().hashCode();
+ for (Node element : children) {
+ result = 31 * result + treeHash(element, subtreeHasher);
+ }
+
+ return result;
+ }
+
+ public static List<Expression> extractConjuncts(Expression expression) {
+ ImmutableList.Builder<Expression> resultBuilder = ImmutableList.builder();
+ extractPredicates(LogicalExpression.Operator.AND, expression,
resultBuilder);
+ return resultBuilder.build();
+ }
+
+ private static void extractPredicates(
+ LogicalExpression.Operator operator,
+ Expression expression,
+ ImmutableList.Builder<Expression> resultBuilder) {
+ if (expression instanceof LogicalExpression
+ && ((LogicalExpression) expression).getOperator() == operator) {
+ for (Expression term : ((LogicalExpression) expression).getTerms()) {
+ extractPredicates(operator, term, resultBuilder);
+ }
+ } else {
+ resultBuilder.add(expression);
+ }
+ }
+
+ private AstUtils() {}
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/ExpressionNodeInliner.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/ExpressionNodeInliner.java
new file mode 100644
index 00000000000..813d22cfdba
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/ExpressionNodeInliner.java
@@ -0,0 +1,24 @@
+package org.apache.iotdb.db.queryengine.plan.relational.planner.ir;
+
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+
+import java.util.Map;
+
+public class ExpressionNodeInliner extends ExpressionRewriter<Void> {
+ public static Expression replaceExpression(
+ Expression expression, Map<? extends Expression, ? extends Expression>
mappings) {
+ return ExpressionTreeRewriter.rewriteWith(new
ExpressionNodeInliner(mappings), expression);
+ }
+
+ private final Map<? extends Expression, ? extends Expression> mappings;
+
+ public ExpressionNodeInliner(Map<? extends Expression, ? extends Expression>
mappings) {
+ this.mappings = mappings;
+ }
+
+ @Override
+ protected Expression rewriteExpression(
+ Expression node, Void context, ExpressionTreeRewriter<Void>
treeRewriter) {
+ return mappings.get(node);
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
index e0012a2db8b..ec0ad3fbb48 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
@@ -152,6 +152,10 @@ public final class IrUtils {
return combineConjuncts(Arrays.asList(expressions));
}
+ public static Expression filterDeterministicConjuncts(Expression expression)
{
+ return filterConjuncts(expression, DeterminismEvaluator::isDeterministic);
+ }
+
public static Expression combineConjuncts(Collection<Expression>
expressions) {
requireNonNull(expressions, "expressions is null");
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/SubExpressionExtractor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/SubExpressionExtractor.java
new file mode 100644
index 00000000000..13842e5d883
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/SubExpressionExtractor.java
@@ -0,0 +1,19 @@
+package org.apache.iotdb.db.queryengine.plan.relational.planner.ir;
+
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+
+import java.util.stream.Stream;
+
+/**
+ * Extracts and returns the stream of all expression subtrees within an
Expression, including
+ * Expression itself
+ */
+public final class SubExpressionExtractor {
+ private SubExpressionExtractor() {}
+
+ public static Stream<Expression> extract(Expression expression) {
+ return AstUtils.preOrder(expression)
+ .filter(Expression.class::isInstance)
+ .map(Expression.class::cast);
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/IterativeOptimizer.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/IterativeOptimizer.java
index 89de7c169c1..04327b6ef0b 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/IterativeOptimizer.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/IterativeOptimizer.java
@@ -106,7 +106,7 @@ public class IterativeOptimizer implements
AdaptivePlanOptimizer {
memo,
lookup,
context.idAllocator(),
- context.symbolAllocator(),
+ context.getSymbolAllocator(),
nanoTime(),
timeout.toMillis(),
context.sessionInfo(),
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
index 25eafc4c626..45358d7d9e9 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/JoinNode.java
@@ -33,6 +33,7 @@ public class JoinNode extends TwoChildProcessNode {
private final List<Symbol> leftOutputSymbols;
private final List<Symbol> rightOutputSymbols;
private final boolean maySkipOutputDuplicates;
+ // some filter like 'a.xx_column < b.yy_column'
private final Optional<Expression> filter;
private final Optional<Symbol> leftHashSymbol;
private final Optional<Symbol> rightHashSymbol;
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PlanOptimizer.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PlanOptimizer.java
index 003e7d12277..31641432b3a 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PlanOptimizer.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PlanOptimizer.java
@@ -82,7 +82,7 @@ public interface PlanOptimizer {
return typeProvider;
}
- public SymbolAllocator symbolAllocator() {
+ public SymbolAllocator getSymbolAllocator() {
return symbolAllocator;
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
index 5a0f9e3bc95..9fc49530418 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
@@ -34,31 +34,56 @@ import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate.Predic
import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnSchema;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.DeviceEntry;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.EqualityInference;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FilterNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.TableScanNode;
+import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.FunctionCall;
import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LogicalExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Node;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
import org.apache.tsfile.read.filter.basic.Filter;
import org.apache.tsfile.utils.Pair;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.Collections;
+import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Objects.requireNonNull;
import static
org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory.ATTRIBUTE;
import static
org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory.MEASUREMENT;
import static
org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory.TIME;
import static
org.apache.iotdb.db.queryengine.plan.analyze.AnalyzeVisitor.getTimePartitionSlotList;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.PredicateUtils.combineConjuncts;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.PredicateUtils.extractConjuncts;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolsExtractor.extractUnique;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.DeterminismEvaluator.isDeterministic;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.GlobalTimePredicateExtractVisitor.extractGlobalTimeFilter;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils.filterDeterministicConjuncts;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.FULL;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.INNER;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.LEFT;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode.JoinType.RIGHT;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL;
/**
* <b>Optimization phase:</b> Logical plan planning.
@@ -88,7 +113,11 @@ public class PushPredicateIntoTableScan implements
PlanOptimizer {
@Override
public PlanNode optimize(PlanNode plan, Context context) {
return plan.accept(
- new Rewriter(context.getQueryContext(), context.getAnalysis(),
context.getMetadata()),
+ new Rewriter(
+ context.getQueryContext(),
+ context.getAnalysis(),
+ context.getMetadata(),
+ context.getSymbolAllocator()),
null);
}
@@ -97,11 +126,17 @@ public class PushPredicateIntoTableScan implements
PlanOptimizer {
private final Analysis analysis;
private final Metadata metadata;
private Expression predicate;
+ private SymbolAllocator symbolAllocator;
- Rewriter(MPPQueryContext queryContext, Analysis analysis, Metadata
metadata) {
+ Rewriter(
+ MPPQueryContext queryContext,
+ Analysis analysis,
+ Metadata metadata,
+ SymbolAllocator symbolAllocator) {
this.queryContext = queryContext;
this.analysis = analysis;
this.metadata = metadata;
+ this.symbolAllocator = symbolAllocator;
}
@Override
@@ -150,6 +185,8 @@ public class PushPredicateIntoTableScan implements
PlanOptimizer {
if (node.getChild() instanceof TableScanNode) {
// child of FilterNode is TableScanNode, means FilterNode must get
from where clause
return combineFilterAndScan((TableScanNode) node.getChild());
+ } else if (node.getChild() instanceof JoinNode) {
+ return visitJoin((JoinNode) node.getChild(), context);
} else {
// FilterNode may get from having or subquery
node.setChild(node.getChild().accept(this, context));
@@ -256,6 +293,437 @@ public class PushPredicateIntoTableScan implements
PlanOptimizer {
metadataExpressions, expressionsCanPushDown,
expressionsCannotPushDown);
}
+ @Override
+ public PlanNode visitJoin(JoinNode node, Void context) {
+ Expression inheritedPredicate = predicate;
+
+ // See if we can rewrite outer joins in terms of a plain inner join
+ node = tryNormalizeToOuterToInnerJoin(node, inheritedPredicate);
+
+ Expression leftEffectivePredicate = TRUE_LITERAL;
+ // effectivePredicateExtractor.extract(session, node.getLeftChild(),
types, typeAnalyzer);
+ Expression rightEffectivePredicate = TRUE_LITERAL;
+ // effectivePredicateExtractor.extract(session, node.getRightChild(),
types, typeAnalyzer);
+ Expression joinPredicate = extractJoinPredicate(node);
+
+ Expression leftPredicate;
+ Expression rightPredicate;
+ Expression postJoinPredicate;
+ Expression newJoinPredicate;
+
+ switch (node.getJoinType()) {
+ case INNER:
+ InnerJoinPushDownResult innerJoinPushDownResult =
+ processInnerJoin(
+ inheritedPredicate,
+ leftEffectivePredicate,
+ rightEffectivePredicate,
+ joinPredicate,
+ node.getLeftChild().getOutputSymbols(),
+ node.getRightChild().getOutputSymbols());
+ leftPredicate = innerJoinPushDownResult.getLeftPredicate();
+ rightPredicate = innerJoinPushDownResult.getRightPredicate();
+ postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
+ newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
+ break;
+ default:
+ throw new IllegalStateException("Only support INNER JOIN in current
version");
+ }
+
+ // newJoinPredicate = simplifyExpression(newJoinPredicate);
+
+ // Create identity projections for all existing symbols
+ Assignments.Builder leftProjections = Assignments.builder();
+ leftProjections.putAll(
+ node.getLeftChild().getOutputSymbols().stream()
+ .collect(toImmutableMap(key -> key, Symbol::toSymbolReference)));
+
+ Assignments.Builder rightProjections = Assignments.builder();
+ rightProjections.putAll(
+ node.getRightChild().getOutputSymbols().stream()
+ .collect(toImmutableMap(key -> key, Symbol::toSymbolReference)));
+
+ // Create new projections for the new join clauses
+ List<JoinNode.EquiJoinClause> equiJoinClauses = new ArrayList<>();
+ ImmutableList.Builder<Expression> joinFilterBuilder =
ImmutableList.builder();
+ for (Expression conjunct : extractConjuncts(newJoinPredicate)) {
+ if (joinEqualityExpression(
+ conjunct,
+ node.getLeftChild().getOutputSymbols(),
+ node.getRightChild().getOutputSymbols())) {
+ ComparisonExpression equality = (ComparisonExpression) conjunct;
+
+ boolean alignedComparison =
+
node.getLeftChild().getOutputSymbols().containsAll(extractUnique(equality.getLeft()));
+ Expression leftExpression = alignedComparison ? equality.getLeft() :
equality.getRight();
+ Expression rightExpression = alignedComparison ? equality.getRight()
: equality.getLeft();
+
+ Symbol leftSymbol = symbolForExpression(leftExpression);
+ if (!node.getLeftChild().getOutputSymbols().contains(leftSymbol)) {
+ leftProjections.put(leftSymbol, leftExpression);
+ }
+
+ Symbol rightSymbol = symbolForExpression(rightExpression);
+ if (!node.getRightChild().getOutputSymbols().contains(rightSymbol)) {
+ rightProjections.put(rightSymbol, rightExpression);
+ }
+
+ equiJoinClauses.add(new JoinNode.EquiJoinClause(leftSymbol,
rightSymbol));
+ } else {
+ joinFilterBuilder.add(conjunct);
+ }
+ }
+
+ List<Expression> joinFilter = joinFilterBuilder.build();
+ // DynamicFiltersResult dynamicFiltersResult =
createDynamicFilters(node,
+ // equiJoinClauses, joinFilter, session, idAllocator);
+ // Map<DynamicFilterId, Symbol> dynamicFilters =
+ // dynamicFiltersResult.getDynamicFilters();
+ // leftPredicate = combineConjuncts(metadata, leftPredicate,
combineConjuncts(metadata,
+ // dynamicFiltersResult.getPredicates()));
+
+ PlanNode leftSource = node.getLeftChild();
+ PlanNode rightSource = node.getRightChild();
+ boolean equiJoinClausesUnmodified =
+
ImmutableSet.copyOf(equiJoinClauses).equals(ImmutableSet.copyOf(node.getCriteria()));
+ if (!equiJoinClausesUnmodified) {
+ // leftSource = context.rewrite(new
ProjectNode(queryContext.getQueryId().genPlanNodeId(),
+ // node.getLeftChild(), leftProjections.build()), leftPredicate);
+ // rightSource = context.rewrite(new
ProjectNode(queryContext.getQueryId().genPlanNodeId(),
+ // node.getRightChild(), rightProjections.build()), rightPredicate);
+ } else {
+ // leftSource = context.rewrite(node.getLeftChild(), leftPredicate);
+ // rightSource = context.rewrite(node.getRightChild(), rightPredicate);
+ // TODO rewrite
+ }
+
+ Optional<Expression> newJoinFilter =
Optional.of(combineConjuncts(joinFilter));
+ if (newJoinFilter.get().equals(TRUE_LITERAL)) {
+ newJoinFilter = Optional.empty();
+ }
+
+ if (node.getJoinType() == INNER && newJoinFilter.isPresent() &&
equiJoinClauses.isEmpty()) {
+ // if we do not have any equi conjunct we do not pushdown non-equality
condition into
+ // inner join, so we plan execution as nested-loops-join followed by
filter instead
+ // hash join.
+ // todo: remove the code when we have support for filter function in
nested loop join
+ // postJoinPredicate = combineConjuncts(postJoinPredicate,
newJoinFilter.get());
+ newJoinFilter = Optional.empty();
+ }
+
+ boolean filtersEquivalent =
+ newJoinFilter.isPresent() == node.getFilter().isPresent() &&
(!newJoinFilter.isPresent());
+ // areExpressionsEquivalent(newJoinFilter.get(), node.getFilter().get());
+
+ PlanNode output = node;
+ if (leftSource != node.getLeftChild()
+ || rightSource != node.getRightChild()
+ || !filtersEquivalent
+ ||
+ // !dynamicFilters.equals(node.getDynamicFilters()) ||
+ !equiJoinClausesUnmodified) {
+ leftSource =
+ new ProjectNode(
+ queryContext.getQueryId().genPlanNodeId(), leftSource,
leftProjections.build());
+ rightSource =
+ new ProjectNode(
+ queryContext.getQueryId().genPlanNodeId(), rightSource,
rightProjections.build());
+
+ output =
+ new JoinNode(
+ node.getPlanNodeId(),
+ node.getJoinType(),
+ leftSource,
+ rightSource,
+ equiJoinClauses,
+ leftSource.getOutputSymbols(),
+ rightSource.getOutputSymbols(),
+ node.isMaySkipOutputDuplicates(),
+ newJoinFilter,
+ node.getLeftHashSymbol(),
+ node.getRightHashSymbol(),
+ node.isSpillable());
+ }
+
+ if (!postJoinPredicate.equals(TRUE_LITERAL)) {
+ output =
+ new FilterNode(queryContext.getQueryId().genPlanNodeId(), output,
postJoinPredicate);
+ }
+
+ if (!node.getOutputSymbols().equals(output.getOutputSymbols())) {
+ output =
+ new ProjectNode(
+ queryContext.getQueryId().genPlanNodeId(),
+ output,
+ Assignments.identity(node.getOutputSymbols()));
+ }
+
+ return output;
+ }
+
+ private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression
inheritedPredicate) {
+ checkArgument(
+ EnumSet.of(INNER, RIGHT, LEFT, FULL).contains(node.getJoinType()),
+ "Unsupported join type: %s",
+ node.getJoinType());
+
+ if (node.getJoinType() == JoinNode.JoinType.INNER) {
+ return node;
+ }
+
+ if (node.getJoinType() == JoinNode.JoinType.FULL) {
+ boolean canConvertToLeftJoin =
+ canConvertOuterToInner(node.getLeftChild().getOutputSymbols(),
inheritedPredicate);
+ boolean canConvertToRightJoin =
+ canConvertOuterToInner(node.getRightChild().getOutputSymbols(),
inheritedPredicate);
+ if (!canConvertToLeftJoin && !canConvertToRightJoin) {
+ return node;
+ }
+ if (canConvertToLeftJoin && canConvertToRightJoin) {
+ return new JoinNode(
+ node.getPlanNodeId(),
+ INNER,
+ node.getLeftChild(),
+ node.getRightChild(),
+ node.getCriteria(),
+ node.getLeftOutputSymbols(),
+ node.getRightOutputSymbols(),
+ node.isMaySkipOutputDuplicates(),
+ node.getFilter(),
+ node.getLeftHashSymbol(),
+ node.getRightHashSymbol(),
+ node.isSpillable());
+ }
+ return new JoinNode(
+ node.getPlanNodeId(),
+ canConvertToLeftJoin ? LEFT : RIGHT,
+ node.getLeftChild(),
+ node.getRightChild(),
+ node.getCriteria(),
+ node.getLeftOutputSymbols(),
+ node.getRightOutputSymbols(),
+ node.isMaySkipOutputDuplicates(),
+ node.getFilter(),
+ node.getLeftHashSymbol(),
+ node.getRightHashSymbol(),
+ node.isSpillable());
+ }
+
+ if (node.getJoinType() == JoinNode.JoinType.LEFT
+ && !canConvertOuterToInner(
+ node.getRightChild().getOutputSymbols(), inheritedPredicate)
+ || node.getJoinType() == JoinNode.JoinType.RIGHT
+ && !canConvertOuterToInner(
+ node.getLeftChild().getOutputSymbols(), inheritedPredicate))
{
+ return node;
+ }
+ return new JoinNode(
+ node.getPlanNodeId(),
+ JoinNode.JoinType.INNER,
+ node.getLeftChild(),
+ node.getRightChild(),
+ node.getCriteria(),
+ node.getLeftOutputSymbols(),
+ node.getRightOutputSymbols(),
+ node.isMaySkipOutputDuplicates(),
+ node.getFilter(),
+ node.getLeftHashSymbol(),
+ node.getRightHashSymbol(),
+ node.isSpillable());
+ }
+
+ private boolean canConvertOuterToInner(
+ List<Symbol> innerSymbolsForOuterJoin, Expression inheritedPredicate) {
+ Set<Symbol> innerSymbols = ImmutableSet.copyOf(innerSymbolsForOuterJoin);
+ for (Expression conjunct : extractConjuncts(inheritedPredicate)) {
+ if (isDeterministic(conjunct)) {
+ // Ignore a conjunct for this test if we cannot deterministically
get responses from it
+ // Object response = nullInputEvaluator(innerSymbols, conjunct);
+ // if (response == null || response instanceof NullLiteral ||
+ // Boolean.FALSE.equals(response)) {
+ // If there is a single conjunct that returns FALSE or NULL given
all NULL inputs for the
+ // inner side symbols of an outer join
+ // then this conjunct removes all effects of the outer join, and
effectively turns this
+ // into an equivalent of an inner join.
+ // So, let's just rewrite this join as an INNER join
+ return true;
+ // }
+ }
+ }
+ return false;
+ }
+
+ private InnerJoinPushDownResult processInnerJoin(
+ Expression inheritedPredicate,
+ Expression leftEffectivePredicate,
+ Expression rightEffectivePredicate,
+ Expression joinPredicate,
+ Collection<Symbol> leftSymbols,
+ Collection<Symbol> rightSymbols) {
+ checkArgument(
+ leftSymbols.containsAll(extractUnique(leftEffectivePredicate)),
+ "leftEffectivePredicate must only contain symbols from leftSymbols");
+ checkArgument(
+ rightSymbols.containsAll(extractUnique(rightEffectivePredicate)),
+ "rightEffectivePredicate must only contain symbols from
rightSymbols");
+
+ ImmutableList.Builder<Expression> leftPushDownConjuncts =
ImmutableList.builder();
+ ImmutableList.Builder<Expression> rightPushDownConjuncts =
ImmutableList.builder();
+ ImmutableList.Builder<Expression> joinConjuncts =
ImmutableList.builder();
+
+ // Strip out non-deterministic conjuncts
+ extractConjuncts(inheritedPredicate).stream()
+ .filter(deterministic -> !isDeterministic(deterministic))
+ .forEach(joinConjuncts::add);
+ inheritedPredicate = filterDeterministicConjuncts(inheritedPredicate);
+
+ extractConjuncts(joinPredicate).stream()
+ .filter(expression -> !isDeterministic(expression))
+ .forEach(joinConjuncts::add);
+ joinPredicate = filterDeterministicConjuncts(joinPredicate);
+
+ leftEffectivePredicate =
filterDeterministicConjuncts(leftEffectivePredicate);
+ rightEffectivePredicate =
filterDeterministicConjuncts(rightEffectivePredicate);
+
+ ImmutableSet<Symbol> leftScope = ImmutableSet.copyOf(leftSymbols);
+ ImmutableSet<Symbol> rightScope = ImmutableSet.copyOf(rightSymbols);
+
+ // Generate equality inferences
+ EqualityInference allInference =
+ new EqualityInference(
+ metadata,
+ inheritedPredicate,
+ leftEffectivePredicate,
+ rightEffectivePredicate,
+ joinPredicate);
+ EqualityInference allInferenceWithoutLeftInferred =
+ new EqualityInference(
+ metadata, inheritedPredicate, rightEffectivePredicate,
joinPredicate);
+ EqualityInference allInferenceWithoutRightInferred =
+ new EqualityInference(
+ metadata, inheritedPredicate, leftEffectivePredicate,
joinPredicate);
+
+ // Add equalities from the inference back in
+ leftPushDownConjuncts.addAll(
+ allInferenceWithoutLeftInferred
+ .generateEqualitiesPartitionedBy(leftScope)
+ .getScopeEqualities());
+ rightPushDownConjuncts.addAll(
+ allInferenceWithoutRightInferred
+ .generateEqualitiesPartitionedBy(rightScope)
+ .getScopeEqualities());
+ joinConjuncts.addAll(
+ allInference
+ .generateEqualitiesPartitionedBy(leftScope)
+ .getScopeStraddlingEqualities()); // scope straddling equalities
get dropped in as
+ // part of the join predicate
+
+ // Sort through conjuncts in inheritedPredicate that were not used for
inference
+ EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate)
+ .forEach(
+ conjunct -> {
+ Expression leftRewrittenConjunct =
allInference.rewrite(conjunct, leftScope);
+ if (leftRewrittenConjunct != null) {
+ leftPushDownConjuncts.add(leftRewrittenConjunct);
+ }
+
+ Expression rightRewrittenConjunct =
allInference.rewrite(conjunct, rightScope);
+ if (rightRewrittenConjunct != null) {
+ rightPushDownConjuncts.add(rightRewrittenConjunct);
+ }
+
+ // Drop predicate after join only if unable to push down to
either side
+ if (leftRewrittenConjunct == null && rightRewrittenConjunct ==
null) {
+ joinConjuncts.add(conjunct);
+ }
+ });
+
+ // See if we can push the right effective predicate to the left side
+ EqualityInference.nonInferrableConjuncts(metadata,
rightEffectivePredicate)
+ .map(conjunct -> allInference.rewrite(conjunct, leftScope))
+ .filter(Objects::nonNull)
+ .forEach(leftPushDownConjuncts::add);
+
+ // See if we can push the left effective predicate to the right side
+ EqualityInference.nonInferrableConjuncts(metadata,
leftEffectivePredicate)
+ .map(conjunct -> allInference.rewrite(conjunct, rightScope))
+ .filter(Objects::nonNull)
+ .forEach(rightPushDownConjuncts::add);
+
+ // See if we can push any parts of the join predicates to either side
+ EqualityInference.nonInferrableConjuncts(metadata, joinPredicate)
+ .forEach(
+ conjunct -> {
+ Expression leftRewritten = allInference.rewrite(conjunct,
leftScope);
+ if (leftRewritten != null) {
+ leftPushDownConjuncts.add(leftRewritten);
+ }
+
+ Expression rightRewritten = allInference.rewrite(conjunct,
rightScope);
+ if (rightRewritten != null) {
+ rightPushDownConjuncts.add(rightRewritten);
+ }
+
+ if (leftRewritten == null && rightRewritten == null) {
+ joinConjuncts.add(conjunct);
+ }
+ });
+
+ return new InnerJoinPushDownResult(
+ combineConjuncts(leftPushDownConjuncts.build()),
+ combineConjuncts(rightPushDownConjuncts.build()),
+ combineConjuncts(joinConjuncts.build()),
+ TRUE_LITERAL);
+ }
+
+ private Expression extractJoinPredicate(JoinNode joinNode) {
+ ImmutableList.Builder<Expression> builder = ImmutableList.builder();
+ for (JoinNode.EquiJoinClause equiJoinClause : joinNode.getCriteria()) {
+ builder.add(equiJoinClause.toExpression());
+ }
+ joinNode.getFilter().ifPresent(builder::add);
+ return combineConjuncts(builder.build());
+ }
+
+ private boolean joinEqualityExpression(
+ Expression expression, Collection<Symbol> leftSymbols,
Collection<Symbol> rightSymbols) {
+ return joinComparisonExpression(
+ expression,
+ leftSymbols,
+ rightSymbols,
+ ImmutableSet.of(ComparisonExpression.Operator.EQUAL));
+ }
+
+ private boolean joinComparisonExpression(
+ Expression expression,
+ Collection<Symbol> leftSymbols,
+ Collection<Symbol> rightSymbols,
+ Set<ComparisonExpression.Operator> operators) {
+ // At this point in time, our join predicates need to be deterministic
+ if (expression instanceof ComparisonExpression &&
isDeterministic(expression)) {
+ ComparisonExpression comparison = (ComparisonExpression) expression;
+ if (operators.contains(comparison.getOperator())) {
+ Set<Symbol> symbols1 = extractUnique(comparison.getLeft());
+ Set<Symbol> symbols2 = extractUnique(comparison.getRight());
+ if (symbols1.isEmpty() || symbols2.isEmpty()) {
+ return false;
+ }
+ return (leftSymbols.containsAll(symbols1) &&
rightSymbols.containsAll(symbols2))
+ || (rightSymbols.containsAll(symbols1) &&
leftSymbols.containsAll(symbols2));
+ }
+ }
+ return false;
+ }
+
+ private Symbol symbolForExpression(Expression expression) {
+ if (expression instanceof SymbolReference) {
+ return Symbol.from(expression);
+ }
+
+ // TODO(beyyes) verify the rightness of type
+ return symbolAllocator.newSymbol(expression,
analysis.getType(expression));
+ }
+
@Override
public PlanNode visitTableScan(TableScanNode node, Void context) {
tableMetadataIndexScan(node, Collections.emptyList());
@@ -391,4 +859,38 @@ public class PushPredicateIntoTableScan implements
PlanOptimizer {
return this.expressionsCannotPushDown;
}
}
+
+ private static class InnerJoinPushDownResult {
+ private final Expression leftPredicate;
+ private final Expression rightPredicate;
+ private final Expression joinPredicate;
+ private final Expression postJoinPredicate;
+
+ private InnerJoinPushDownResult(
+ Expression leftPredicate,
+ Expression rightPredicate,
+ Expression joinPredicate,
+ Expression postJoinPredicate) {
+ this.leftPredicate = leftPredicate;
+ this.rightPredicate = rightPredicate;
+ this.joinPredicate = joinPredicate;
+ this.postJoinPredicate = postJoinPredicate;
+ }
+
+ private Expression getLeftPredicate() {
+ return leftPredicate;
+ }
+
+ private Expression getRightPredicate() {
+ return rightPredicate;
+ }
+
+ private Expression getJoinPredicate() {
+ return joinPredicate;
+ }
+
+ private Expression getPostJoinPredicate() {
+ return postJoinPredicate;
+ }
+ }
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/DisjointSet.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/DisjointSet.java
new file mode 100644
index 00000000000..eb13d09d043
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/DisjointSet.java
@@ -0,0 +1,114 @@
+package org.apache.iotdb.db.queryengine.plan.relational.utils;
+
+import java.util.Collection;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Set;
+
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.base.Verify.verify;
+
+public class DisjointSet<T> {
+ private static class Entry<T> {
+ private T parent;
+ // Without path compression, this would be equal to depth. Depth of 1-node
tree is
+ // considered 0.
+ private int rank;
+
+ Entry() {
+ this(null, 0);
+ }
+
+ private Entry(T parent, int rank) {
+ this.parent = parent;
+ this.rank = rank;
+ }
+
+ public T getParent() {
+ return parent;
+ }
+
+ public void setParent(T parent) {
+ this.parent = parent;
+ this.rank = -1;
+ }
+
+ public int getRank() {
+ checkState(parent == null);
+ return rank;
+ }
+
+ public void incrementRank() {
+ checkState(parent == null);
+ rank++;
+ }
+ }
+
+ private final Map<T, Entry<T>> map;
+
+ public DisjointSet() {
+ map = new LinkedHashMap<>();
+ }
+
+ /**
+ * @return <tt>true</tt> if the specified equivalence is new
+ */
+ public boolean findAndUnion(T node1, T node2) {
+ return union(find(node1), find(node2));
+ }
+
+ public T find(T element) {
+ if (!map.containsKey(element)) {
+ map.put(element, new Entry<>());
+ return element;
+ }
+ return findInternal(element);
+ }
+
+ private boolean union(T root1, T root2) {
+ if (root1.equals(root2)) {
+ return false;
+ }
+ Entry<T> entry1 = map.get(root1);
+ Entry<T> entry2 = map.get(root2);
+ int entry1Rank = entry1.getRank();
+ int entry2Rank = entry2.getRank();
+ verify(entry1Rank >= 0);
+ verify(entry2Rank >= 0);
+ if (entry1Rank < entry2Rank) {
+ // make root1 child of root2
+ entry1.setParent(root2);
+ } else {
+ if (entry1Rank == entry2Rank) {
+ // increment rank of root1 when both side were equally deep
+ entry1.incrementRank();
+ }
+ // make root2 child of root1
+ entry2.setParent(root1);
+ }
+ return true;
+ }
+
+ private T findInternal(T element) {
+ Entry<T> value = map.get(element);
+ if (value.getParent() == null) {
+ return element;
+ }
+ T root = findInternal(value.getParent());
+ value.setParent(root);
+ return root;
+ }
+
+ public Collection<Set<T>> getEquivalentClasses() {
+ // map from root element to all element in the tree
+ Map<T, Set<T>> rootToTreeElements = new LinkedHashMap<>();
+ for (Map.Entry<T, Entry<T>> entry : map.entrySet()) {
+ T node = entry.getKey();
+ T root = findInternal(node);
+ rootToTreeElements.computeIfAbsent(root, unused -> new
LinkedHashSet<>());
+ rootToTreeElements.get(root).add(node);
+ }
+ return rootToTreeElements.values();
+ }
+}