This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4c39d6fa648a [SPARK-50699][PYTHON] Parse and generate DDL string with 
a specified session
4c39d6fa648a is described below

commit 4c39d6fa648a754d0b6585839e2803bc1e2c8cc1
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Dec 31 09:11:42 2024 +0900

    [SPARK-50699][PYTHON] Parse and generate DDL string with a specified session
    
    ### What changes were proposed in this pull request?
    Parse and generate DDL string with a specified session
    
    ### Why are the changes needed?
    In `_parse_datatype_string` and `toDDL`, a `SparkSession` or `SparkContext` 
is always needed.
    In most cases, the session is already present, so we can avoid creating or 
fetching the activate session.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49331 from zhengruifeng/py_session_ddl_json.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/core/context.py          | 11 +++++++++++
 python/pyspark/pandas/frame.py          |  6 ++----
 python/pyspark/sql/connect/dataframe.py |  4 ++--
 python/pyspark/sql/connect/group.py     | 10 +++++-----
 python/pyspark/sql/connect/session.py   | 10 ++++++++++
 python/pyspark/sql/pandas/group_ops.py  | 10 +++++-----
 python/pyspark/sql/session.py           |  9 +++++++--
 python/pyspark/sql/types.py             | 24 ++++--------------------
 8 files changed, 46 insertions(+), 38 deletions(-)

diff --git a/python/pyspark/core/context.py b/python/pyspark/core/context.py
index 9ed4699c4b5b..42a368555ae9 100644
--- a/python/pyspark/core/context.py
+++ b/python/pyspark/core/context.py
@@ -75,6 +75,7 @@ from py4j.java_gateway import is_instance_of, JavaGateway, 
JavaObject, JVMView
 
 if TYPE_CHECKING:
     from pyspark.accumulators import AccumulatorParam
+    from pyspark.sql.types import DataType, StructType
 
 __all__ = ["SparkContext"]
 
@@ -2623,6 +2624,16 @@ class SparkContext:
                 messageParameters={},
             )
 
+    def _to_ddl(self, struct: "StructType") -> str:
+        assert self._jvm is not None
+        return self._jvm.PythonSQLUtils.jsonToDDL(struct.json())
+
+    def _parse_ddl(self, ddl: str) -> "DataType":
+        from pyspark.sql.types import _parse_datatype_json_string
+
+        assert self._jvm is not None
+        return 
_parse_datatype_json_string(self._jvm.PythonSQLUtils.ddlToJson(ddl))
+
 
 def _test() -> None:
     import doctest
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 35b96543b9eb..86820573344e 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -7292,8 +7292,6 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         4  1   True  1.0
         5  2  False  2.0
         """
-        from pyspark.sql.types import _parse_datatype_string
-
         include_list: List[str]
         if not is_list_like(include):
             include_list = [cast(str, include)] if include is not None else []
@@ -7320,14 +7318,14 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         include_spark_type = []
         for inc in include_list:
             try:
-                include_spark_type.append(_parse_datatype_string(inc))
+                
include_spark_type.append(self._internal.spark_frame._session._parse_ddl(inc))
             except BaseException:
                 pass
 
         exclude_spark_type = []
         for exc in exclude_list:
             try:
-                exclude_spark_type.append(_parse_datatype_string(exc))
+                
exclude_spark_type.append(self._internal.spark_frame._session._parse_ddl(exc))
             except BaseException:
                 pass
 
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 33956c867669..3d8f0eced34b 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -54,7 +54,7 @@ import functools
 from pyspark import _NoValue
 from pyspark._globals import _NoValueType
 from pyspark.util import is_remote_only
-from pyspark.sql.types import Row, StructType, _create_row, 
_parse_datatype_string
+from pyspark.sql.types import Row, StructType, _create_row
 from pyspark.sql.dataframe import (
     DataFrame as ParentDataFrame,
     DataFrameNaFunctions as ParentDataFrameNaFunctions,
@@ -2037,7 +2037,7 @@ class DataFrame(ParentDataFrame):
 
         _validate_pandas_udf(func, evalType)
         if isinstance(schema, str):
-            schema = cast(StructType, _parse_datatype_string(schema))
+            schema = cast(StructType, self._session._parse_ddl(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index 006af8756e63..11adc8850fec 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -35,7 +35,7 @@ from pyspark.util import PythonEvalType
 from pyspark.sql.group import GroupedData as PySparkGroupedData
 from pyspark.sql.pandas.group_ops import PandasCogroupedOps as 
PySparkPandasCogroupedOps
 from pyspark.sql.pandas.functions import _validate_pandas_udf  # type: 
ignore[attr-defined]
-from pyspark.sql.types import NumericType, StructType, _parse_datatype_string
+from pyspark.sql.types import NumericType, StructType
 
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.column import Column
@@ -295,7 +295,7 @@ class GroupedData:
 
         _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
         if isinstance(schema, str):
-            schema = cast(StructType, _parse_datatype_string(schema))
+            schema = cast(StructType, self._df._session._parse_ddl(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
@@ -369,7 +369,7 @@ class GroupedData:
 
         _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF)
         if isinstance(schema, str):
-            schema = cast(StructType, _parse_datatype_string(schema))
+            schema = cast(StructType, self._df._session._parse_ddl(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
@@ -414,7 +414,7 @@ class PandasCogroupedOps:
 
         _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
         if isinstance(schema, str):
-            schema = cast(StructType, _parse_datatype_string(schema))
+            schema = cast(StructType, 
self._gd1._df._session._parse_ddl(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
@@ -445,7 +445,7 @@ class PandasCogroupedOps:
 
         _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF)
         if isinstance(schema, str):
-            schema = cast(StructType, _parse_datatype_string(schema))
+            schema = cast(StructType, 
self._gd1._df._session._parse_ddl(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 925eaaeabf60..3f1663d06850 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -1111,6 +1111,16 @@ class SparkSession:
 
         return creator, (self._session_id,)
 
+    def _to_ddl(self, struct: StructType) -> str:
+        ddl = self._client._analyze(method="json_to_ddl", 
json_string=struct.json()).ddl_string
+        assert ddl is not None
+        return ddl
+
+    def _parse_ddl(self, ddl: str) -> DataType:
+        dt = self._client._analyze(method="ddl_parse", ddl_string=ddl).parsed
+        assert dt is not None
+        return dt
+
 
 SparkSession.__doc__ = PySparkSession.__doc__
 
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index bd12b41b3436..343a68bf010b 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -36,7 +36,7 @@ from pyspark.sql.streaming.stateful_processor import (
 )
 from pyspark.sql.streaming.stateful_processor import StatefulProcessor, 
StatefulProcessorHandle
 from pyspark.sql.streaming.stateful_processor_util import 
TransformWithStateInPandasFuncMode
-from pyspark.sql.types import StructType, _parse_datatype_string
+from pyspark.sql.types import StructType
 
 if TYPE_CHECKING:
     from pyspark.sql.pandas._typing import (
@@ -348,9 +348,9 @@ class PandasGroupedOpsMixin:
         ]
 
         if isinstance(outputStructType, str):
-            outputStructType = cast(StructType, 
_parse_datatype_string(outputStructType))
+            outputStructType = cast(StructType, 
self._df._session._parse_ddl(outputStructType))
         if isinstance(stateStructType, str):
-            stateStructType = cast(StructType, 
_parse_datatype_string(stateStructType))
+            stateStructType = cast(StructType, 
self._df._session._parse_ddl(stateStructType))
 
         udf = pandas_udf(
             func,  # type: ignore[call-overload]
@@ -502,7 +502,7 @@ class PandasGroupedOpsMixin:
         if initialState is not None:
             assert isinstance(initialState, GroupedData)
         if isinstance(outputStructType, str):
-            outputStructType = cast(StructType, 
_parse_datatype_string(outputStructType))
+            outputStructType = cast(StructType, 
self._df._session._parse_ddl(outputStructType))
 
         def handle_pre_init(
             statefulProcessorApiClient: StatefulProcessorApiClient,
@@ -681,7 +681,7 @@ class PandasGroupedOpsMixin:
             return result
 
         if isinstance(outputStructType, str):
-            outputStructType = cast(StructType, 
_parse_datatype_string(outputStructType))
+            outputStructType = cast(StructType, 
self._df._session._parse_ddl(outputStructType))
 
         df = self._df
 
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 00fa60442b41..f3a1639fddaf 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -58,7 +58,6 @@ from pyspark.sql.types import (
     _has_nulltype,
     _merge_type,
     _create_converter,
-    _parse_datatype_string,
     _from_numpy_type,
 )
 from pyspark.errors.exceptions.captured import install_exception_handler
@@ -1501,7 +1500,7 @@ class SparkSession(SparkConversionMixin):
             )
 
         if isinstance(schema, str):
-            schema = cast(Union[AtomicType, StructType, str], 
_parse_datatype_string(schema))
+            schema = cast(Union[AtomicType, StructType, str], 
self._parse_ddl(schema))
         elif isinstance(schema, (list, tuple)):
             # Must re-encode any unicode strings to be consistent with 
StructField names
             schema = [x.encode("utf-8") if not isinstance(x, str) else x for x 
in schema]
@@ -2338,6 +2337,12 @@ class SparkSession(SparkConversionMixin):
         """
         self._jsparkSession.clearTags()
 
+    def _to_ddl(self, struct: StructType) -> str:
+        return self._sc._to_ddl(struct)
+
+    def _parse_ddl(self, ddl: str) -> DataType:
+        return self._sc._parse_ddl(ddl)
+
 
 def _test() -> None:
     import os
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 93ac6655b886..f40a8bf62b29 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1563,16 +1563,9 @@ class StructType(DataType):
 
             session = SparkSession.getActiveSession()
             assert session is not None
-            return session._client._analyze(  # type: ignore[return-value]
-                method="json_to_ddl", json_string=self.json()
-            ).ddl_string
-
+            return session._to_ddl(self)
         else:
-            from py4j.java_gateway import JVMView
-
-            sc = get_active_spark_context()
-            assert sc._jvm is not None
-            return cast(JVMView, sc._jvm).PythonSQLUtils.jsonToDDL(self.json())
+            return get_active_spark_context()._to_ddl(self)
 
 
 class VariantType(AtomicType):
@@ -1907,18 +1900,9 @@ def _parse_datatype_string(s: str) -> DataType:
     if is_remote():
         from pyspark.sql.connect.session import SparkSession
 
-        return cast(
-            DataType,
-            SparkSession.active()._client._analyze(method="ddl_parse", 
ddl_string=s).parsed,
-        )
-
+        return SparkSession.active()._parse_ddl(s)
     else:
-        from py4j.java_gateway import JVMView
-
-        sc = get_active_spark_context()
-        return _parse_datatype_json_string(
-            cast(JVMView, 
sc._jvm).org.apache.spark.sql.api.python.PythonSQLUtils.ddlToJson(s)
-        )
+        return get_active_spark_context()._parse_ddl(s)
 
 
 def _parse_datatype_json_string(json_string: str) -> DataType:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to