This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4c39d6fa648a [SPARK-50699][PYTHON] Parse and generate DDL string with
a specified session
4c39d6fa648a is described below
commit 4c39d6fa648a754d0b6585839e2803bc1e2c8cc1
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Dec 31 09:11:42 2024 +0900
[SPARK-50699][PYTHON] Parse and generate DDL string with a specified session
### What changes were proposed in this pull request?
Parse and generate DDL string with a specified session
### Why are the changes needed?
In `_parse_datatype_string` and `toDDL`, a `SparkSession` or `SparkContext`
is always needed.
In most cases, the session is already present, so we can avoid creating or
fetching the activate session.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49331 from zhengruifeng/py_session_ddl_json.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/core/context.py | 11 +++++++++++
python/pyspark/pandas/frame.py | 6 ++----
python/pyspark/sql/connect/dataframe.py | 4 ++--
python/pyspark/sql/connect/group.py | 10 +++++-----
python/pyspark/sql/connect/session.py | 10 ++++++++++
python/pyspark/sql/pandas/group_ops.py | 10 +++++-----
python/pyspark/sql/session.py | 9 +++++++--
python/pyspark/sql/types.py | 24 ++++--------------------
8 files changed, 46 insertions(+), 38 deletions(-)
diff --git a/python/pyspark/core/context.py b/python/pyspark/core/context.py
index 9ed4699c4b5b..42a368555ae9 100644
--- a/python/pyspark/core/context.py
+++ b/python/pyspark/core/context.py
@@ -75,6 +75,7 @@ from py4j.java_gateway import is_instance_of, JavaGateway,
JavaObject, JVMView
if TYPE_CHECKING:
from pyspark.accumulators import AccumulatorParam
+ from pyspark.sql.types import DataType, StructType
__all__ = ["SparkContext"]
@@ -2623,6 +2624,16 @@ class SparkContext:
messageParameters={},
)
+ def _to_ddl(self, struct: "StructType") -> str:
+ assert self._jvm is not None
+ return self._jvm.PythonSQLUtils.jsonToDDL(struct.json())
+
+ def _parse_ddl(self, ddl: str) -> "DataType":
+ from pyspark.sql.types import _parse_datatype_json_string
+
+ assert self._jvm is not None
+ return
_parse_datatype_json_string(self._jvm.PythonSQLUtils.ddlToJson(ddl))
+
def _test() -> None:
import doctest
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 35b96543b9eb..86820573344e 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -7292,8 +7292,6 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
4 1 True 1.0
5 2 False 2.0
"""
- from pyspark.sql.types import _parse_datatype_string
-
include_list: List[str]
if not is_list_like(include):
include_list = [cast(str, include)] if include is not None else []
@@ -7320,14 +7318,14 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
include_spark_type = []
for inc in include_list:
try:
- include_spark_type.append(_parse_datatype_string(inc))
+
include_spark_type.append(self._internal.spark_frame._session._parse_ddl(inc))
except BaseException:
pass
exclude_spark_type = []
for exc in exclude_list:
try:
- exclude_spark_type.append(_parse_datatype_string(exc))
+
exclude_spark_type.append(self._internal.spark_frame._session._parse_ddl(exc))
except BaseException:
pass
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 33956c867669..3d8f0eced34b 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -54,7 +54,7 @@ import functools
from pyspark import _NoValue
from pyspark._globals import _NoValueType
from pyspark.util import is_remote_only
-from pyspark.sql.types import Row, StructType, _create_row,
_parse_datatype_string
+from pyspark.sql.types import Row, StructType, _create_row
from pyspark.sql.dataframe import (
DataFrame as ParentDataFrame,
DataFrameNaFunctions as ParentDataFrameNaFunctions,
@@ -2037,7 +2037,7 @@ class DataFrame(ParentDataFrame):
_validate_pandas_udf(func, evalType)
if isinstance(schema, str):
- schema = cast(StructType, _parse_datatype_string(schema))
+ schema = cast(StructType, self._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
diff --git a/python/pyspark/sql/connect/group.py
b/python/pyspark/sql/connect/group.py
index 006af8756e63..11adc8850fec 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -35,7 +35,7 @@ from pyspark.util import PythonEvalType
from pyspark.sql.group import GroupedData as PySparkGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps as
PySparkPandasCogroupedOps
from pyspark.sql.pandas.functions import _validate_pandas_udf # type:
ignore[attr-defined]
-from pyspark.sql.types import NumericType, StructType, _parse_datatype_string
+from pyspark.sql.types import NumericType, StructType
import pyspark.sql.connect.plan as plan
from pyspark.sql.column import Column
@@ -295,7 +295,7 @@ class GroupedData:
_validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
if isinstance(schema, str):
- schema = cast(StructType, _parse_datatype_string(schema))
+ schema = cast(StructType, self._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -369,7 +369,7 @@ class GroupedData:
_validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF)
if isinstance(schema, str):
- schema = cast(StructType, _parse_datatype_string(schema))
+ schema = cast(StructType, self._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -414,7 +414,7 @@ class PandasCogroupedOps:
_validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
if isinstance(schema, str):
- schema = cast(StructType, _parse_datatype_string(schema))
+ schema = cast(StructType,
self._gd1._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -445,7 +445,7 @@ class PandasCogroupedOps:
_validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF)
if isinstance(schema, str):
- schema = cast(StructType, _parse_datatype_string(schema))
+ schema = cast(StructType,
self._gd1._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 925eaaeabf60..3f1663d06850 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -1111,6 +1111,16 @@ class SparkSession:
return creator, (self._session_id,)
+ def _to_ddl(self, struct: StructType) -> str:
+ ddl = self._client._analyze(method="json_to_ddl",
json_string=struct.json()).ddl_string
+ assert ddl is not None
+ return ddl
+
+ def _parse_ddl(self, ddl: str) -> DataType:
+ dt = self._client._analyze(method="ddl_parse", ddl_string=ddl).parsed
+ assert dt is not None
+ return dt
+
SparkSession.__doc__ = PySparkSession.__doc__
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index bd12b41b3436..343a68bf010b 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -36,7 +36,7 @@ from pyspark.sql.streaming.stateful_processor import (
)
from pyspark.sql.streaming.stateful_processor import StatefulProcessor,
StatefulProcessorHandle
from pyspark.sql.streaming.stateful_processor_util import
TransformWithStateInPandasFuncMode
-from pyspark.sql.types import StructType, _parse_datatype_string
+from pyspark.sql.types import StructType
if TYPE_CHECKING:
from pyspark.sql.pandas._typing import (
@@ -348,9 +348,9 @@ class PandasGroupedOpsMixin:
]
if isinstance(outputStructType, str):
- outputStructType = cast(StructType,
_parse_datatype_string(outputStructType))
+ outputStructType = cast(StructType,
self._df._session._parse_ddl(outputStructType))
if isinstance(stateStructType, str):
- stateStructType = cast(StructType,
_parse_datatype_string(stateStructType))
+ stateStructType = cast(StructType,
self._df._session._parse_ddl(stateStructType))
udf = pandas_udf(
func, # type: ignore[call-overload]
@@ -502,7 +502,7 @@ class PandasGroupedOpsMixin:
if initialState is not None:
assert isinstance(initialState, GroupedData)
if isinstance(outputStructType, str):
- outputStructType = cast(StructType,
_parse_datatype_string(outputStructType))
+ outputStructType = cast(StructType,
self._df._session._parse_ddl(outputStructType))
def handle_pre_init(
statefulProcessorApiClient: StatefulProcessorApiClient,
@@ -681,7 +681,7 @@ class PandasGroupedOpsMixin:
return result
if isinstance(outputStructType, str):
- outputStructType = cast(StructType,
_parse_datatype_string(outputStructType))
+ outputStructType = cast(StructType,
self._df._session._parse_ddl(outputStructType))
df = self._df
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 00fa60442b41..f3a1639fddaf 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -58,7 +58,6 @@ from pyspark.sql.types import (
_has_nulltype,
_merge_type,
_create_converter,
- _parse_datatype_string,
_from_numpy_type,
)
from pyspark.errors.exceptions.captured import install_exception_handler
@@ -1501,7 +1500,7 @@ class SparkSession(SparkConversionMixin):
)
if isinstance(schema, str):
- schema = cast(Union[AtomicType, StructType, str],
_parse_datatype_string(schema))
+ schema = cast(Union[AtomicType, StructType, str],
self._parse_ddl(schema))
elif isinstance(schema, (list, tuple)):
# Must re-encode any unicode strings to be consistent with
StructField names
schema = [x.encode("utf-8") if not isinstance(x, str) else x for x
in schema]
@@ -2338,6 +2337,12 @@ class SparkSession(SparkConversionMixin):
"""
self._jsparkSession.clearTags()
+ def _to_ddl(self, struct: StructType) -> str:
+ return self._sc._to_ddl(struct)
+
+ def _parse_ddl(self, ddl: str) -> DataType:
+ return self._sc._parse_ddl(ddl)
+
def _test() -> None:
import os
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 93ac6655b886..f40a8bf62b29 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1563,16 +1563,9 @@ class StructType(DataType):
session = SparkSession.getActiveSession()
assert session is not None
- return session._client._analyze( # type: ignore[return-value]
- method="json_to_ddl", json_string=self.json()
- ).ddl_string
-
+ return session._to_ddl(self)
else:
- from py4j.java_gateway import JVMView
-
- sc = get_active_spark_context()
- assert sc._jvm is not None
- return cast(JVMView, sc._jvm).PythonSQLUtils.jsonToDDL(self.json())
+ return get_active_spark_context()._to_ddl(self)
class VariantType(AtomicType):
@@ -1907,18 +1900,9 @@ def _parse_datatype_string(s: str) -> DataType:
if is_remote():
from pyspark.sql.connect.session import SparkSession
- return cast(
- DataType,
- SparkSession.active()._client._analyze(method="ddl_parse",
ddl_string=s).parsed,
- )
-
+ return SparkSession.active()._parse_ddl(s)
else:
- from py4j.java_gateway import JVMView
-
- sc = get_active_spark_context()
- return _parse_datatype_json_string(
- cast(JVMView,
sc._jvm).org.apache.spark.sql.api.python.PythonSQLUtils.ddlToJson(s)
- )
+ return get_active_spark_context()._parse_ddl(s)
def _parse_datatype_json_string(json_string: str) -> DataType:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]