This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f72ff1b096ed [SPARK-50681][PYTHON][CONNECT] Cache the parsed schema
for MapInXXX and ApplyInXXX
f72ff1b096ed is described below
commit f72ff1b096ed07b717d203847d67f33208a46e4c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Dec 27 13:09:43 2024 +0800
[SPARK-50681][PYTHON][CONNECT] Cache the parsed schema for MapInXXX and
ApplyInXXX
### What changes were proposed in this pull request?
Cache the parsed schema for MapInXXX and ApplyInXXX
### Why are the changes needed?
The specified schema for MapInXXX and ApplyInXXX has been cached when the
schema is `StructType`.
For a `str` schema, it is parsed in `UserDefinedFunction.returnType` but
will not be cached.
In this case, we can move this parse ahead so it can be cached in MapInXXX
and ApplyInXXX
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
added test
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49305 from zhengruifeng/py_cache_ddl.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/dataframe.py | 4 +++-
python/pyspark/sql/connect/group.py | 11 +++++++++--
.../sql/tests/connect/test_connect_dataframe_property.py | 12 ++++++++++++
python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py | 2 +-
4 files changed, 25 insertions(+), 4 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 185ddc88cd08..33956c867669 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -54,7 +54,7 @@ import functools
from pyspark import _NoValue
from pyspark._globals import _NoValueType
from pyspark.util import is_remote_only
-from pyspark.sql.types import Row, StructType, _create_row
+from pyspark.sql.types import Row, StructType, _create_row,
_parse_datatype_string
from pyspark.sql.dataframe import (
DataFrame as ParentDataFrame,
DataFrameNaFunctions as ParentDataFrameNaFunctions,
@@ -2036,6 +2036,8 @@ class DataFrame(ParentDataFrame):
from pyspark.sql.connect.udf import UserDefinedFunction
_validate_pandas_udf(func, evalType)
+ if isinstance(schema, str):
+ schema = cast(StructType, _parse_datatype_string(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
diff --git a/python/pyspark/sql/connect/group.py
b/python/pyspark/sql/connect/group.py
index 863461da10ec..006af8756e63 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -35,8 +35,7 @@ from pyspark.util import PythonEvalType
from pyspark.sql.group import GroupedData as PySparkGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps as
PySparkPandasCogroupedOps
from pyspark.sql.pandas.functions import _validate_pandas_udf # type:
ignore[attr-defined]
-from pyspark.sql.types import NumericType
-from pyspark.sql.types import StructType
+from pyspark.sql.types import NumericType, StructType, _parse_datatype_string
import pyspark.sql.connect.plan as plan
from pyspark.sql.column import Column
@@ -295,6 +294,8 @@ class GroupedData:
from pyspark.sql.connect.dataframe import DataFrame
_validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+ if isinstance(schema, str):
+ schema = cast(StructType, _parse_datatype_string(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -367,6 +368,8 @@ class GroupedData:
from pyspark.sql.connect.dataframe import DataFrame
_validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF)
+ if isinstance(schema, str):
+ schema = cast(StructType, _parse_datatype_string(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -410,6 +413,8 @@ class PandasCogroupedOps:
from pyspark.sql.connect.dataframe import DataFrame
_validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
+ if isinstance(schema, str):
+ schema = cast(StructType, _parse_datatype_string(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -439,6 +444,8 @@ class PandasCogroupedOps:
from pyspark.sql.connect.dataframe import DataFrame
_validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF)
+ if isinstance(schema, str):
+ schema = cast(StructType, _parse_datatype_string(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
diff --git
a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
index 1a8c7190e31a..c4c10c963a48 100644
--- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
+++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
@@ -110,6 +110,12 @@ class
SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
cdf1 = cdf.mapInPandas(func, schema)
self.assertEqual(cdf1._cached_schema, schema)
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+ self.assertTrue(is_remote())
+ cdf1 = cdf.mapInPandas(func, "a int, b string")
+ # Properly cache the parsed schema
+ self.assertEqual(cdf1._cached_schema, schema)
+
with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
# 'mapInPandas' depends on the method 'pandas_udf', which is
dispatched
# based on 'is_remote'. However, in SparkConnectSQLTestCase, the
remote
@@ -180,6 +186,12 @@ class
SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
cdf1 = cdf.groupby("id").applyInPandas(normalize, schema)
self.assertEqual(cdf1._cached_schema, schema)
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+ self.assertTrue(is_remote())
+ cdf1 = cdf.groupby("id").applyInPandas(normalize, "id long, v
double")
+ # Properly cache the parsed schema
+ self.assertEqual(cdf1._cached_schema, schema)
+
with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
self.assertFalse(is_remote())
sdf1 = sdf.groupby("id").applyInPandas(normalize, schema)
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index f85a7b03edda..1f9532352679 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -154,7 +154,7 @@ class CogroupedApplyInPandasTestsMixin:
):
(left.groupby("id",
"k").cogroup(right.groupby("id"))).applyInPandas(
merge_pandas, "id long, k int, v int"
- ).schema
+ ).count()
def test_apply_in_pandas_not_returning_pandas_dataframe(self):
with self.quiet():
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]