Github user aokolnychyi commented on a diff in the pull request:
https://github.com/apache/spark/pull/18692#discussion_r137343500
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
---
@@ -152,3 +152,71 @@ object EliminateOuterJoin extends Rule[LogicalPlan]
with PredicateHelper {
if (j.joinType == newJoinType) f else Filter(condition,
j.copy(joinType = newJoinType))
}
}
+
+/**
+ * A rule that uses propagated constraints to infer join conditions. The
optimization is applicable
+ * only to CROSS joins.
+ *
+ * For instance, if there is a CROSS join, where the left relation has 'a
= 1' and the right
+ * relation has 'b = 1', then the rule infers 'a = b' as a join predicate.
+ */
+object InferJoinConditionsFromConstraints extends Rule[LogicalPlan] with
PredicateHelper {
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ if (SQLConf.get.constraintPropagationEnabled) {
+ inferJoinConditions(plan)
+ } else {
+ plan
+ }
+ }
+
+ private def inferJoinConditions(plan: LogicalPlan): LogicalPlan = plan
transform {
+ case join @ Join(left, right, Cross, conditionOpt) =>
+ val leftConstraints =
join.constraints.filter(_.references.subsetOf(left.outputSet))
+ val rightConstraints =
join.constraints.filter(_.references.subsetOf(right.outputSet))
+ val inferredJoinPredicates = inferJoinPredicates(leftConstraints,
rightConstraints)
+
+ val newConditionOpt = conditionOpt match {
+ case Some(condition) =>
+ val existingPredicates = splitConjunctivePredicates(condition)
+ val newPredicates = findNewPredicates(inferredJoinPredicates,
existingPredicates)
+ if (newPredicates.nonEmpty) Some(And(newPredicates.reduce(And),
condition)) else None
+ case None =>
+ inferredJoinPredicates.reduceOption(And)
+ }
+ if (newConditionOpt.isDefined) Join(left, right, Inner,
newConditionOpt) else join
--- End diff --
And what about CROSS joins with join conditions? Not sure if they will
benefit from the proposed rule, but it is better to ask.
```
Seq((1, 2)).toDF("col1", "col2").write.saveAsTable("t1")
Seq((1, 2)).toDF("col1", "col2").write.saveAsTable("t2")
val df = spark.sql("SELECT * FROM t1 CROSS JOIN t2 ON t1.col1 >= t2.col1
WHERE t1.col1 = 1 AND t2.col1 = 1")
df.explain(true)
== Optimized Logical Plan ==
Join Cross, (col1#40 >= col1#42)
:- Filter (isnotnull(col1#40) && (col1#40 = 1))
: +- Relation[col1#40,col2#41] parquet
+- Filter (isnotnull(col1#42) && (col1#42 = 1))
+- Relation[col1#42,col2#43] parquet
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]