HyukjinKwon commented on a change in pull request #25204: 
[SPARK-28441][SQL][Python] Fix error when non-foldable expression is used in 
correlated scalar subquery
URL: https://github.com/apache/spark/pull/25204#discussion_r307951390
 
 

 ##########
 File path: sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
 ##########
 @@ -1384,4 +1384,231 @@ class SubquerySuite extends QueryTest with 
SharedSQLContext {
     assert(subqueryExecs.forall(_.name.startsWith("scalar-subquery#")),
           "SubqueryExec name should start with scalar-subquery#")
   }
+
+  test("SPARK-28441: COUNT bug in WHERE clause (Filter) with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    // Case 1: Canonical example of the COUNT bug
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = 
r.c) < l.a"),
+      Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
+    // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently 
uses
+    // a rewrite that is vulnerable to the COUNT bug
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = 
r.c) = 0"),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+    // Case 3: COUNT bug without a COUNT aggregate
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(sum(r.d)) is null FROM r WHERE 
l.a = r.c)"),
+      Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug in SELECT clause (Project) with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("SELECT a, (SELECT udf(count(*)) FROM r WHERE l.a = r.c) AS cnt FROM 
l"),
+      Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: 
Row(null, 0)
+        :: Row(null, 0) :: Row(6, 1) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug in HAVING clause (Filter) with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("""
+            |SELECT
+            |  l.a AS grp_a
+            |FROM l GROUP BY l.a
+            |HAVING
+            |  (
+            |    SELECT udf(count(*)) FROM r WHERE grp_a = r.c
+            |  ) = 0
+            |ORDER BY grp_a""".stripMargin),
+      Row(null) :: Row(1) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug in Aggregate with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("""
+            |SELECT
+            |  l.a AS aval,
+            |  sum(
+            |    (
+            |      SELECT udf(count(*)) FROM r WHERE l.a = r.c
+            |    )
+            |  ) AS cnt
+            |FROM l GROUP BY l.a ORDER BY aval""".stripMargin),
+      Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1)  :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug negative examples with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    // Case 1: Potential COUNT bug case that was working correctly prior to 
the fix
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(sum(r.d)) FROM r WHERE l.a = 
r.c) is null"),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil)
+    // Case 2: COUNT aggregate but no COUNT bug due to > 0 test.
+    checkAnswer(
+      sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = 
r.c) > 0"),
+      Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil)
+    // Case 3: COUNT inside aggregate expression but no COUNT bug.
+    checkAnswer(
+      sql("""
+            |SELECT
+            |  l.a
+            |FROM l
+            |WHERE
+            |  (
+            |    SELECT udf(count(*)) + udf(sum(r.d))
+            |    FROM r WHERE l.a = r.c
+            |  ) = 0""".stripMargin),
+      Nil)
+  }
+
+  test("SPARK-28441: COUNT bug in nested subquery with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("""
+            |SELECT l.a FROM l
+            |WHERE (
+            |    SELECT cntPlusOne + 1 AS cntPlusTwo FROM (
+            |        SELECT cnt + 1 AS cntPlusOne FROM (
+            |            SELECT udf(sum(r.c)) s, udf(count(*)) cnt FROM r 
WHERE l.a = r.c
+            |                   HAVING cnt = 0
+            |        )
+            |    )
+            |) = 2""".stripMargin),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug with nasty predicate expr with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
+
+    checkAnswer(
+      sql("""
+            |SELECT
+            |  l.a
+            |FROM l WHERE
+            |  (
+            |    SELECT CASE WHEN udf(count(*)) = 1 THEN null ELSE 
udf(count(*)) END AS cnt
+            |    FROM r WHERE l.a = r.c
+            |  ) = 0""".stripMargin),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+  }
+
+  test("SPARK-28441: COUNT bug with attribute ref in subquery input and output 
with PythonUDF") {
+    import IntegratedUDFTestUtils._
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+    registerTestUDF(pythonTestUDF, spark)
 
 Review comment:
   BTW, we should add `assume(shouldTestPythonUDFs)`. Maybe it's not a biggie 
in general but it can matter in other venders' testing base. For instance, if 
somebody launches a test in a minimal docker image, it might make the tests 
failed suddenly.
   
   This skipping stuff isn't completely new in our test base. See 
`TestUtils.testCommandAvailable` for instance.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to