This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f72ff1b096ed [SPARK-50681][PYTHON][CONNECT] Cache the parsed schema 
for MapInXXX and ApplyInXXX
f72ff1b096ed is described below

commit f72ff1b096ed07b717d203847d67f33208a46e4c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Dec 27 13:09:43 2024 +0800

    [SPARK-50681][PYTHON][CONNECT] Cache the parsed schema for MapInXXX and 
ApplyInXXX
    
    ### What changes were proposed in this pull request?
    Cache the parsed schema for MapInXXX and ApplyInXXX
    
    ### Why are the changes needed?
    The specified schema for MapInXXX and ApplyInXXX has been cached when the 
schema is `StructType`.
    
    For a `str` schema, it is parsed in `UserDefinedFunction.returnType` but 
will not be cached.
    In this case, we can move this parse ahead so it can be cached in MapInXXX 
and ApplyInXXX
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    added test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49305 from zhengruifeng/py_cache_ddl.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py                      |  4 +++-
 python/pyspark/sql/connect/group.py                          | 11 +++++++++--
 .../sql/tests/connect/test_connect_dataframe_property.py     | 12 ++++++++++++
 python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py |  2 +-
 4 files changed, 25 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 185ddc88cd08..33956c867669 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -54,7 +54,7 @@ import functools
 from pyspark import _NoValue
 from pyspark._globals import _NoValueType
 from pyspark.util import is_remote_only
-from pyspark.sql.types import Row, StructType, _create_row
+from pyspark.sql.types import Row, StructType, _create_row, 
_parse_datatype_string
 from pyspark.sql.dataframe import (
     DataFrame as ParentDataFrame,
     DataFrameNaFunctions as ParentDataFrameNaFunctions,
@@ -2036,6 +2036,8 @@ class DataFrame(ParentDataFrame):
         from pyspark.sql.connect.udf import UserDefinedFunction
 
         _validate_pandas_udf(func, evalType)
+        if isinstance(schema, str):
+            schema = cast(StructType, _parse_datatype_string(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index 863461da10ec..006af8756e63 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -35,8 +35,7 @@ from pyspark.util import PythonEvalType
 from pyspark.sql.group import GroupedData as PySparkGroupedData
 from pyspark.sql.pandas.group_ops import PandasCogroupedOps as 
PySparkPandasCogroupedOps
 from pyspark.sql.pandas.functions import _validate_pandas_udf  # type: 
ignore[attr-defined]
-from pyspark.sql.types import NumericType
-from pyspark.sql.types import StructType
+from pyspark.sql.types import NumericType, StructType, _parse_datatype_string
 
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.column import Column
@@ -295,6 +294,8 @@ class GroupedData:
         from pyspark.sql.connect.dataframe import DataFrame
 
         _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+        if isinstance(schema, str):
+            schema = cast(StructType, _parse_datatype_string(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
@@ -367,6 +368,8 @@ class GroupedData:
         from pyspark.sql.connect.dataframe import DataFrame
 
         _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF)
+        if isinstance(schema, str):
+            schema = cast(StructType, _parse_datatype_string(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
@@ -410,6 +413,8 @@ class PandasCogroupedOps:
         from pyspark.sql.connect.dataframe import DataFrame
 
         _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
+        if isinstance(schema, str):
+            schema = cast(StructType, _parse_datatype_string(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
@@ -439,6 +444,8 @@ class PandasCogroupedOps:
         from pyspark.sql.connect.dataframe import DataFrame
 
         _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF)
+        if isinstance(schema, str):
+            schema = cast(StructType, _parse_datatype_string(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
diff --git 
a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py 
b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
index 1a8c7190e31a..c4c10c963a48 100644
--- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
+++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
@@ -110,6 +110,12 @@ class 
SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
             cdf1 = cdf.mapInPandas(func, schema)
             self.assertEqual(cdf1._cached_schema, schema)
 
+        with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+            self.assertTrue(is_remote())
+            cdf1 = cdf.mapInPandas(func, "a int, b string")
+            # Properly cache the parsed schema
+            self.assertEqual(cdf1._cached_schema, schema)
+
         with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
             # 'mapInPandas' depends on the method 'pandas_udf', which is 
dispatched
             # based on 'is_remote'. However, in SparkConnectSQLTestCase, the 
remote
@@ -180,6 +186,12 @@ class 
SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
             cdf1 = cdf.groupby("id").applyInPandas(normalize, schema)
             self.assertEqual(cdf1._cached_schema, schema)
 
+        with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+            self.assertTrue(is_remote())
+            cdf1 = cdf.groupby("id").applyInPandas(normalize, "id long, v 
double")
+            # Properly cache the parsed schema
+            self.assertEqual(cdf1._cached_schema, schema)
+
         with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
             self.assertFalse(is_remote())
             sdf1 = sdf.groupby("id").applyInPandas(normalize, schema)
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index f85a7b03edda..1f9532352679 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -154,7 +154,7 @@ class CogroupedApplyInPandasTestsMixin:
         ):
             (left.groupby("id", 
"k").cogroup(right.groupby("id"))).applyInPandas(
                 merge_pandas, "id long, k int, v int"
-            ).schema
+            ).count()
 
     def test_apply_in_pandas_not_returning_pandas_dataframe(self):
         with self.quiet():


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to