This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5d47aae836f4 [SPARK-46044][PYTHON][TESTS] Improve test coverage of 
udf.py
5d47aae836f4 is described below

commit 5d47aae836f45f1c28b95ddf23b2ddd8d4cef9ae
Author: Xinrong Meng <[email protected]>
AuthorDate: Wed Dec 6 08:58:34 2023 +0900

    [SPARK-46044][PYTHON][TESTS] Improve test coverage of udf.py
    
    ### What changes were proposed in this pull request?
    Improve test coverage of udf.py
    
    ### Why are the changes needed?
    Subtasks of 
[SPARK-46041](https://issues.apache.org/jira/browse/SPARK-46041) to improve 
test coverage
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Test changes only.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43947 from xinrong-meng/test_udf.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../pyspark/sql/tests/connect/test_parity_udf.py   |  6 +++
 python/pyspark/sql/tests/test_udf.py               | 53 +++++++++++++++++++++-
 2 files changed, 58 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py 
b/python/pyspark/sql/tests/connect/test_parity_udf.py
index 1be7d69b8c32..fc9942af6ea5 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf.py
@@ -86,6 +86,12 @@ class UDFParityTests(BaseUDFTestsMixin, 
ReusedConnectTestCase):
     def test_udf_registration_returns_udf_on_sql_context(self):
         super().test_udf_registration_returns_udf_on_sql_context()
 
+    def test_err_udf_registration(self):
+        self.check_err_udf_registration()
+
+    def test_err_udf_init(self):
+        self.check_err_udf_init()
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index b070124cb458..33c4001867c4 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -377,12 +377,17 @@ class BaseUDFTestsMixin(object):
     def test_udf_registration_returns_udf(self):
         df = self.spark.range(10)
         add_three = self.spark.udf.register("add_three", lambda x: x + 3, 
IntegerType())
-
         self.assertListEqual(
             df.selectExpr("add_three(id) AS plus_three").collect(),
             df.select(add_three("id").alias("plus_three")).collect(),
         )
 
+        add_three_str = self.spark.udf.register("add_three_str", lambda x: x + 
3)
+        self.assertListEqual(
+            df.selectExpr("add_three_str(id) AS plus_three").collect(),
+            df.select(add_three_str("id").alias("plus_three")).collect(),
+        )
+
     def test_udf_registration_returns_udf_on_sql_context(self):
         df = self.spark.range(10)
 
@@ -425,6 +430,20 @@ class BaseUDFTestsMixin(object):
         ).first()
         self.assertEqual(row.asDict(), Row(name="b", avg=102.0).asDict())
 
+    def test_err_udf_registration(self):
+        with QuietTest(self.sc):
+            self.check_err_udf_registration()
+
+    def check_err_udf_registration(self):
+        with self.assertRaises(PySparkTypeError) as pe:
+            self.spark.udf.register("f", UserDefinedFunction("x", 
StringType()), "int")
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_CALLABLE",
+            message_parameters={"arg_name": "func", "arg_type": "str"},
+        )
+
     def test_non_existed_udf(self):
         spark = self.spark
         self.assertRaisesRegex(
@@ -1027,6 +1046,38 @@ class BaseUDFTestsMixin(object):
 
                 self.spark.range(1).select(udf(lambda x: 
ctypes.string_at(0))("id")).collect()
 
+    def test_err_udf_init(self):
+        with QuietTest(self.sc):
+            self.check_err_udf_init()
+
+    def check_err_udf_init(self):
+        with self.assertRaises(PySparkTypeError) as pe:
+            UserDefinedFunction("x", StringType())
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_CALLABLE",
+            message_parameters={"arg_name": "func", "arg_type": "str"},
+        )
+
+        with self.assertRaises(PySparkTypeError) as pe:
+            UserDefinedFunction(lambda x: x, 1)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_DATATYPE_OR_STR",
+            message_parameters={"arg_name": "returnType", "arg_type": "int"},
+        )
+
+        with self.assertRaises(PySparkTypeError) as pe:
+            UserDefinedFunction(lambda x: x, StringType(), 
evalType="SQL_BATCHED_UDF")
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_INT",
+            message_parameters={"arg_name": "evalType", "arg_type": "str"},
+        )
+
 
 class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
     @classmethod


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to