This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5d47aae836f4 [SPARK-46044][PYTHON][TESTS] Improve test coverage of
udf.py
5d47aae836f4 is described below
commit 5d47aae836f45f1c28b95ddf23b2ddd8d4cef9ae
Author: Xinrong Meng <[email protected]>
AuthorDate: Wed Dec 6 08:58:34 2023 +0900
[SPARK-46044][PYTHON][TESTS] Improve test coverage of udf.py
### What changes were proposed in this pull request?
Improve test coverage of udf.py
### Why are the changes needed?
Subtasks of
[SPARK-46041](https://issues.apache.org/jira/browse/SPARK-46041) to improve
test coverage
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Test changes only.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43947 from xinrong-meng/test_udf.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../pyspark/sql/tests/connect/test_parity_udf.py | 6 +++
python/pyspark/sql/tests/test_udf.py | 53 +++++++++++++++++++++-
2 files changed, 58 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py
b/python/pyspark/sql/tests/connect/test_parity_udf.py
index 1be7d69b8c32..fc9942af6ea5 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf.py
@@ -86,6 +86,12 @@ class UDFParityTests(BaseUDFTestsMixin,
ReusedConnectTestCase):
def test_udf_registration_returns_udf_on_sql_context(self):
super().test_udf_registration_returns_udf_on_sql_context()
+ def test_err_udf_registration(self):
+ self.check_err_udf_registration()
+
+ def test_err_udf_init(self):
+ self.check_err_udf_init()
+
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/tests/test_udf.py
b/python/pyspark/sql/tests/test_udf.py
index b070124cb458..33c4001867c4 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -377,12 +377,17 @@ class BaseUDFTestsMixin(object):
def test_udf_registration_returns_udf(self):
df = self.spark.range(10)
add_three = self.spark.udf.register("add_three", lambda x: x + 3,
IntegerType())
-
self.assertListEqual(
df.selectExpr("add_three(id) AS plus_three").collect(),
df.select(add_three("id").alias("plus_three")).collect(),
)
+ add_three_str = self.spark.udf.register("add_three_str", lambda x: x +
3)
+ self.assertListEqual(
+ df.selectExpr("add_three_str(id) AS plus_three").collect(),
+ df.select(add_three_str("id").alias("plus_three")).collect(),
+ )
+
def test_udf_registration_returns_udf_on_sql_context(self):
df = self.spark.range(10)
@@ -425,6 +430,20 @@ class BaseUDFTestsMixin(object):
).first()
self.assertEqual(row.asDict(), Row(name="b", avg=102.0).asDict())
+ def test_err_udf_registration(self):
+ with QuietTest(self.sc):
+ self.check_err_udf_registration()
+
+ def check_err_udf_registration(self):
+ with self.assertRaises(PySparkTypeError) as pe:
+ self.spark.udf.register("f", UserDefinedFunction("x",
StringType()), "int")
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_CALLABLE",
+ message_parameters={"arg_name": "func", "arg_type": "str"},
+ )
+
def test_non_existed_udf(self):
spark = self.spark
self.assertRaisesRegex(
@@ -1027,6 +1046,38 @@ class BaseUDFTestsMixin(object):
self.spark.range(1).select(udf(lambda x:
ctypes.string_at(0))("id")).collect()
+ def test_err_udf_init(self):
+ with QuietTest(self.sc):
+ self.check_err_udf_init()
+
+ def check_err_udf_init(self):
+ with self.assertRaises(PySparkTypeError) as pe:
+ UserDefinedFunction("x", StringType())
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_CALLABLE",
+ message_parameters={"arg_name": "func", "arg_type": "str"},
+ )
+
+ with self.assertRaises(PySparkTypeError) as pe:
+ UserDefinedFunction(lambda x: x, 1)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_DATATYPE_OR_STR",
+ message_parameters={"arg_name": "returnType", "arg_type": "int"},
+ )
+
+ with self.assertRaises(PySparkTypeError) as pe:
+ UserDefinedFunction(lambda x: x, StringType(),
evalType="SQL_BATCHED_UDF")
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_INT",
+ message_parameters={"arg_name": "evalType", "arg_type": "str"},
+ )
+
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]