This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c74f584481d9 [SPARK-48039][PYTHON][CONNECT] Update the error class for `group.apply` c74f584481d9 is described below commit c74f584481d9bcefda7e8ac2a37feb2d61891fe4 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Apr 29 20:06:22 2024 +0900 [SPARK-48039][PYTHON][CONNECT] Update the error class for `group.apply` ### What changes were proposed in this pull request? Update the error class for `group.apply` ### Why are the changes needed? https://github.com/apache/spark/commit/eae91ee3c96b6887581e59821d905b8ea94f6bc0 introduced a dedicated error class `INVALID_UDF_EVAL_TYPE` for `group.apply`, but only used it in Spark Connect. This PR uses this error class in Spark Classic, to make it consistent. And also enable a parity test `GroupedApplyInPandasTests.test_wrong_args ` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46277 from zhengruifeng/fix_test_wrong_args. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/pandas/group_ops.py | 10 ++++------ .../tests/connect/test_parity_pandas_grouped_map.py | 4 ---- .../sql/tests/pandas/test_pandas_grouped_map.py | 21 +++++++++++---------- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index d5b214e2f7d5..3d1c50d94902 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -18,7 +18,7 @@ import sys from typing import List, Union, TYPE_CHECKING, cast import warnings -from pyspark.errors import PySparkValueError +from pyspark.errors import PySparkTypeError from pyspark.util import PythonEvalType from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame @@ -100,11 +100,9 @@ class PandasGroupedOpsMixin: != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF ) ): - raise PySparkValueError( - error_class="INVALID_PANDAS_UDF", - message_parameters={ - "detail": "the udf argument must be a pandas_udf of type GROUPED_MAP." - }, + raise PySparkTypeError( + error_class="INVALID_UDF_EVAL_TYPE", + message_parameters={"eval_type": "SQL_GROUPED_MAP_PANDAS_UDF"}, ) warnings.warn( diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py index f0e7eeb606ca..1cc4ce012623 100644 --- a/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py @@ -30,10 +30,6 @@ class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin, ReusedConnectTes def test_wrong_return_type(self): super().test_wrong_return_type() - @unittest.skip("Fails in Spark Connect, should enable.") - def test_wrong_args(self): - super().test_wrong_args() - @unittest.skip("Fails in Spark Connect, should enable.") def test_unsupported_types(self): super().test_unsupported_types() diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index 0396006e2b36..f43dafc0a4a1 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -52,7 +52,7 @@ from pyspark.sql.types import ( MapType, YearMonthIntervalType, ) -from pyspark.errors import PythonException, PySparkTypeError +from pyspark.errors import PythonException, PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ( ReusedSQLTestCase, have_pandas, @@ -421,22 +421,23 @@ class GroupedApplyInPandasTestsMixin: def check_wrong_args(self): df = self.data - with self.assertRaisesRegex(ValueError, "Invalid function"): + with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"): df.groupby("id").apply(lambda x: x) - with self.assertRaisesRegex(ValueError, "Invalid function"): + with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"): df.groupby("id").apply(udf(lambda x: x, DoubleType())) - with self.assertRaisesRegex(ValueError, "Invalid function"): + with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"): df.groupby("id").apply(sum(df.v)) - with self.assertRaisesRegex(ValueError, "Invalid function"): + with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"): df.groupby("id").apply(df.v + 1) - with self.assertRaisesRegex(ValueError, "Invalid function"): + with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"): + df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType())) + with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"): + df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) + + with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"): df.groupby("id").apply( pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])) ) - with self.assertRaisesRegex(ValueError, "Invalid function"): - df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType())) - with self.assertRaisesRegex(ValueError, "Invalid function.*GROUPED_MAP"): - df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) def test_unsupported_types(self): with self.quiet(): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org