[ 
https://issues.apache.org/jira/browse/SPARK-26366?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Reynold Xin updated SPARK-26366:
--------------------------------
    Comment: was deleted

(was: asfgit closed pull request #23315: [SPARK-26366][SQL] 
ReplaceExceptWithFilter should consider NULL as False
URL: https://github.com/apache/spark/pull/23315
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
index efd3944eba7f5..4996d24dfd298 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
@@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
  * Note:
  * Before flipping the filter condition of the right node, we should:
  * 1. Combine all it's [[Filter]].
- * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL 
values in the condition).
+ * 2. Update the attribute references to the left node;
+ * 3. Add a Coalesce(condition, False) (to take into account of NULL values in 
the condition).
  */
 object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
 
@@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
 
     plan.transform {
       case e @ Except(left, right, false) if isEligible(left, right) =>
-        val newCondition = transformCondition(left, skipProject(right))
-        newCondition.map { c =>
-          Distinct(Filter(Not(c), left))
-        }.getOrElse {
+        val filterCondition = 
combineFilters(skipProject(right)).asInstanceOf[Filter].condition
+        if (filterCondition.deterministic) {
+          transformCondition(left, filterCondition).map { c =>
+            Distinct(Filter(Not(c), left))
+          }.getOrElse {
+            e
+          }
+        } else {
           e
         }
     }
   }
 
-  private def transformCondition(left: LogicalPlan, right: LogicalPlan): 
Option[Expression] = {
-    val filterCondition =
-      
InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition
-
-    val attributeNameMap: Map[String, Attribute] = left.output.map(x => 
(x.name, x)).toMap
-
-    if (filterCondition.references.forall(r => 
attributeNameMap.contains(r.name))) {
-      Some(filterCondition.transform { case a: AttributeReference => 
attributeNameMap(a.name) })
+  private def transformCondition(plan: LogicalPlan, condition: Expression): 
Option[Expression] = {
+    val attributeNameMap: Map[String, Attribute] = plan.output.map(x => 
(x.name, x)).toMap
+    if (condition.references.forall(r => attributeNameMap.contains(r.name))) {
+      val rewrittenCondition = condition.transform {
+        case a: AttributeReference => attributeNameMap(a.name)
+      }
+      // We need to consider as False when the condition is NULL, otherwise we 
do not return those
+      // rows containing NULL which are instead filtered in the Except right 
plan
+      Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral)))
     } else {
       None
     }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index 3b1b2d588ef67..c8e15c7da763e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If, 
Literal, Not}
 import org.apache.spark.sql.catalyst.expressions.aggregate.First
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.BooleanType
 
 class ReplaceOperatorSuite extends PlanTest {
 
@@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA >= 2 && attributeB < 1)),
+        Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, 
Literal.FalseLiteral))),
           Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze
 
     comparePlans(optimized, correctAnswer)
@@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA >= 2 && attributeB < 1)), table1)).analyze
+        Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, 
Literal.FalseLiteral))),
+          table1)).analyze
 
     comparePlans(optimized, correctAnswer)
   }
@@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA >= 2 && attributeB < 1)),
+        Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, 
Literal.FalseLiteral))),
           Project(Seq(attributeA, attributeB), table1))).analyze
 
     comparePlans(optimized, correctAnswer)
@@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-          Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-            (attributeA >= 2 && attributeB < 1)),
+          Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, 
Literal.FalseLiteral))),
             Filter(attributeB === 2, Filter(attributeA === 1, 
table1)))).analyze
 
     comparePlans(optimized, correctAnswer)
@@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest {
 
     val correctAnswer =
       Aggregate(table1.output, table1.output,
-        Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
-          (attributeA === 1 && attributeB === 2)),
+        Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2, 
Literal.FalseLiteral))),
           Project(Seq(attributeA, attributeB),
             Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze
 
@@ -229,4 +226,29 @@ class ReplaceOperatorSuite extends PlanTest {
 
     comparePlans(optimized, query)
   }
+
+  test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") {
+    val basePlan = LocalRelation(Seq('a.int, 'b.int))
+    val otherPlan = basePlan.where('a.in(1, 2) || 'b.in())
+    val except = Except(basePlan, otherPlan, false)
+    val result = OptimizeIn(Optimize.execute(except.analyze))
+    val correctAnswer = Aggregate(basePlan.output, basePlan.output,
+      Filter(!Coalesce(Seq(
+        'a.in(1, 2) || If('b.isNotNull, Literal.FalseLiteral, Literal(null, 
BooleanType)),
+        Literal.FalseLiteral)),
+        basePlan)).analyze
+    comparePlans(result, correctAnswer)
+  }
+
+  test("SPARK-26366: ReplaceExceptWithFilter should not transform 
non-detrministic") {
+    val basePlan = LocalRelation(Seq('a.int, 'b.int))
+    val otherPlan = basePlan.where('a > rand(1L))
+    val except = Except(basePlan, otherPlan, false)
+    val result = Optimize.execute(except.analyze)
+    val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) 
=>
+      a1 <=> a2 }.reduce( _ && _)
+    val correctAnswer = Aggregate(basePlan.output, otherPlan.output,
+      Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze
+    comparePlans(result, correctAnswer)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 525c7cef39563..c90b15814a534 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1656,6 +1656,17 @@ class DatasetSuite extends QueryTest with 
SharedSQLContext {
     checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
       Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", 
BigDecimal.valueOf(1.1111))))
   }
+
+  test("SPARK-26366: return nulls which are not filtered in except") {
+    val inputDF = sqlContext.createDataFrame(
+      sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))),
+      StructType(Seq(
+        StructField("a", StringType, nullable = true),
+        StructField("b", StringType, nullable = true))))
+
+    val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
+    checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
+  }
 }
 
 case class TestDataUnion(x: Int, y: Int, z: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4cc8a45391996..37a8815350a53 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2899,6 +2899,44 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
       }
     }
   }
+
+  test("SPARK-26366: verify ReplaceExceptWithFilter") {
+    Seq(true, false).foreach { enabled =>
+      withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) {
+        val df = spark.createDataFrame(
+          sparkContext.parallelize(Seq(Row(0, 3, 5),
+            Row(0, 3, null),
+            Row(null, 3, 5),
+            Row(0, null, 5),
+            Row(0, null, null),
+            Row(null, null, 5),
+            Row(null, 3, null),
+            Row(null, null, null))),
+          StructType(Seq(StructField("c1", IntegerType),
+            StructField("c2", IntegerType),
+            StructField("c3", IntegerType))))
+        val where = "c2 >= 3 OR c1 >= 0"
+        val whereNullSafe =
+          """
+            |(c2 IS NOT NULL AND c2 >= 3)
+            |OR (c1 IS NOT NULL AND c1 >= 0)
+          """.stripMargin
+
+        val df_a = df.filter(where)
+        val df_b = df.filter(whereNullSafe)
+        checkAnswer(df.except(df_a), df.except(df_b))
+
+        val whereWithIn = "c2 >= 3 OR c1 in (2)"
+        val whereWithInNullSafe =
+          """
+            |(c2 IS NOT NULL AND c2 >= 3)
+          """.stripMargin
+        val dfIn_a = df.filter(whereWithIn)
+        val dfIn_b = df.filter(whereWithInNullSafe)
+        checkAnswer(df.except(dfIn_a), df.except(dfIn_b))
+      }
+    }
+  }
 }
 
 case class Foo(bar: Option[String])


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org
)

> Except with transform regression
> --------------------------------
>
>                 Key: SPARK-26366
>                 URL: https://issues.apache.org/jira/browse/SPARK-26366
>             Project: Spark
>          Issue Type: Bug
>          Components: Spark Core, SQL
>    Affects Versions: 2.3.2
>            Reporter: Dan Osipov
>            Assignee: Marco Gaido
>            Priority: Major
>              Labels: correctness
>             Fix For: 2.3.3, 2.4.1, 3.0.0
>
>
> There appears to be a regression between Spark 2.2 and 2.3. Below is the code 
> to reproduce it:
>  
> {code:java}
> import org.apache.spark.sql.functions.col
> import org.apache.spark.sql.Row
> import org.apache.spark.sql.types._
> val inputDF = spark.sqlContext.createDataFrame(
>   spark.sparkContext.parallelize(Seq(
>     Row("0", "john", "smith", "j...@smith.com"),
>     Row("1", "jane", "doe", "j...@doe.com"),
>     Row("2", "apache", "spark", "sp...@apache.org"),
>     Row("3", "foo", "bar", null)
>   )),
>   StructType(List(
>     StructField("id", StringType, nullable=true),
>     StructField("first_name", StringType, nullable=true),
>     StructField("last_name", StringType, nullable=true),
>     StructField("email", StringType, nullable=true)
>   ))
> )
> val exceptDF = inputDF.transform( toProcessDF =>
>   toProcessDF.filter(
>       (
>         col("first_name").isin(Seq("john", "jane"): _*)
>           and col("last_name").isin(Seq("smith", "doe"): _*)
>       )
>       or col("email").isin(List(): _*)
>   )
> )
> inputDF.except(exceptDF).show()
> {code}
> Output with Spark 2.2:
> {noformat}
> +---+----------+---------+----------------+
> | id|first_name|last_name| email|
> +---+----------+---------+----------------+
> | 2| apache| spark|sp...@apache.org|
> | 3| foo| bar| null|
> +---+----------+---------+----------------+{noformat}
> Output with Spark 2.3:
> {noformat}
> +---+----------+---------+----------------+
> | id|first_name|last_name| email|
> +---+----------+---------+----------------+
> | 2| apache| spark|sp...@apache.org|
> +---+----------+---------+----------------+{noformat}
> Note, changing the last line to 
> {code:java}
> inputDF.except(exceptDF.cache()).show()
> {code}
> produces identical output for both Spark 2.3 and 2.2
>  



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to