This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 40f30dd036c [SPARK-43966][SQL][PYTHON] Support non-deterministic 
table-valued functions
40f30dd036c is described below

commit 40f30dd036c7df949ce11c59a009bd8ebafe1f0d
Author: allisonwang-db <[email protected]>
AuthorDate: Fri Jul 21 12:21:14 2023 +0800

    [SPARK-43966][SQL][PYTHON] Support non-deterministic table-valued functions
    
    ### What changes were proposed in this pull request?
    
    This PR supports non-deterministic table-valued functions. More 
specifically, it supports running non-deterministic Python UDTFs and built-in 
table-valued generator functions with non-deterministic input values.
    
    ### Why are the changes needed?
    
    To make table-valued functions more versatile.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Before this PR, Spark will throw an exception when running a 
non-deterministic Python UDTF:
    ```
    select * from random_udtf(1)
    AnalysisException: [INVALID_NON_DETERMINISTIC_EXPRESSIONS] The operator 
expects a deterministic expression,
    ```
    
    After this PR, it is supported.
    
    ### How was this patch tested?
    
    Existing and new unit tests.
    
    Closes #42075 from allisonwang-db/spark-43966-non-det-udtf.
    
    Authored-by: allisonwang-db <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit 1fb3e16a48d826aed1ca9688a661281f750bbf5a)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 python/pyspark/sql/tests/test_udtf.py                  | 18 ++++++------------
 .../spark/sql/catalyst/analysis/CheckAnalysis.scala    |  1 +
 .../analyzer-results/table-valued-functions.sql.out    |  6 ++++++
 .../sql-tests/inputs/table-valued-functions.sql        |  3 +++
 .../sql-tests/results/table-valued-functions.sql.out   |  8 ++++++++
 5 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index ec3379accca..2c76d2f7e15 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -355,14 +355,12 @@ class BaseUDTFTestsMixin:
 
         class RandomUDTF:
             def eval(self, a: int):
-                yield a * int(random.random() * 100),
+                yield a + int(random.random()),
 
         random_udtf = udtf(RandomUDTF, returnType="x: 
int").asNondeterministic()
-        # TODO(SPARK-43966): support non-deterministic UDTFs
-        with self.assertRaisesRegex(
-            AnalysisException, "The operator expects a deterministic 
expression"
-        ):
-            random_udtf(lit(1)).collect()
+        assertDataFrameEqual(random_udtf(lit(1)), [Row(x=1)])
+        self.spark.udtf.register("random_udtf", random_udtf)
+        assertDataFrameEqual(self.spark.sql("select * from random_udtf(1)"), 
[Row(x=1)])
 
     def test_udtf_with_nondeterministic_input(self):
         from pyspark.sql.functions import rand
@@ -370,13 +368,9 @@ class BaseUDTFTestsMixin:
         @udtf(returnType="x: int")
         class TestUDTF:
             def eval(self, a: int):
-                yield a + 1,
+                yield 1 if a > 100 else 0,
 
-        # TODO(SPARK-43966): support non-deterministic UDTFs
-        with self.assertRaisesRegex(
-            AnalysisException, " The operator expects a deterministic 
expression"
-        ):
-            TestUDTF(rand(0) * 100).collect()
+        assertDataFrameEqual(TestUDTF(rand(0) * 100), [Row(x=0)])
 
     def test_udtf_with_invalid_return_type(self):
         @udtf(returnType="int")
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index d933ea26d5d..e198fd58953 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -752,6 +752,7 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
             !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] &&
             !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] &&
             !o.isInstanceOf[Expand] &&
+            !o.isInstanceOf[Generate] &&
             // Lateral join is checked in checkSubqueryExpression.
             !o.isInstanceOf[LateralJoin] =>
             // The rule above is used to check Aggregate operator.
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out
index 49ad4bf19f7..6c29a0ec1db 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out
@@ -205,6 +205,12 @@ Project [k#x, v#x]
          +- OneRowRelation
 
 
+-- !query
+select * from explode(array(rand(0)))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
 -- !query
 select * from explode(null)
 -- !query analysis
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
index 2b809f9a7c8..79d427bc209 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
@@ -43,6 +43,9 @@ select * from explode(map());
 select * from explode(array(1, 2)) t(c1);
 select * from explode(map('a', 1, 'b', 2)) t(k, v);
 
+-- explode with non-deterministic values
+select * from explode(array(rand(0)));
+
 -- explode with erroneous input
 select * from explode(null);
 select * from explode(null) t(c1);
diff --git 
a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
index 578461d164a..1348110a83a 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
@@ -242,6 +242,14 @@ a  1
 b      2
 
 
+-- !query
+select * from explode(array(rand(0)))
+-- !query schema
+struct<col:double>
+-- !query output
+0.7604953758285915
+
+
 -- !query
 select * from explode(null)
 -- !query schema


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to