This is an automated email from the ASF dual-hosted git repository. xinrong pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 79e8df84309 [SPARK-42126][PYTHON][CONNECT] Accept return type in DDL strings for Python Scalar UDFs in Spark Connect 79e8df84309 is described below commit 79e8df84309ed54d0c3fc7face414e6c440daa81 Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Thu Jan 26 19:15:13 2023 +0800 [SPARK-42126][PYTHON][CONNECT] Accept return type in DDL strings for Python Scalar UDFs in Spark Connect ### What changes were proposed in this pull request? Accept return type in DDL strings for Python Scalar UDFs in Spark Connect. The approach proposed in this PR is a workaround to parse DataType from DDL strings. We should think of a more elegant alternative to replace that later. ### Why are the changes needed? To reach parity with vanilla PySpark. ### Does this PR introduce _any_ user-facing change? Yes. Return type in DDL strings are accepted now. ### How was this patch tested? Unit tests. Closes #39739 from xinrong-meng/datatype_ddl. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Xinrong Meng <xinr...@apache.org> (cherry picked from commit dbd667e7bc5fee443b8a39ca56d4cf3dd1bb2bae) Signed-off-by: Xinrong Meng <xinr...@apache.org> --- python/pyspark/sql/connect/udf.py | 20 +++++++++++++++++++- .../sql/tests/connect/test_connect_function.py | 8 ++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 4a465084838..d0eb2fdfe6c 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -28,6 +28,7 @@ from pyspark.sql.connect.expressions import ( ) from pyspark.sql.connect.column import Column from pyspark.sql.types import DataType, StringType +from pyspark.sql.utils import is_remote if TYPE_CHECKING: @@ -90,7 +91,24 @@ class UserDefinedFunction: ) self.func = func - self._returnType = returnType + + if isinstance(returnType, str): + # Currently we don't have a way to have a current Spark session in Spark Connect, and + # pyspark.sql.SparkSession has a centralized logic to control the session creation. + # So uses pyspark.sql.SparkSession for now. Should replace this to using the current + # Spark session for Spark Connect in the future. + from pyspark.sql import SparkSession as PySparkSession + + assert is_remote() + return_type_schema = ( # a workaround to parse the DataType from DDL strings + PySparkSession.builder.getOrCreate() + .createDataFrame(data=[], schema=returnType) + .schema + ) + assert len(return_type_schema.fields) == 1, "returnType should be singular" + self._returnType = return_type_schema.fields[0].dataType + else: + self._returnType = returnType self._name = name or ( func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ ) diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index 7042a7e8e6f..50fadb49ed4 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -2299,6 +2299,14 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, S cdf.withColumn("A", CF.udf(lambda x: x + 1)(cdf.a)).toPandas(), sdf.withColumn("A", SF.udf(lambda x: x + 1)(sdf.a)).toPandas(), ) + self.assert_eq( # returnType as DDL strings + cdf.withColumn("C", CF.udf(lambda x: len(x), "int")(cdf.c)).toPandas(), + sdf.withColumn("C", SF.udf(lambda x: len(x), "int")(sdf.c)).toPandas(), + ) + self.assert_eq( # returnType as DataType + cdf.withColumn("C", CF.udf(lambda x: len(x), IntegerType())(cdf.c)).toPandas(), + sdf.withColumn("C", SF.udf(lambda x: len(x), IntegerType())(sdf.c)).toPandas(), + ) # as a decorator @CF.udf(StringType()) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org