This is an automated email from the ASF dual-hosted git repository.
xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dbd667e7bc5 [SPARK-42126][PYTHON][CONNECT] Accept return type in DDL
strings for Python Scalar UDFs in Spark Connect
dbd667e7bc5 is described below
commit dbd667e7bc5fee443b8a39ca56d4cf3dd1bb2bae
Author: Xinrong Meng <[email protected]>
AuthorDate: Thu Jan 26 19:15:13 2023 +0800
[SPARK-42126][PYTHON][CONNECT] Accept return type in DDL strings for Python
Scalar UDFs in Spark Connect
### What changes were proposed in this pull request?
Accept return type in DDL strings for Python Scalar UDFs in Spark Connect.
The approach proposed in this PR is a workaround to parse DataType from DDL
strings. We should think of a more elegant alternative to replace that later.
### Why are the changes needed?
To reach parity with vanilla PySpark.
### Does this PR introduce _any_ user-facing change?
Yes. Return type in DDL strings are accepted now.
### How was this patch tested?
Unit tests.
Closes #39739 from xinrong-meng/datatype_ddl.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Xinrong Meng <[email protected]>
---
python/pyspark/sql/connect/udf.py | 20 +++++++++++++++++++-
.../sql/tests/connect/test_connect_function.py | 8 ++++++++
2 files changed, 27 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/connect/udf.py
b/python/pyspark/sql/connect/udf.py
index 4a465084838..d0eb2fdfe6c 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -28,6 +28,7 @@ from pyspark.sql.connect.expressions import (
)
from pyspark.sql.connect.column import Column
from pyspark.sql.types import DataType, StringType
+from pyspark.sql.utils import is_remote
if TYPE_CHECKING:
@@ -90,7 +91,24 @@ class UserDefinedFunction:
)
self.func = func
- self._returnType = returnType
+
+ if isinstance(returnType, str):
+ # Currently we don't have a way to have a current Spark session in
Spark Connect, and
+ # pyspark.sql.SparkSession has a centralized logic to control the
session creation.
+ # So uses pyspark.sql.SparkSession for now. Should replace this to
using the current
+ # Spark session for Spark Connect in the future.
+ from pyspark.sql import SparkSession as PySparkSession
+
+ assert is_remote()
+ return_type_schema = ( # a workaround to parse the DataType from
DDL strings
+ PySparkSession.builder.getOrCreate()
+ .createDataFrame(data=[], schema=returnType)
+ .schema
+ )
+ assert len(return_type_schema.fields) == 1, "returnType should be
singular"
+ self._returnType = return_type_schema.fields[0].dataType
+ else:
+ self._returnType = returnType
self._name = name or (
func.__name__ if hasattr(func, "__name__") else
func.__class__.__name__
)
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 7042a7e8e6f..50fadb49ed4 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -2299,6 +2299,14 @@ class SparkConnectFunctionTests(ReusedConnectTestCase,
PandasOnSparkTestUtils, S
cdf.withColumn("A", CF.udf(lambda x: x + 1)(cdf.a)).toPandas(),
sdf.withColumn("A", SF.udf(lambda x: x + 1)(sdf.a)).toPandas(),
)
+ self.assert_eq( # returnType as DDL strings
+ cdf.withColumn("C", CF.udf(lambda x: len(x),
"int")(cdf.c)).toPandas(),
+ sdf.withColumn("C", SF.udf(lambda x: len(x),
"int")(sdf.c)).toPandas(),
+ )
+ self.assert_eq( # returnType as DataType
+ cdf.withColumn("C", CF.udf(lambda x: len(x),
IntegerType())(cdf.c)).toPandas(),
+ sdf.withColumn("C", SF.udf(lambda x: len(x),
IntegerType())(sdf.c)).toPandas(),
+ )
# as a decorator
@CF.udf(StringType())
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]