viirya commented on a change in pull request #25204: [SPARK-28441][SQL][Python] 
Fix error when non-foldable expression is used in correlated scalar subquery
URL: https://github.com/apache/spark/pull/25204#discussion_r307333116
 
 

 ##########
 File path: sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
 ##########
 @@ -1384,4 +1384,222 @@ class SubquerySuite extends QueryTest with 
SharedSQLContext {
     assert(subqueryExecs.forall(_.name.startsWith("scalar-subquery#")),
           "SubqueryExec name should start with scalar-subquery#")
   }
+
+  test("SPARK-28441: COUNT bug in WHERE clause (Filter) with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    // Case 1: Canonical example of the COUNT bug
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = 
r.c) < l.a"),
+      Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
+    // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently 
uses
+    // a rewrite that is vulnerable to the COUNT bug
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = 
r.c) = 0"),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+    // Case 3: COUNT bug without a COUNT aggregate
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(sum(r.d)) is null FROM r WHERE 
l.a = r.c)"),
+      Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug in SELECT clause (Project) with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("SELECT a, (SELECT udf(count(*)) FROM r WHERE l.a = r.c) AS cnt FROM 
l"),
+      Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: 
Row(null, 0)
+        :: Row(null, 0) :: Row(6, 1) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug in HAVING clause (Filter) with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("""SELECT
 
 Review comment:
   fixed.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to