This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4d63ca6394f [SPARK-45562][SQL] XML: Make 'rowTag' a required option 4d63ca6394f is described below commit 4d63ca6394fe8692e1f9bceb93606a86b88b5dc1 Author: Sandip Agarwala <131817656+sandip...@users.noreply.github.com> AuthorDate: Tue Oct 17 20:38:02 2023 +0900 [SPARK-45562][SQL] XML: Make 'rowTag' a required option ### What changes were proposed in this pull request? User can specify `rowTag` option that is the name of the XML element that maps to a `DataFrame Row`. A non-existent `rowTag` will not infer any schema or generate any `DataFrame` rows. Currently, not specifying `rowTag` option results in picking up its default value of `ROW`, which won't match a real XML element in most scenarios. This results in an empty dataframe and confuse new users. This PR makes `rowTag` a required option for both read and write. XML built-in functions (from_xml/schema_of_xml) ignore `rowTag` option. ### Why are the changes needed? See above ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #43389 from sandip-db/xml-rowTag. Authored-by: Sandip Agarwala <131817656+sandip...@users.noreply.github.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../apache/spark/sql/catalyst/xml/XmlOptions.scala | 4 +- .../execution/datasources/xml/XmlFileFormat.scala | 2 + .../execution/datasources/xml/JavaXmlSuite.java | 10 +- .../sql/execution/datasources/xml/XmlSuite.scala | 125 +++++++++++++++------ .../xml/parsers/StaxXmlGeneratorSuite.scala | 4 +- 5 files changed, 103 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala index d0cfff87279..0dedbec58e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala @@ -63,8 +63,8 @@ private[sql] class XmlOptions( } val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName) - val rowTag = parameters.getOrElse(ROW_TAG, XmlOptions.DEFAULT_ROW_TAG) - require(rowTag.nonEmpty, s"'$ROW_TAG' option should not be empty string.") + val rowTag = parameters.getOrElse(ROW_TAG, XmlOptions.DEFAULT_ROW_TAG).trim + require(rowTag.nonEmpty, s"'$ROW_TAG' option should not be an empty string.") require(!rowTag.startsWith("<") && !rowTag.endsWith(">"), s"'$ROW_TAG' should not include angle brackets") val rootTag = parameters.getOrElse(ROOT_TAG, XmlOptions.DEFAULT_ROOT_TAG) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala index baacf7f0748..4342711b00f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala @@ -42,6 +42,8 @@ class XmlFileFormat extends TextBasedFileFormat with DataSourceRegister { def getXmlOptions( sparkSession: SparkSession, parameters: Map[String, String]): XmlOptions = { + val rowTagOpt = parameters.get(XmlOptions.ROW_TAG) + require(rowTagOpt.isDefined, s"'${XmlOptions.ROW_TAG}' option is required.") new XmlOptions(parameters, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java index b3f39180843..c773459dc4c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java @@ -82,7 +82,7 @@ public final class JavaXmlSuite { public void testXmlParser() { Map<String, String> options = new HashMap<>(); options.put("rowTag", booksFileTag); - Dataset<Row> df = spark.read().options(options).format("xml").load(booksFile); + Dataset<Row> df = spark.read().options(options).xml(booksFile); String prefix = XmlOptions.DEFAULT_ATTRIBUTE_PREFIX(); long result = df.select(prefix + "id").count(); Assertions.assertEquals(result, numBooks); @@ -92,7 +92,7 @@ public final class JavaXmlSuite { public void testLoad() { Map<String, String> options = new HashMap<>(); options.put("rowTag", booksFileTag); - Dataset<Row> df = spark.read().options(options).format("xml").load(booksFile); + Dataset<Row> df = spark.read().options(options).xml(booksFile); long result = df.select("description").count(); Assertions.assertEquals(result, numBooks); } @@ -103,10 +103,10 @@ public final class JavaXmlSuite { options.put("rowTag", booksFileTag); Path booksPath = getEmptyTempDir().resolve("booksFile"); - Dataset<Row> df = spark.read().options(options).format("xml").load(booksFile); - df.select("price", "description").write().format("xml").save(booksPath.toString()); + Dataset<Row> df = spark.read().options(options).xml(booksFile); + df.select("price", "description").write().options(options).xml(booksPath.toString()); - Dataset<Row> newDf = spark.read().format("xml").load(booksPath.toString()); + Dataset<Row> newDf = spark.read().options(options).xml(booksPath.toString()); long result = newDf.select("price").count(); Assertions.assertEquals(result, numBooks); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index c5892abf3f8..23223b3e94e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -65,6 +65,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL test") { val results = spark.read.format("xml") + .option("rowTag", "ROW") .option("multiLine", "true") .load(getTestResourcePath(resDir + "cars.xml")) .select("year") @@ -75,6 +76,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL test with xml having unbalanced datatypes") { val results = spark.read + .option("rowTag", "ROW") .option("treatEmptyValuesAsNulls", "true") .option("multiLine", "true") .xml(getTestResourcePath(resDir + "gps-empty-field.xml")) @@ -84,6 +86,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL test with mixed elements (attributes, no child)") { val results = spark.read + .option("rowTag", "ROW") .xml(getTestResourcePath(resDir + "cars-mixed-attr-no-child.xml")) .select("date") .collect() @@ -129,6 +132,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL test for iso-8859-1 encoded file") { val dataFrame = spark.read + .option("rowTag", "ROW") .option("charset", StandardCharsets.ISO_8859_1.name) .xml(getTestResourcePath(resDir + "cars-iso-8859-1.xml")) assert(dataFrame.select("year").collect().length === 3) @@ -142,6 +146,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL test compressed file") { val results = spark.read + .option("rowTag", "ROW") .xml(getTestResourcePath(resDir + "cars.xml.gz")) .select("year") .collect() @@ -151,6 +156,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL test splittable compressed file") { val results = spark.read + .option("rowTag", "ROW") .xml(getTestResourcePath(resDir + "cars.xml.bz2")) .select("year") .collect() @@ -162,6 +168,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { // val exception = intercept[UnsupportedCharsetException] { val exception = intercept[SparkException] { spark.read + .option("rowTag", "ROW") .option("charset", "1-9588-osi") .xml(getTestResourcePath(resDir + "cars.xml")) .select("year") @@ -175,7 +182,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { spark.sql(s""" |CREATE TEMPORARY VIEW carsTable1 |USING org.apache.spark.sql.execution.datasources.xml - |OPTIONS (path "${getTestResourcePath(resDir + "cars.xml")}") + |OPTIONS (rowTag "ROW", path "${getTestResourcePath(resDir + "cars.xml")}") """.stripMargin.replaceAll("\n", " ")) assert(spark.sql("SELECT year FROM carsTable1").collect().length === 3) @@ -185,7 +192,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { spark.sql(s""" |CREATE TEMPORARY VIEW carsTable2 |USING xml - |OPTIONS (path "${getTestResourcePath(resDir + "cars.xml")}") + |OPTIONS (rowTag "ROW", path "${getTestResourcePath(resDir + "cars.xml")}") """.stripMargin.replaceAll("\n", " ")) assert(spark.sql("SELECT year FROM carsTable2").collect().length === 3) @@ -193,6 +200,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL test for parsing a malformed XML file") { val results = spark.read + .option("rowTag", "ROW") .option("mode", DropMalformedMode.name) .xml(getTestResourcePath(resDir + "cars-malformed.xml")) @@ -201,6 +209,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL test for dropping malformed rows") { val cars = spark.read + .option("rowTag", "ROW") .option("mode", DropMalformedMode.name) .xml(getTestResourcePath(resDir + "cars-malformed.xml")) @@ -211,6 +220,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL test for failing fast") { val exceptionInParse = intercept[SparkException] { spark.read + .option("rowTag", "ROW") .option("mode", FailFastMode.name) .xml(getTestResourcePath(resDir + "cars-malformed.xml")) .collect() @@ -245,6 +255,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL test for permissive mode for corrupt records") { val carsDf = spark.read + .option("rowTag", "ROW") .option("mode", PermissiveMode.name) .option("columnNameOfCorruptRecord", "_malformed_records") .xml(getTestResourcePath(resDir + "cars-malformed.xml")) @@ -268,6 +279,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL test with empty file and known schema") { val results = spark.read + .option("rowTag", "ROW") .schema(buildSchema(field("column", StringType, false))) .xml(getTestResourcePath(resDir + "empty.xml")) .count() @@ -283,6 +295,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { field("model"), field("comment")) val results = spark.read.schema(schema) + .option("rowTag", "ROW") .xml(getTestResourcePath(resDir + "cars-unbalanced-elements.xml")) .count() @@ -293,8 +306,8 @@ class XmlSuite extends QueryTest with SharedSparkSession { spark.sql(s""" |CREATE TEMPORARY VIEW carsTable3 |(year double, make string, model string, comments string, grp string) - |USING org.apache.spark.sql.execution.datasources.xml - |OPTIONS (path "${getTestResourcePath(resDir + "empty.xml")}") + |USING xml + |OPTIONS (rowTag "ROW", path "${getTestResourcePath(resDir + "empty.xml")}") """.stripMargin.replaceAll("\n", " ")) assert(spark.sql("SELECT count(*) FROM carsTable3").collect().head(0) === 0) @@ -304,15 +317,15 @@ class XmlSuite extends QueryTest with SharedSparkSession { val tempPath = getEmptyTempDir() spark.sql(s""" |CREATE TEMPORARY VIEW booksTableIO - |USING org.apache.spark.sql.execution.datasources.xml + |USING xml |OPTIONS (path "${getTestResourcePath(resDir + "books.xml")}", rowTag "book") """.stripMargin.replaceAll("\n", " ")) spark.sql(s""" |CREATE TEMPORARY VIEW booksTableEmpty |(author string, description string, genre string, |id string, price double, publish_date string, title string) - |USING org.apache.spark.sql.execution.datasources.xml - |OPTIONS (path "$tempPath") + |USING xml + |OPTIONS (rowTag "ROW", path "$tempPath") """.stripMargin.replaceAll("\n", " ")) assert(spark.sql("SELECT * FROM booksTableIO").collect().length === 12) @@ -329,16 +342,18 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL save with gzip compression codec") { val copyFilePath = getEmptyTempDir().resolve("cars-copy.xml") - val cars = spark.read.xml(getTestResourcePath(resDir + "cars.xml")) + val cars = spark.read + .option("rowTag", "ROW") + .xml(getTestResourcePath(resDir + "cars.xml")) cars.write .mode(SaveMode.Overwrite) - .options(Map("compression" -> classOf[GzipCodec].getName)) + .options(Map("rowTag" -> "ROW", "compression" -> classOf[GzipCodec].getName)) .xml(copyFilePath.toString) // Check that the part file has a .gz extension assert(Files.list(copyFilePath).iterator().asScala .count(_.getFileName.toString().endsWith(".xml.gz")) === 1) - val carsCopy = spark.read.xml(copyFilePath.toString) + val carsCopy = spark.read.option("rowTag", "ROW").xml(copyFilePath.toString) assert(carsCopy.count() === cars.count()) assert(carsCopy.collect().map(_.toString).toSet === cars.collect().map(_.toString).toSet) @@ -347,17 +362,19 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL save with gzip compression codec by shorten name") { val copyFilePath = getEmptyTempDir().resolve("cars-copy.xml") - val cars = spark.read.xml(getTestResourcePath(resDir + "cars.xml")) + val cars = spark.read + .option("rowTag", "ROW") + .xml(getTestResourcePath(resDir + "cars.xml")) cars.write .mode(SaveMode.Overwrite) - .options(Map("compression" -> "gZiP")) + .options(Map("rowTag" -> "ROW", "compression" -> "gZiP")) .xml(copyFilePath.toString) // Check that the part file has a .gz extension assert(Files.list(copyFilePath).iterator().asScala .count(_.getFileName.toString().endsWith(".xml.gz")) === 1) - val carsCopy = spark.read.xml(copyFilePath.toString) + val carsCopy = spark.read.option("rowTag", "ROW").xml(copyFilePath.toString) assert(carsCopy.count() === cars.count()) assert(carsCopy.collect().map(_.toString).toSet === cars.collect().map(_.toString).toSet) @@ -413,7 +430,9 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("DSL save with item") { val tempPath = getEmptyTempDir().resolve("items-temp.xml") val items = spark.createDataFrame(Seq(Tuple1(Array(Array(3, 4))))).toDF("thing").repartition(1) - items.write.option("arrayElementName", "foo").xml(tempPath.toString) + items.write + .option("rowTag", "ROW") + .option("arrayElementName", "foo").xml(tempPath.toString) val xmlFile = Files.list(tempPath).iterator.asScala @@ -474,7 +493,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { val data = spark.sparkContext.parallelize( List(List(List("aa", "bb"), List("aa", "bb"))).map(Row(_))) val df = spark.createDataFrame(data, schema) - df.write.xml(copyFilePath.toString) + df.write.option("rowTag", "ROW").xml(copyFilePath.toString) // When [[ArrayType]] has [[ArrayType]] as elements, it is confusing what is the element // name for XML file. Now, it is "item" by default. So, "item" field is additionally added @@ -482,7 +501,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { val schemaCopy = buildSchema( structArray("a", field(XmlOptions.DEFAULT_ARRAY_ELEMENT_NAME, ArrayType(StringType)))) - val dfCopy = spark.read.xml(copyFilePath.toString) + val dfCopy = spark.read.option("rowTag", "ROW").xml(copyFilePath.toString) assert(dfCopy.count() === df.count()) assert(dfCopy.schema === schemaCopy) @@ -518,9 +537,9 @@ class XmlSuite extends QueryTest with SharedSparkSession { val data = spark.sparkContext.parallelize(Seq(row)) val df = spark.createDataFrame(data, schema) - df.write.xml(copyFilePath.toString) + df.write.option("rowTag", "ROW").xml(copyFilePath.toString) - val dfCopy = spark.read.schema(schema) + val dfCopy = spark.read.option("rowTag", "ROW").schema(schema) .xml(copyFilePath.toString) assert(dfCopy.collect() === df.collect()) @@ -685,7 +704,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { field("comment"), field("color"), field("year", IntegerType)) - val results = spark.read.schema(schema) + val results = spark.read.option("rowTag", "ROW").schema(schema) .xml(getTestResourcePath(resDir + "cars-unbalanced-elements.xml")) .count() @@ -693,7 +712,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { } test("DSL test inferred schema passed through") { - val dataFrame = spark.read.xml(getTestResourcePath(resDir + "cars.xml")) + val dataFrame = spark.read.option("rowTag", "ROW").xml(getTestResourcePath(resDir + "cars.xml")) val results = dataFrame .select("comment", "year") @@ -706,7 +725,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { val schema = buildSchema( field("name", StringType, false), field("age")) - val results = spark.read.schema(schema) + val results = spark.read.option("rowTag", "ROW").schema(schema) .xml(getTestResourcePath(resDir + "null-numbers.xml")) .select("name", "age") .collect() @@ -721,6 +740,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { field("name", StringType, false), field("age", IntegerType)) val results = spark.read.schema(schema) + .option("rowTag", "ROW") .option("treatEmptyValuesAsNulls", true) .xml(getTestResourcePath(resDir + "null-numbers.xml")) .select("name", "age") @@ -808,6 +828,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { field("a", IntegerType))) val result = spark.read.schema(schema) + .option("rowTag", "ROW") .xml(getTestResourcePath(resDir + "simple-nested-objects.xml")) .select("c.a", "c.b") .collect() @@ -858,7 +879,9 @@ class XmlSuite extends QueryTest with SharedSparkSession { } test("Skip and project currently XML files without indentation") { - val df = spark.read.xml(getTestResourcePath(resDir + "cars-no-indentation.xml")) + val df = spark.read + .option("rowTag", "ROW") + .xml(getTestResourcePath(resDir + "cars-no-indentation.xml")) val results = df.select("model").collect() val years = results.map(_(0)).toSet assert(years === Set("S", "E350", "Volt")) @@ -880,10 +903,11 @@ class XmlSuite extends QueryTest with SharedSparkSession { val messageOne = intercept[IllegalArgumentException] { spark.read.option("rowTag", "").xml(getTestResourcePath(resDir + "cars.xml")) }.getMessage - assert(messageOne === "requirement failed: 'rowTag' option should not be empty string.") + assert(messageOne === "requirement failed: 'rowTag' option should not be an empty string.") val messageThree = intercept[IllegalArgumentException] { - spark.read.option("valueTag", "").xml(getTestResourcePath(resDir + "cars.xml")) + spark.read.option("rowTag", "ROW") + .option("valueTag", "").xml(getTestResourcePath(resDir + "cars.xml")) }.getMessage assert(messageThree === "requirement failed: 'valueTag' option should not be empty string.") } @@ -895,18 +919,21 @@ class XmlSuite extends QueryTest with SharedSparkSession { assert(messageOne === "requirement failed: 'rowTag' should not include angle brackets") val messageTwo = intercept[IllegalArgumentException] { - spark.read.option("rowTag", "<ROW").xml(getTestResourcePath(resDir + "cars.xml")) + spark.read.option("rowTag", "ROW") + .option("rowTag", "<ROW").xml(getTestResourcePath(resDir + "cars.xml")) }.getMessage assert( messageTwo === "requirement failed: 'rowTag' should not include angle brackets") val messageThree = intercept[IllegalArgumentException] { - spark.read.option("rootTag", "ROWSET>").xml(getTestResourcePath(resDir + "cars.xml")) + spark.read.option("rowTag", "ROW") + .option("rootTag", "ROWSET>").xml(getTestResourcePath(resDir + "cars.xml")) }.getMessage assert(messageThree === "requirement failed: 'rootTag' should not include angle brackets") val messageFour = intercept[IllegalArgumentException] { - spark.read.option("rootTag", "<ROWSET").xml(getTestResourcePath(resDir + "cars.xml")) + spark.read.option("rowTag", "ROW") + .option("rootTag", "<ROWSET").xml(getTestResourcePath(resDir + "cars.xml")) }.getMessage assert(messageFour === "requirement failed: 'rootTag' should not include angle brackets") } @@ -914,6 +941,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("valueTag and attributePrefix should not be the same.") { val messageOne = intercept[IllegalArgumentException] { spark.read + .option("rowTag", "ROW") .option("valueTag", "#abc") .option("attributePrefix", "#abc") .xml(getTestResourcePath(resDir + "cars.xml")) @@ -924,6 +952,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("nullValue and treatEmptyValuesAsNulls test") { val resultsOne = spark.read + .option("rowTag", "ROW") .option("treatEmptyValuesAsNulls", "true") .xml(getTestResourcePath(resDir + "gps-empty-field.xml")) assert(resultsOne.selectExpr("extensions.TrackPointExtension").head().getStruct(0) !== null) @@ -934,6 +963,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { assert(resultsOne.collect().length === 2) val resultsTwo = spark.read + .option("rowTag", "ROW") .option("nullValue", "2013-01-24T06:18:43Z") .xml(getTestResourcePath(resDir + "gps-empty-field.xml")) assert(resultsTwo.selectExpr("time").head().getStruct(0) === null) @@ -1015,7 +1045,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { field("non-empty-tag", IntegerType), field("self-closing-tag", IntegerType)) - val result = spark.read.schema(schema) + val result = spark.read.option("rowTag", "ROW").schema(schema) .xml(getTestResourcePath(resDir + "self-closing-tag.xml")) .collect() @@ -1054,6 +1084,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { field("integer_map", MapType(StringType, IntegerType)), field("_malformed_records", StringType)) val results = spark.read + .option("rowTag", "ROW") .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", "_malformed_records") .schema(schema) @@ -1178,7 +1209,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { "<ROW><year>2015</year><make>Chevy</make><model>Volt</model><comment>No</comment></ROW>") val xmlRDD = spark.sparkContext.parallelize(data) val ds = spark.createDataset(xmlRDD)(Encoders.STRING) - assert(spark.read.xml(ds).collect().length === 3) + assert(spark.read.option("rowTag", "ROW").xml(ds).collect().length === 3) } import testImplicits._ @@ -1308,10 +1339,11 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("rootTag with simple attributes") { val xmlPath = getEmptyTempDir().resolve("simple_attributes") val df = spark.createDataFrame(Seq((42, "foo"))).toDF("number", "value").repartition(1) - df.write. - option("rootTag", "root foo='bar' bing=\"baz\""). - option("declaration", ""). - xml(xmlPath.toString) + df.write + .option("rowTag", "ROW") + .option("rootTag", "root foo='bar' bing=\"baz\"") + .option("declaration", "") + .xml(xmlPath.toString) val xmlFile = Files.list(xmlPath).iterator.asScala.filter(_.getFileName.toString.startsWith("part-")).next() @@ -1651,10 +1683,12 @@ class XmlSuite extends QueryTest with SharedSparkSession { val results = Seq( // user specified schema spark.read + .option("rowTag", "ROW") .schema(schema) .xml(getTestResourcePath(resDir + "root-level-value.xml")).collect(), // schema inference spark.read + .option("rowTag", "ROW") .xml(getTestResourcePath(resDir + "root-level-value.xml")).collect()) results.foreach { result => assert(result.length === 3) @@ -1677,10 +1711,12 @@ class XmlSuite extends QueryTest with SharedSparkSession { val dfs = Seq( // user specified schema spark.read + .option("rowTag", "ROW") .schema(schema) .xml(getTestResourcePath(resDir + "root-level-value-none.xml")), // schema inference spark.read + .option("rowTag", "ROW") .xml(getTestResourcePath(resDir + "root-level-value-none.xml")) ) dfs.foreach { df => @@ -1720,4 +1756,27 @@ class XmlSuite extends QueryTest with SharedSparkSession { assert(result.select("decoded._VALUE").head().getLong(0) === 123456L) assert(result.select("decoded._attr").head().getString(0) === "attr1") } + + test("Test XML Options Error Messages") { + def checkXmlOptionErrorMessage( + parameters: Map[String, String] = Map.empty, + msg: String): Unit = { + val e = intercept[IllegalArgumentException] { + spark.read + .options(parameters) + .xml(getTestResourcePath(resDir + "ages.xml")) + .collect() + } + assert(e.getMessage.contains(msg)) + } + + checkXmlOptionErrorMessage(Map.empty, "'rowTag' option is required.") + checkXmlOptionErrorMessage(Map("rowTag" -> ""), + "'rowTag' option should not be an empty string.") + checkXmlOptionErrorMessage(Map("rowTag" -> " "), + "'rowTag' option should not be an empty string.") + checkXmlOptionErrorMessage(Map("rowTag" -> "person", + "declaration" -> s"<${XmlOptions.DEFAULT_DECLARATION}>"), + "'declaration' should not include angle brackets") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala index 176cfd98563..1798d32d8a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala @@ -69,9 +69,9 @@ final class StaxXmlGeneratorSuite extends SharedSparkSession { val df = dataset.toDF().orderBy("booleanDatum") val targetFile = Files.createTempDirectory("StaxXmlGeneratorSuite").resolve("roundtrip.xml").toString - df.write.format("xml").save(targetFile) + df.write.option("rowTag", "ROW").xml(targetFile) val newDf = - spark.read.schema(df.schema).format("xml").load(targetFile).orderBy("booleanDatum") + spark.read.option("rowTag", "ROW").schema(df.schema).xml(targetFile).orderBy("booleanDatum") assert(df.collect().toSeq === newDf.collect().toSeq) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org