asfgit closed pull request #23315: [SPARK-26366][SQL] ReplaceExceptWithFilter
should consider NULL as False
URL: https://github.com/apache/spark/pull/23315
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
index efd3944eba7f5..4996d24dfd298 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
@@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
* Note:
* Before flipping the filter condition of the right node, we should:
* 1. Combine all it's [[Filter]].
- * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL
values in the condition).
+ * 2. Update the attribute references to the left node;
+ * 3. Add a Coalesce(condition, False) (to take into account of NULL values in
the condition).
*/
object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
@@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
plan.transform {
case e @ Except(left, right, false) if isEligible(left, right) =>
- val newCondition = transformCondition(left, skipProject(right))
- newCondition.map { c =>
- Distinct(Filter(Not(c), left))
- }.getOrElse {
+ val filterCondition =
combineFilters(skipProject(right)).asInstanceOf[Filter].condition
+ if (filterCondition.deterministic) {
+ transformCondition(left, filterCondition).map { c =>
+ Distinct(Filter(Not(c), left))
+ }.getOrElse {
+ e
+ }
+ } else {
e
}
}
}
- private def transformCondition(left: LogicalPlan, right: LogicalPlan):
Option[Expression] = {
- val filterCondition =
-
InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition
-
- val attributeNameMap: Map[String, Attribute] = left.output.map(x =>
(x.name, x)).toMap
-
- if (filterCondition.references.forall(r =>
attributeNameMap.contains(r.name))) {
- Some(filterCondition.transform { case a: AttributeReference =>
attributeNameMap(a.name) })
+ private def transformCondition(plan: LogicalPlan, condition: Expression):
Option[Expression] = {
+ val attributeNameMap: Map[String, Attribute] = plan.output.map(x =>
(x.name, x)).toMap
+ if (condition.references.forall(r => attributeNameMap.contains(r.name))) {
+ val rewrittenCondition = condition.transform {
+ case a: AttributeReference => attributeNameMap(a.name)
+ }
+ // We need to consider as False when the condition is NULL, otherwise we
do not return those
+ // rows containing NULL which are instead filtered in the Except right
plan
+ Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral)))
} else {
None
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index 3b1b2d588ef67..c8e15c7da763e 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If,
Literal, Not}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.BooleanType
class ReplaceOperatorSuite extends PlanTest {
@@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)),
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1,
Literal.FalseLiteral))),
Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze
comparePlans(optimized, correctAnswer)
@@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)), table1)).analyze
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1,
Literal.FalseLiteral))),
+ table1)).analyze
comparePlans(optimized, correctAnswer)
}
@@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)),
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1,
Literal.FalseLiteral))),
Project(Seq(attributeA, attributeB), table1))).analyze
comparePlans(optimized, correctAnswer)
@@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)),
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1,
Literal.FalseLiteral))),
Filter(attributeB === 2, Filter(attributeA === 1,
table1)))).analyze
comparePlans(optimized, correctAnswer)
@@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA === 1 && attributeB === 2)),
+ Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2,
Literal.FalseLiteral))),
Project(Seq(attributeA, attributeB),
Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze
@@ -229,4 +226,29 @@ class ReplaceOperatorSuite extends PlanTest {
comparePlans(optimized, query)
}
+
+ test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") {
+ val basePlan = LocalRelation(Seq('a.int, 'b.int))
+ val otherPlan = basePlan.where('a.in(1, 2) || 'b.in())
+ val except = Except(basePlan, otherPlan, false)
+ val result = OptimizeIn(Optimize.execute(except.analyze))
+ val correctAnswer = Aggregate(basePlan.output, basePlan.output,
+ Filter(!Coalesce(Seq(
+ 'a.in(1, 2) || If('b.isNotNull, Literal.FalseLiteral, Literal(null,
BooleanType)),
+ Literal.FalseLiteral)),
+ basePlan)).analyze
+ comparePlans(result, correctAnswer)
+ }
+
+ test("SPARK-26366: ReplaceExceptWithFilter should not transform
non-detrministic") {
+ val basePlan = LocalRelation(Seq('a.int, 'b.int))
+ val otherPlan = basePlan.where('a > rand(1L))
+ val except = Except(basePlan, otherPlan, false)
+ val result = Optimize.execute(except.analyze)
+ val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2)
=>
+ a1 <=> a2 }.reduce( _ && _)
+ val correctAnswer = Aggregate(basePlan.output, otherPlan.output,
+ Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze
+ comparePlans(result, correctAnswer)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 525c7cef39563..c90b15814a534 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1656,6 +1656,17 @@ class DatasetSuite extends QueryTest with
SharedSQLContext {
checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1",
BigDecimal.valueOf(1.1111))))
}
+
+ test("SPARK-26366: return nulls which are not filtered in except") {
+ val inputDF = sqlContext.createDataFrame(
+ sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))),
+ StructType(Seq(
+ StructField("a", StringType, nullable = true),
+ StructField("b", StringType, nullable = true))))
+
+ val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
+ checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
+ }
}
case class TestDataUnion(x: Int, y: Int, z: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4cc8a45391996..37a8815350a53 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2899,6 +2899,44 @@ class SQLQuerySuite extends QueryTest with
SharedSQLContext {
}
}
}
+
+ test("SPARK-26366: verify ReplaceExceptWithFilter") {
+ Seq(true, false).foreach { enabled =>
+ withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) {
+ val df = spark.createDataFrame(
+ sparkContext.parallelize(Seq(Row(0, 3, 5),
+ Row(0, 3, null),
+ Row(null, 3, 5),
+ Row(0, null, 5),
+ Row(0, null, null),
+ Row(null, null, 5),
+ Row(null, 3, null),
+ Row(null, null, null))),
+ StructType(Seq(StructField("c1", IntegerType),
+ StructField("c2", IntegerType),
+ StructField("c3", IntegerType))))
+ val where = "c2 >= 3 OR c1 >= 0"
+ val whereNullSafe =
+ """
+ |(c2 IS NOT NULL AND c2 >= 3)
+ |OR (c1 IS NOT NULL AND c1 >= 0)
+ """.stripMargin
+
+ val df_a = df.filter(where)
+ val df_b = df.filter(whereNullSafe)
+ checkAnswer(df.except(df_a), df.except(df_b))
+
+ val whereWithIn = "c2 >= 3 OR c1 in (2)"
+ val whereWithInNullSafe =
+ """
+ |(c2 IS NOT NULL AND c2 >= 3)
+ """.stripMargin
+ val dfIn_a = df.filter(whereWithIn)
+ val dfIn_b = df.filter(whereWithInNullSafe)
+ checkAnswer(df.except(dfIn_a), df.except(dfIn_b))
+ }
+ }
+ }
}
case class Foo(bar: Option[String])
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]