This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 71e0fdbcbeb2 [SPARK-46382][SQL] XML: Refactor inference and parsing
71e0fdbcbeb2 is described below

commit 71e0fdbcbeb2efa388d67f398a28a3168b494e9c
Author: Shujing Yang <shujing.y...@databricks.com>
AuthorDate: Mon Jan 8 08:53:39 2024 +0900

    [SPARK-46382][SQL] XML: Refactor inference and parsing
    
    ### What changes were proposed in this pull request?
    
    This follow-up refactors the handling of value tags and endElement.
    
    1. As value tags only exist in structure data, their handling will be 
confined to the inferObject method, eliminating the need for processing in 
inferField. This implies that when we encounter non-whitespace characters, we 
can invoke inferObject. For structures with a single primitive field, we'll 
simplify them into primitive types during the schema inference.
    2. We wanted to make sure that the entire entry, including the starting 
tag, value, and ending tag are all consumed when we completed the parsing.
    
    ### Why are the changes needed?
    
    This follow-up simplifies the handling of value tags.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44571 from shujingyang-db/cpature-values-follow-up.
    
    Lead-authored-by: Shujing Yang <shujing.y...@databricks.com>
    Co-authored-by: Shujing Yang 
<135740748+shujingyang...@users.noreply.github.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../spark/sql/catalyst/xml/StaxXmlParser.scala     | 138 ++++++++------------
 .../sql/catalyst/xml/StaxXmlParserUtils.scala      |  30 ++++-
 .../spark/sql/catalyst/xml/XmlInferSchema.scala    | 139 ++++++---------------
 .../sql/execution/datasources/xml/XmlSuite.scala   |  46 ++++---
 .../xml/parsers/StaxXmlParserUtilsSuite.scala      |   2 +-
 5 files changed, 146 insertions(+), 209 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
index 11edce8140f0..199f1abd7e20 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
@@ -25,7 +25,6 @@ import javax.xml.stream.events._
 import javax.xml.transform.stream.StreamSource
 import javax.xml.validation.Schema
 
-import scala.annotation.tailrec
 import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
 import scala.util.Try
@@ -151,12 +150,7 @@ class StaxXmlParser(
       }
       val parser = StaxXmlParserUtils.filteredReader(xml)
       val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
-      // A structure object is an attribute-only element
-      // if it only consists of attributes and valueTags.
-      val isRootAttributesOnly = schema.fields.forall { f =>
-        f.name == options.valueTag || 
f.name.startsWith(options.attributePrefix)
-      }
-      val result = Some(convertObject(parser, schema, rootAttributes, 
isRootAttributesOnly))
+      val result = Some(convertObject(parser, schema, rootAttributes))
       parser.close()
       result
     } catch {
@@ -195,69 +189,60 @@ class StaxXmlParser(
   private[xml] def convertField(
       parser: XMLEventReader,
       dataType: DataType,
+      startElementName: String,
       attributes: Array[Attribute] = Array.empty): Any = {
 
-    def convertComplicatedType(dt: DataType, attributes: Array[Attribute]): 
Any = dt match {
+    def convertComplicatedType(
+        dt: DataType,
+        startElementName: String,
+        attributes: Array[Attribute]): Any = dt match {
       case st: StructType => convertObject(parser, st)
       case MapType(StringType, vt, _) => convertMap(parser, vt, attributes)
-      case ArrayType(st, _) => convertField(parser, st)
+      case ArrayType(st, _) => convertField(parser, st, startElementName)
       case _: StringType =>
-        convertTo(StaxXmlParserUtils.currentStructureAsString(parser), 
StringType)
+        convertTo(
+          StaxXmlParserUtils.currentStructureAsString(
+            parser, startElementName, options),
+          StringType)
     }
 
     (parser.peek, dataType) match {
-      case (_: StartElement, dt: DataType) => convertComplicatedType(dt, 
attributes)
+      case (_: StartElement, dt: DataType) =>
+        convertComplicatedType(dt, startElementName, attributes)
       case (_: EndElement, _: StringType) =>
+        StaxXmlParserUtils.skipNextEndElement(parser, startElementName, 
options)
         // Empty. It's null if "" is the null value
         if (options.nullValue == "") {
           null
         } else {
           UTF8String.fromString("")
         }
-      case (_: EndElement, _: DataType) => null
+      case (_: EndElement, _: DataType) =>
+        StaxXmlParserUtils.skipNextEndElement(parser, startElementName, 
options)
+        null
       case (c: Characters, ArrayType(st, _)) =>
         // For `ArrayType`, it needs to return the type of element. The values 
are merged later.
         parser.next
-        convertTo(c.getData, st)
-      case (c: Characters, st: StructType) =>
-        parser.next
-        parser.peek match {
-          case _: EndElement =>
-            // It couldn't be an array of value tags
-            // as the opening tag is immediately followed by a closing tag.
-            if (c.isWhiteSpace) {
-              return null
-            }
-            val indexOpt = getFieldNameToIndex(st).get(options.valueTag)
-            indexOpt match {
-              case Some(index) =>
-                convertTo(c.getData, st.fields(index).dataType)
-              case None => null
-            }
-          case _ =>
-            val row = convertObject(parser, st)
-            if (!c.isWhiteSpace) {
-              addOrUpdate(row.toSeq(st).toArray, st, options.valueTag, 
c.getData, addToTail = false)
-            } else {
-              row
-            }
-        }
+        val value = convertTo(c.getData, st)
+        StaxXmlParserUtils.skipNextEndElement(parser, startElementName, 
options)
+        value
+      case (_: Characters, st: StructType) =>
+        convertObject(parser, st)
       case (_: Characters, _: StringType) =>
-        convertTo(StaxXmlParserUtils.currentStructureAsString(parser), 
StringType)
+        convertTo(
+          StaxXmlParserUtils.currentStructureAsString(
+            parser, startElementName, options),
+          StringType)
       case (c: Characters, _: DataType) if c.isWhiteSpace =>
         // When `Characters` is found, we need to look further to decide
         // if this is really data or space between other elements.
-        val data = c.getData
         parser.next
-        parser.peek match {
-          case _: StartElement => convertComplicatedType(dataType, attributes)
-          case _: EndElement if data.isEmpty => null
-          case _: EndElement => convertTo(data, dataType)
-          case _ => convertField(parser, dataType, attributes)
-        }
+        convertField(parser, dataType, startElementName, attributes)
       case (c: Characters, dt: DataType) =>
+        val value = convertTo(c.getData, dt)
         parser.next
-        convertTo(c.getData, dt)
+        StaxXmlParserUtils.skipNextEndElement(parser, startElementName, 
options)
+        value
       case (e: XMLEvent, dt: DataType) =>
         throw new IllegalArgumentException(
           s"Failed to parse a value for data type $dt with event 
${e.toString}")
@@ -280,16 +265,16 @@ class StaxXmlParser(
     while (!shouldStop) {
       parser.nextEvent match {
         case e: StartElement =>
+          val key = StaxXmlParserUtils.getName(e.asStartElement.getName, 
options)
           kvPairs +=
-            
(UTF8String.fromString(StaxXmlParserUtils.getName(e.asStartElement.getName, 
options)) ->
-            convertField(parser, valueType))
+          (UTF8String.fromString(key) -> convertField(parser, valueType, key))
         case c: Characters if !c.isWhiteSpace =>
           // Create a value tag field for it
           kvPairs +=
           // TODO: We don't support an array value tags in map yet.
           (UTF8String.fromString(options.valueTag) -> convertTo(c.getData, 
valueType))
-        case _: EndElement =>
-          shouldStop = StaxXmlParserUtils.checkEndElement(parser)
+        case _: EndElement | _: EndDocument =>
+          shouldStop = true
         case _ => // do nothing
       }
     }
@@ -321,6 +306,7 @@ class StaxXmlParser(
   private def convertObjectWithAttributes(
       parser: XMLEventReader,
       schema: StructType,
+      startElementName: String,
       attributes: Array[Attribute] = Array.empty): InternalRow = {
     // TODO: This method might have to be removed. Some logics duplicate 
`convertObject()`
     val row = new Array[Any](schema.length)
@@ -329,7 +315,7 @@ class StaxXmlParser(
     val attributesMap = convertAttributes(attributes, schema)
 
     // Then, we read elements here.
-    val fieldsMap = convertField(parser, schema) match {
+    val fieldsMap = convertField(parser, schema, startElementName) match {
       case internalRow: InternalRow =>
         Map(schema.map(_.name).zip(internalRow.toSeq(schema)): _*)
       case v if schema.fieldNames.contains(options.valueTag) =>
@@ -363,8 +349,7 @@ class StaxXmlParser(
   private def convertObject(
       parser: XMLEventReader,
       schema: StructType,
-      rootAttributes: Array[Attribute] = Array.empty,
-      isRootAttributesOnly: Boolean = false): InternalRow = {
+      rootAttributes: Array[Attribute] = Array.empty): InternalRow = {
     val row = new Array[Any](schema.length)
     val nameToIndex = getFieldNameToIndex(schema)
     // If there are attributes, then we process them first.
@@ -388,7 +373,7 @@ class StaxXmlParser(
           nameToIndex.get(field) match {
             case Some(index) => schema(index).dataType match {
               case st: StructType =>
-                row(index) = convertObjectWithAttributes(parser, st, 
attributes)
+                row(index) = convertObjectWithAttributes(parser, st, field, 
attributes)
 
               case ArrayType(dt: DataType, _) =>
                 val values = Option(row(index))
@@ -396,21 +381,21 @@ class StaxXmlParser(
                   .getOrElse(ArrayBuffer.empty[Any])
                 val newValue = dt match {
                   case st: StructType =>
-                    convertObjectWithAttributes(parser, st, attributes)
+                    convertObjectWithAttributes(parser, st, field, attributes)
                   case dt: DataType =>
-                    convertField(parser, dt)
+                    convertField(parser, dt, field)
                 }
                 row(index) = values :+ newValue
 
               case dt: DataType =>
-                row(index) = convertField(parser, dt, attributes)
+                row(index) = convertField(parser, dt, field, attributes)
             }
 
             case None =>
               if (hasWildcard) {
                 // Special case: there's an 'any' wildcard element that 
matches anything else
                 // as a string (or array of strings, to parse multiple ones)
-                val newValue = convertField(parser, StringType)
+                val newValue = convertField(parser, StringType, field)
                 val anyIndex = schema.fieldIndex(wildcardColName)
                 schema(wildcardColName).dataType match {
                   case StringType =>
@@ -423,19 +408,21 @@ class StaxXmlParser(
                 }
               } else {
                 StaxXmlParserUtils.skipChildren(parser)
+                StaxXmlParserUtils.skipNextEndElement(parser, field, options)
               }
           }
         } catch {
           case e: SparkUpgradeException => throw e
           case NonFatal(e) =>
+            // TODO: we don't support partial results now
             badRecordException = badRecordException.orElse(Some(e))
         }
 
         case c: Characters if !c.isWhiteSpace =>
           addOrUpdate(row, schema, options.valueTag, c.getData)
 
-        case _: EndElement =>
-          shouldStop = parseAndCheckEndElement(row, schema, parser)
+        case _: EndElement | _: EndDocument =>
+          shouldStop = true
 
         case _ => // do nothing
       }
@@ -599,24 +586,6 @@ class StaxXmlParser(
     }
   }
 
-  @tailrec
-  private def parseAndCheckEndElement(
-      row: Array[Any],
-      schema: StructType,
-      parser: XMLEventReader): Boolean = {
-    parser.peek match {
-      case _: EndElement | _: EndDocument => true
-      case _: StartElement => false
-      case c: Characters if !c.isWhiteSpace =>
-        parser.nextEvent()
-        addOrUpdate(row, schema, options.valueTag, c.getData)
-        parseAndCheckEndElement(row, schema, parser)
-      case _ =>
-        parser.nextEvent()
-        parseAndCheckEndElement(row, schema, parser)
-    }
-  }
-
   private def addOrUpdate(
       row: Array[Any],
       schema: StructType,
@@ -628,17 +597,14 @@ class StaxXmlParser(
         schema(index).dataType match {
           case ArrayType(elementType, _) =>
             val value = convertTo(data, elementType)
-            val result = if (row(index) == null) {
-              ArrayBuffer(value)
-            } else {
-              val genericArrayData = row(index).asInstanceOf[GenericArrayData]
-              if (addToTail) {
-                genericArrayData.toArray(elementType) :+ value
+            val values = Option(row(index))
+              .map(_.asInstanceOf[ArrayBuffer[Any]])
+              .getOrElse(ArrayBuffer.empty[Any])
+            row(index) = if (addToTail) {
+                values :+ value
               } else {
-                value +: genericArrayData.toArray(elementType)
+                value +: values
               }
-            }
-            row(index) = new GenericArrayData(result)
           case dataType =>
             row(index) = convertTo(data, dataType)
         }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala
index 0471cb310d89..a59ea6f460de 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala
@@ -38,9 +38,14 @@ object StaxXmlParserUtils {
   def filteredReader(xml: String): XMLEventReader = {
     val filter = new EventFilter {
       override def accept(event: XMLEvent): Boolean =
-        // Ignore comments and processing instructions
         event.getEventType match {
+          // Ignore comments and processing instructions
           case XMLStreamConstants.COMMENT | 
XMLStreamConstants.PROCESSING_INSTRUCTION => false
+          // unsupported events
+          case XMLStreamConstants.DTD |
+               XMLStreamConstants.ENTITY_DECLARATION |
+               XMLStreamConstants.ENTITY_REFERENCE |
+               XMLStreamConstants.NOTATION_DECLARATION => false
           case _ => true
         }
     }
@@ -121,7 +126,10 @@ object StaxXmlParserUtils {
   /**
    * Convert the current structure of XML document to a XML string.
    */
-  def currentStructureAsString(parser: XMLEventReader): String = {
+  def currentStructureAsString(
+      parser: XMLEventReader,
+      startElementName: String,
+      options: XmlOptions): String = {
     val xmlString = new StringBuilder()
     var indent = 0
     do {
@@ -151,6 +159,7 @@ object StaxXmlParserUtils {
         indent > 0
       case _ => true
     })
+    skipNextEndElement(parser, startElementName, options)
     xmlString.toString()
   }
 
@@ -178,4 +187,21 @@ object StaxXmlParserUtils {
       }
     }
   }
+
+  @tailrec
+  def skipNextEndElement(
+      parser: XMLEventReader,
+      expectedNextEndElementName: String,
+      options: XmlOptions): Unit = {
+    parser.nextEvent() match {
+      case c: Characters if c.isWhiteSpace =>
+        skipNextEndElement(parser, expectedNextEndElementName, options)
+      case endElement: EndElement =>
+        assert(
+          getName(endElement.getName, options) == expectedNextEndElementName,
+          s"Expected EndElement </$expectedNextEndElementName>")
+      case _ => throw new IllegalStateException(
+        s"Expected EndElement </$expectedNextEndElementName>")
+    }
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
index 59222f56454f..51d5ae532b05 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
@@ -23,7 +23,6 @@ import javax.xml.stream.events._
 import javax.xml.transform.stream.StreamSource
 import javax.xml.validation.Schema
 
-import scala.annotation.tailrec
 import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
 import scala.util.control.Exception._
@@ -157,38 +156,17 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
     parser.peek match {
       case _: EndElement => NullType
       case _: StartElement => inferObject(parser)
-      case c: Characters if c.isWhiteSpace =>
-        // When `Characters` is found, we need to look further to decide
-        // if this is really data or space between other elements.
-        val data = c.getData
-        parser.nextEvent()
-        parser.peek match {
-          case _: StartElement => inferObject(parser)
-          case _: EndElement if data.isEmpty => NullType
-          case _: EndElement if options.nullValue == "" => NullType
-          case _: EndElement => StringType
-          case _ => inferField(parser)
-        }
-      case c: Characters if !c.isWhiteSpace =>
-        val characterType = inferFrom(c.getData)
-        parser.nextEvent()
-        parser.peek match {
-          case _: StartElement =>
-            // Some more elements follow;
-            // This is a mix of values and other elements
-            val innerType = inferObject(parser).asInstanceOf[StructType]
-            addOrUpdateValueTagType(innerType, characterType)
-          case _ =>
-            val fieldType = inferField(parser)
-            fieldType match {
-              case st: StructType => addOrUpdateValueTagType(st, characterType)
-              case _: NullType => characterType
-              case _: DataType =>
-                // The field type couldn't be an array type
-                new StructType()
-                .add(options.valueTag, addOrUpdateType(Some(characterType), 
fieldType))
-
-            }
+      case _: Characters =>
+        val structType = inferObject(parser).asInstanceOf[StructType]
+        structType match {
+          case _ if structType.fields.isEmpty =>
+            NullType
+          case simpleType
+              if structType.fields.length == 1
+              && isPrimitiveType(structType.fields.head.dataType)
+              && isValueTagField(structType.fields.head, caseSensitive) =>
+            simpleType.fields.head.dataType
+          case _ => structType
         }
       case e: XMLEvent =>
         throw new IllegalArgumentException(s"Failed to parse data with 
unexpected event $e")
@@ -224,22 +202,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
     val nameToDataType =
       collection.mutable.TreeMap.empty[String, 
DataType](caseSensitivityOrdering)
 
-    @tailrec
-    def inferAndCheckEndElement(parser: XMLEventReader): Boolean = {
-      parser.peek match {
-        case _: EndElement | _: EndDocument => true
-        case _: StartElement => false
-        case c: Characters if !c.isWhiteSpace =>
-          val characterType = inferFrom(c.getData)
-          parser.nextEvent()
-          addOrUpdateType(nameToDataType, options.valueTag, characterType)
-          inferAndCheckEndElement(parser)
-        case _ =>
-          parser.nextEvent()
-          inferAndCheckEndElement(parser)
-      }
-    }
-
     // If there are attributes, then we should process them first.
     val rootValuesMap =
       StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options)
@@ -253,6 +215,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
         case e: StartElement =>
           val attributes = e.getAttributes.asScala.toArray
           val valuesMap = 
StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options)
+          val field = StaxXmlParserUtils.getName(e.asStartElement.getName, 
options)
           val inferredType = inferField(parser) match {
             case st: StructType if valuesMap.nonEmpty =>
               // Merge attributes to the field
@@ -267,7 +230,9 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
             case dt: DataType if valuesMap.nonEmpty =>
               // We need to manually add the field for value.
               val nestedBuilder = ArrayBuffer[StructField]()
-              nestedBuilder += StructField(options.valueTag, dt, nullable = 
true)
+              if (!dt.isInstanceOf[NullType]) {
+                nestedBuilder += StructField(options.valueTag, dt, nullable = 
true)
+              }
               valuesMap.foreach {
                 case (f, v) =>
                   nestedBuilder += StructField(f, inferFrom(v), nullable = 
true)
@@ -277,16 +242,15 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
             case dt: DataType => dt
           }
           // Add the field and datatypes so that we can check if this is 
ArrayType.
-          val field = StaxXmlParserUtils.getName(e.asStartElement.getName, 
options)
           addOrUpdateType(nameToDataType, field, inferredType)
 
         case c: Characters if !c.isWhiteSpace =>
-          // This can be an attribute-only object
+          // This is a value tag
           val valueTagType = inferFrom(c.getData)
           addOrUpdateType(nameToDataType, options.valueTag, valueTagType)
 
-        case _: EndElement =>
-          shouldStop = inferAndCheckEndElement(parser)
+        case _: EndElement | _: EndDocument =>
+          shouldStop = true
 
         case _ => // do nothing
       }
@@ -429,56 +393,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
     case other => Some(other)
   }
 
-  /**
-   * This helper function merges the data type of value tags and inner 
elements.
-   * It could only be structure data. Consider the following case,
-   * <a>
-   *   value1
-   *   <b>1</b>
-   *   value2
-   * </a>
-   * Input: ''a struct<b int, _VALUE string>'' and ''_VALUE string''
-   * Return: ''a struct<b int, _VALUE array<string>>''
-   * @param objectType inner elements' type
-   * @param valueTagType value tag's type
-   */
-  private[xml] def addOrUpdateValueTagType(
-      objectType: DataType,
-      valueTagType: DataType): DataType = {
-    (objectType, valueTagType) match {
-      case (st: StructType, _) =>
-        val valueTagIndexOpt = st.getFieldIndex(options.valueTag)
-
-        valueTagIndexOpt match {
-          // If the field name exists in the inner elements,
-          // merge the type and infer the combined field as an array type if 
necessary
-          case Some(index) if !st(index).dataType.isInstanceOf[ArrayType] =>
-            updateStructField(
-              st,
-              index,
-              ArrayType(compatibleType(caseSensitive, options.valueTag)(
-                st(index).dataType, valueTagType)))
-          case Some(index) =>
-            updateStructField(st, index, compatibleType(caseSensitive, 
options.valueTag)(
-              st(index).dataType, valueTagType))
-          case None =>
-            st.add(options.valueTag, valueTagType)
-        }
-      case _ =>
-        throw new IllegalStateException(
-          "illegal state when merging value tags types in schema inference"
-        )
-    }
-  }
-
-  private def updateStructField(
-      structType: StructType,
-      index: Int,
-      newType: DataType): StructType = {
-    val newFields: Array[StructField] =
-      structType.fields.updated(index, structType.fields(index).copy(dataType 
= newType))
-    StructType(newFields)
-  }
 
   private def addOrUpdateType(
       nameToDataType: collection.mutable.TreeMap[String, DataType],
@@ -501,6 +415,23 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
         newType
     }
   }
+
+  private[xml] def isPrimitiveType(dataType: DataType): Boolean = {
+    dataType match {
+      case _: StructType => false
+      case _: ArrayType => false
+      case _: MapType => false
+      case _ => true
+    }
+  }
+
+  private[xml] def isValueTagField(structField: StructField, caseSensitive: 
Boolean): Boolean = {
+    if (!caseSensitive) {
+      structField.name.toLowerCase(Locale.ROOT) == 
options.valueTag.toLowerCase(Locale.ROOT)
+    } else {
+      structField.name == options.valueTag
+    }
+  }
 }
 
 object XmlInferSchema {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
index 78f9d5285c23..5fdf949a2137 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
@@ -1098,9 +1098,11 @@ class XmlSuite extends QueryTest with SharedSparkSession 
{
     assert(valid.toSeq.toArray.take(schema.length - 1) ===
       Array(Row(10, 10), Row(10, "Ten"), 10.0, 10.0, true,
         "Ten", Array(1, 2), Map("a" -> 123, "b" -> 345)))
-    assert(invalid.toSeq.toArray.take(schema.length - 1) ===
-      Array(null, null, null, null, null,
-        "Ten", Array(2), null))
+    // TODO: we don't support partial results
+    assert(
+      invalid.toSeq.toArray.take(schema.length - 1) ===
+        Array(null, null, null, null, null,
+          null, null, null))
 
     assert(valid.toSeq.toArray.last === null)
     assert(invalid.toSeq.toArray.last.toString.contains(
@@ -1337,7 +1339,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
       .xml(getTestResourcePath(resDir + "whitespace_error.xml"))
 
     assert(whitespaceDF.count() === 1)
-    assert(whitespaceDF.take(1).head.getAs[String]("_corrupt_record") !== null)
+    assert(whitespaceDF.take(1).head.getAs[String]("_corrupt_record") === null)
   }
 
   test("struct with only attributes and no value tag does not crash") {
@@ -2479,7 +2481,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
       .xml(input)
 
     checkAnswer(df, Seq(
-      Row("\" \"", Row(1, "\" \""), Row(Row(null, " ")))))
+      Row("\" \"", Row("\" \"", 1), Row(Row(" ")))))
   }
 
   test("capture values interspersed between elements - nested comments") {
@@ -2552,7 +2554,9 @@ class XmlSuite extends QueryTest with SharedSparkSession {
          |                value4
          |                <struct3>
          |                    value5
-         |                    <array2>1</array2>
+         |                    <array2>1<!--First comment--> <!--Second 
comment--></array2>
+         |                    <![CDATA[This is a CDATA section containing 
<sample1> text.]]>
+         |                    <![CDATA[This is a CDATA section containing 
<sample2> text.]]>
          |                    value6
          |                    <array2>2</array2>
          |                    value7
@@ -2563,10 +2567,10 @@ class XmlSuite extends QueryTest with 
SharedSparkSession {
          |            </array1>
          |            value10
          |            <array1>
-         |                <struct3>
+         |                <struct3><!--First comment--> <!--Second comment-->
          |                    <array2>3</array2>
          |                    value11
-         |                    <array2>4</array2>
+         |                    <array2>4</array2><!--First comment--> 
<!--Second comment-->
          |                </struct3>
          |                <string>string</string>
          |                value12
@@ -2577,7 +2581,9 @@ class XmlSuite extends QueryTest with SharedSparkSession {
          |        </struct2>
          |        value15
          |    </struct1>
+         |     <!--First comment-->
          |    value16
+         |     <!--Second comment-->
          |</ROW>
          |""".stripMargin
     val input = spark.createDataset(Seq(xmlString))
@@ -2594,14 +2600,22 @@ class XmlSuite extends QueryTest with 
SharedSparkSession {
         Row(
           ArraySeq("value3", "value10", "value13", "value14"),
           Array(
-            Row(
-              ArraySeq("value4", "value8", "value9"),
-              "string",
-              Row(ArraySeq("value5", "value6", "value7"), ArraySeq(1, 2))),
-            Row(
-              ArraySeq("value12"),
-              "string",
-              Row(ArraySeq("value11"), ArraySeq(3, 4)))),
+              Row(
+                ArraySeq("value4", "value8", "value9"),
+                "string",
+                Row(
+                  ArraySeq(
+                    "value5",
+                    "This is a CDATA section containing <sample1> text." +
+                      "\n                    This is a CDATA section 
containing <sample2> text.\n" +
+                      "                    value6",
+                    "value7"
+                  ),
+                  ArraySeq(1, 2)
+                )
+              ),
+              Row(ArraySeq("value12"), "string", Row(ArraySeq("value11"), 
ArraySeq(3, 4)))
+            ),
           3))))
 
     checkAnswer(df, expectedAns)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala
index 13a90acb7152..a4ac25b036c4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala
@@ -63,7 +63,7 @@ final class StaxXmlParserUtilsSuite extends SparkFunSuite 
with BeforeAndAfterAll
     val parser = factory.createXMLEventReader(new StringReader(input.toString))
     // Skip until </id>
     StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.END_ELEMENT)
-    val xmlString = StaxXmlParserUtils.currentStructureAsString(parser)
+    val xmlString = StaxXmlParserUtils.currentStructureAsString(parser, "ROW", 
new XmlOptions())
     val expected = <info>
       <name>Sam Mad Dog 
Smith</name><amount><small>1</small><large>9</large></amount></info>
     assert(xmlString === expected.toString())


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to