This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 79e8df84309 [SPARK-42126][PYTHON][CONNECT] Accept return type in DDL 
strings for Python Scalar UDFs in Spark Connect
79e8df84309 is described below

commit 79e8df84309ed54d0c3fc7face414e6c440daa81
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Thu Jan 26 19:15:13 2023 +0800

    [SPARK-42126][PYTHON][CONNECT] Accept return type in DDL strings for Python 
Scalar UDFs in Spark Connect
    
    ### What changes were proposed in this pull request?
    Accept return type in DDL strings for Python Scalar UDFs in Spark Connect.
    
    The approach proposed in this PR is a workaround to parse DataType from DDL 
strings. We should think of a more elegant alternative to replace that later.
    
    ### Why are the changes needed?
    To reach parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. Return type in DDL strings are accepted now.
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #39739 from xinrong-meng/datatype_ddl.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
    (cherry picked from commit dbd667e7bc5fee443b8a39ca56d4cf3dd1bb2bae)
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
---
 python/pyspark/sql/connect/udf.py                    | 20 +++++++++++++++++++-
 .../sql/tests/connect/test_connect_function.py       |  8 ++++++++
 2 files changed, 27 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
index 4a465084838..d0eb2fdfe6c 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -28,6 +28,7 @@ from pyspark.sql.connect.expressions import (
 )
 from pyspark.sql.connect.column import Column
 from pyspark.sql.types import DataType, StringType
+from pyspark.sql.utils import is_remote
 
 
 if TYPE_CHECKING:
@@ -90,7 +91,24 @@ class UserDefinedFunction:
             )
 
         self.func = func
-        self._returnType = returnType
+
+        if isinstance(returnType, str):
+            # Currently we don't have a way to have a current Spark session in 
Spark Connect, and
+            # pyspark.sql.SparkSession has a centralized logic to control the 
session creation.
+            # So uses pyspark.sql.SparkSession for now. Should replace this to 
using the current
+            # Spark session for Spark Connect in the future.
+            from pyspark.sql import SparkSession as PySparkSession
+
+            assert is_remote()
+            return_type_schema = (  # a workaround to parse the DataType from 
DDL strings
+                PySparkSession.builder.getOrCreate()
+                .createDataFrame(data=[], schema=returnType)
+                .schema
+            )
+            assert len(return_type_schema.fields) == 1, "returnType should be 
singular"
+            self._returnType = return_type_schema.fields[0].dataType
+        else:
+            self._returnType = returnType
         self._name = name or (
             func.__name__ if hasattr(func, "__name__") else 
func.__class__.__name__
         )
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 7042a7e8e6f..50fadb49ed4 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -2299,6 +2299,14 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, 
PandasOnSparkTestUtils, S
             cdf.withColumn("A", CF.udf(lambda x: x + 1)(cdf.a)).toPandas(),
             sdf.withColumn("A", SF.udf(lambda x: x + 1)(sdf.a)).toPandas(),
         )
+        self.assert_eq(  # returnType as DDL strings
+            cdf.withColumn("C", CF.udf(lambda x: len(x), 
"int")(cdf.c)).toPandas(),
+            sdf.withColumn("C", SF.udf(lambda x: len(x), 
"int")(sdf.c)).toPandas(),
+        )
+        self.assert_eq(  # returnType as DataType
+            cdf.withColumn("C", CF.udf(lambda x: len(x), 
IntegerType())(cdf.c)).toPandas(),
+            sdf.withColumn("C", SF.udf(lambda x: len(x), 
IntegerType())(sdf.c)).toPandas(),
+        )
 
         # as a decorator
         @CF.udf(StringType())


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to