This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b9b7669fce5c [SPARK-45190][SPARK-48897][PYTHON][CONNECT] Make 
`from_xml` support StructType schema
b9b7669fce5c is described below

commit b9b7669fce5cf8797627e568b1f85fe6e4733d31
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Jul 16 14:21:15 2024 +0900

    [SPARK-45190][SPARK-48897][PYTHON][CONNECT] Make `from_xml` support 
StructType schema
    
    ### What changes were proposed in this pull request?
    Make `from_xml` support StructType schema
    
    ### Why are the changes needed?
    StructType schema was supported in Spark Classic, but not in Spark Connect
    
    to address https://github.com/apache/spark/pull/43680#discussion_r1385332357
    
    ### Does this PR introduce _any_ user-facing change?
    
    before:
    ```
    from pyspark.sql.types import StructType, LongType
    import pyspark.sql.functions as sf
    data = [(1, '''<p><a>1</a></p>''')]
    df = spark.createDataFrame(data, ("key", "value"))
    
    schema = StructType().add("a", LongType())
    df.select(sf.from_xml(df.value, schema)).show()
    
    ---------------------------------------------------------------------------
    AnalysisException                         Traceback (most recent call last)
    Cell In[1], line 7
    ...
    AnalysisException: [PARSE_SYNTAX_ERROR] Syntax error at or near '{'. 
SQLSTATE: 42601
    
    JVM stacktrace:
    org.apache.spark.sql.AnalysisException
            at 
org.apache.spark.sql.catalyst.parser.ParseException.withCommand(parsers.scala:278)
            at 
org.apache.spark.sql.catalyst.parser.AbstractParser.parse(parsers.scala:98)
            at 
org.apache.spark.sql.catalyst.parser.AbstractParser.parseDataType(parsers.scala:40)
            at 
org.apache.spark.sql.types.DataType$.$anonfun$fromDDL$1(DataType.scala:126)
            at 
org.apache.spark.sql.types.DataType$.parseTypeWithFallback(DataType.scala:145)
            at org.apache.spark.sql.types.DataType$.fromDDL(DataType.scala:127)
    ```
    
    after:
    ```
    +---------------+
    |from_xml(value)|
    +---------------+
    |            {1}|
    +---------------+
    
    ```
    
    ### How was this patch tested?
    added doctest and enabled unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #47355 from zhengruifeng/from_xml_struct.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  | 36 +++++++++++++++++++++-
 python/pyspark/sql/functions/builtin.py            | 18 +++++++++--
 .../sql/tests/connect/test_connect_function.py     |  7 ++---
 3 files changed, 54 insertions(+), 7 deletions(-)

diff --git 
a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4702f09a14c2..9eeec306cda8 100644
--- 
a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -65,7 +65,7 @@ import 
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_
 import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
 import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, 
SparkConnectService}
 import org.apache.spark.sql.connect.utils.MetricGenerator
-import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors}
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.execution.arrow.ArrowConverters
@@ -1973,6 +1973,40 @@ class SparkConnectPlanner(
           None
         }
 
+      case "from_xml" if Seq(2, 3).contains(fun.getArgumentsCount) =>
+        // XmlToStructs constructor doesn't accept JSON-formatted schema.
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+
+        var schema: DataType = null
+        children(1) match {
+          case Literal(s, StringType) if s != null =>
+            try {
+              schema = DataType.fromJson(s.toString)
+            } catch {
+              case _: Exception =>
+            }
+          case _ =>
+        }
+
+        if (schema != null) {
+          schema match {
+            case t: StructType => t
+            case _ => throw 
DataTypeErrors.failedParsingStructTypeError(schema.sql)
+          }
+
+          var options = Map.empty[String, String]
+          if (children.length == 3) {
+            options = extractMapData(children(2), "Options")
+          }
+          Some(
+            XmlToStructs(
+              schema = 
CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType],
+              options = options,
+              child = children.head))
+        } else {
+          None
+        }
+
       // Avro-specific functions
       case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) =>
         val children = fun.getArgumentsList.asScala.map(transformExpression)
diff --git a/python/pyspark/sql/functions/builtin.py 
b/python/pyspark/sql/functions/builtin.py
index 9e0c0700ae04..3193c3c4b574 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -16303,7 +16303,21 @@ def from_xml(
     >>> df.select(sf.from_xml(df.value, schema).alias("xml")).collect()
     [Row(xml=Row(a=1))]
 
-    Example 2: Parsing XML with :class:`ArrayType` in schema
+    Example 2: Parsing XML with a :class:`StructType` schema
+
+    >>> import pyspark.sql.functions as sf
+    >>> from pyspark.sql.types import StructType, LongType
+    >>> data = [(1, '''<p><a>1</a></p>''')]
+    >>> df = spark.createDataFrame(data, ("key", "value"))
+    >>> schema = StructType().add("a", LongType())
+    >>> df.select(sf.from_xml(df.value, schema)).show()
+    +---------------+
+    |from_xml(value)|
+    +---------------+
+    |            {1}|
+    +---------------+
+
+    Example 3: Parsing XML with :class:`ArrayType` in schema
 
     >>> import pyspark.sql.functions as sf
     >>> data = [(1, '<p><a>1</a><a>2</a></p>')]
@@ -16314,7 +16328,7 @@ def from_xml(
     >>> df.select(sf.from_xml(df.value, schema).alias("xml")).collect()
     [Row(xml=Row(a=[1, 2]))]
 
-    Example 3: Parsing XML using :meth:`pyspark.sql.functions.schema_of_xml`
+    Example 4: Parsing XML using :meth:`pyspark.sql.functions.schema_of_xml`
 
     >>> import pyspark.sql.functions as sf
     >>> # Sample data with an XML column
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 0f0abfd4b856..a4dcf1ee0e31 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -1908,11 +1908,10 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, 
PandasOnSparkTestUtils, S
         sdf = self.spark.sql(query)
 
         # test from_xml
-        # TODO(SPARK-45190): Address StructType schema parse error
         for schema in [
             "a INT",
-            # StructType([StructField("a", IntegerType())]),
-            # StructType([StructField("a", ArrayType(IntegerType()))]),
+            StructType([StructField("a", IntegerType())]),
+            StructType([StructField("a", ArrayType(IntegerType()))]),
         ]:
             self.compare_by_show(
                 cdf.select(CF.from_xml(cdf.a, schema)),
@@ -1933,7 +1932,7 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, 
PandasOnSparkTestUtils, S
 
         for schema in [
             "STRUCT<a: ARRAY<INT>>",
-            # StructType([StructField("a", ArrayType(IntegerType()))]),
+            StructType([StructField("a", ArrayType(IntegerType()))]),
         ]:
             self.compare_by_show(
                 cdf.select(CF.from_xml(cdf.b, schema)),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to