This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a3c17b2e229 [SPARK-45562][SQL][FOLLOW-UP] XML: Make 'rowTag' option
check case insensitive
a3c17b2e229 is described below
commit a3c17b2e22969de3d225fc9890023456592f6158
Author: Sandip Agarwala <[email protected]>
AuthorDate: Thu Oct 19 13:23:04 2023 +0900
[SPARK-45562][SQL][FOLLOW-UP] XML: Make 'rowTag' option check case
insensitive
### What changes were proposed in this pull request?
[PR 43389](https://github.com/apache/spark/pull/43389) made `rowTag` option
required for XML read and write. However, the option check was done in a case
sensitive manner. This PR makes the check case-insensitive.
### Why are the changes needed?
Options are case-insensitive.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit test.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43416 from sandip-db/xml-rowTagCaseInsensitive.
Authored-by: Sandip Agarwala <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../org/apache/spark/sql/catalyst/xml/XmlOptions.scala | 17 +++++++++++------
.../sql/execution/datasources/xml/XmlFileFormat.scala | 5 ++---
2 files changed, 13 insertions(+), 9 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
index 0dedbec58e1..d2c7b435fe6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
@@ -34,7 +34,8 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy,
SQLConf}
private[sql] class XmlOptions(
@transient val parameters: CaseInsensitiveMap[String],
defaultTimeZoneId: String,
- defaultColumnNameOfCorruptRecord: String)
+ defaultColumnNameOfCorruptRecord: String,
+ rowTagRequired: Boolean)
extends FileSourceOptions(parameters) with Logging {
import XmlOptions._
@@ -42,11 +43,13 @@ private[sql] class XmlOptions(
def this(
parameters: Map[String, String] = Map.empty,
defaultTimeZoneId: String = SQLConf.get.sessionLocalTimeZone,
- defaultColumnNameOfCorruptRecord: String =
SQLConf.get.columnNameOfCorruptRecord) = {
+ defaultColumnNameOfCorruptRecord: String =
SQLConf.get.columnNameOfCorruptRecord,
+ rowTagRequired: Boolean = false) = {
this(
CaseInsensitiveMap(parameters),
defaultTimeZoneId,
- defaultColumnNameOfCorruptRecord)
+ defaultColumnNameOfCorruptRecord,
+ rowTagRequired)
}
private def getBool(paramName: String, default: Boolean = false): Boolean = {
@@ -63,7 +66,9 @@ private[sql] class XmlOptions(
}
val compressionCodec =
parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName)
- val rowTag = parameters.getOrElse(ROW_TAG, XmlOptions.DEFAULT_ROW_TAG).trim
+ val rowTagOpt = parameters.get(XmlOptions.ROW_TAG)
+ require(!rowTagRequired || rowTagOpt.isDefined, s"'${XmlOptions.ROW_TAG}'
option is required.")
+ val rowTag = rowTagOpt.getOrElse(XmlOptions.DEFAULT_ROW_TAG).trim
require(rowTag.nonEmpty, s"'$ROW_TAG' option should not be an empty string.")
require(!rowTag.startsWith("<") && !rowTag.endsWith(">"),
s"'$ROW_TAG' should not include angle brackets")
@@ -223,8 +228,8 @@ private[sql] object XmlOptions extends DataSourceOptions {
newOption(ENCODING, CHARSET)
def apply(parameters: Map[String, String]): XmlOptions =
- new XmlOptions(parameters, SQLConf.get.sessionLocalTimeZone)
+ new XmlOptions(parameters)
def apply(): XmlOptions =
- new XmlOptions(Map.empty, SQLConf.get.sessionLocalTimeZone)
+ new XmlOptions(Map.empty)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
index 4342711b00f..77619299278 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
@@ -42,11 +42,10 @@ class XmlFileFormat extends TextBasedFileFormat with
DataSourceRegister {
def getXmlOptions(
sparkSession: SparkSession,
parameters: Map[String, String]): XmlOptions = {
- val rowTagOpt = parameters.get(XmlOptions.ROW_TAG)
- require(rowTagOpt.isDefined, s"'${XmlOptions.ROW_TAG}' option is
required.")
new XmlOptions(parameters,
sparkSession.sessionState.conf.sessionLocalTimeZone,
- sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+ sparkSession.sessionState.conf.columnNameOfCorruptRecord,
+ true)
}
override def isSplitable(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]