This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 018808236708 [SPARK-46599][SQL] XML: Use TypeCoercion.findTightestCommonType for compatibility check 018808236708 is described below commit 018808236708bea7a78618abf750bea39be3c9f8 Author: Sandip Agarwala <131817656+sandip...@users.noreply.github.com> AuthorDate: Sun Jan 7 09:58:13 2024 +0900 [SPARK-46599][SQL] XML: Use TypeCoercion.findTightestCommonType for compatibility check ### What changes were proposed in this pull request? Make the following changes to the XML schema inference: - Use TypeCoercion.findTightestCommonType for compatibility check. - Update DecimalType to support scale > 0 - Create a spark job so that TypeCoercion can access the SQLConf. - Added reduceOption so that each partition returns just one StructType as opposed to a list of StructType ### Why are the changes needed? To achieve consistency of dataType compatibility checks with other formats. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Existing and new unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #44601 from sandip-db/xml-typecoercion. Authored-by: Sandip Agarwala <131817656+sandip...@users.noreply.github.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../spark/sql/catalyst/xml/XmlInferSchema.scala | 245 ++++++++++----------- .../sql/execution/datasources/xml/XmlSuite.scala | 91 +++++++- 2 files changed, 203 insertions(+), 133 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index 9d0c16d95e46..59222f56454f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -31,9 +31,11 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.{DateFormatter, PermissiveMode, TimestampFormatter} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT +import org.apache.spark.sql.catalyst.xml.XmlInferSchema.compatibleType import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ @@ -63,32 +65,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) legacyFormat = FAST_DATE_FORMAT, isParsing = true) - /** - * Copied from internal Spark api - * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion]] - */ - private val numericPrecedence: IndexedSeq[DataType] = - IndexedSeq[DataType]( - ByteType, - ShortType, - IntegerType, - LongType, - FloatType, - DoubleType, - TimestampType, - DecimalType.SYSTEM_DEFAULT) - - private val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { - case (t1, t2) if t1 == t2 => Some(t1) - - // Promote numeric types to the highest of the two - case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => - val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) - Some(numericPrecedence(index)) - - case _ => None - } - /** * Infer the type of a collection of XML records in three stages: * 1. Infer the type of each record @@ -102,13 +78,26 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) xml } // perform schema inference on each row and merge afterwards - val rootType = schemaData.mapPartitions { iter => + val mergedTypesFromPartitions = schemaData.mapPartitions { iter => val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema) iter.flatMap { xml => infer(xml, xsdSchema) + }.reduceOption(compatibleType(caseSensitive, options.valueTag)).iterator + } + + // Here we manually submit a fold-like Spark job, so that we can set the SQLConf when running + // the fold functions in the scheduler event loop thread. + val existingConf = SQLConf.get + var rootType: DataType = StructType(Nil) + val foldPartition = (iter: Iterator[DataType]) => + iter.fold(StructType(Nil))(compatibleType(caseSensitive, options.valueTag)) + val mergeResult = (index: Int, taskResult: DataType) => { + rootType = SQLConf.withExistingConf(existingConf) { + compatibleType(caseSensitive, options.valueTag)(rootType, taskResult) } - }.fold(StructType(Seq()))(compatibleType) + } + xml.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -339,21 +328,10 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) return None } - try { - // The conversion can fail when the `field` is not a form of number. - val bigDecimal = decimalParser(signSafeValue) - // Because many other formats do not support decimal, it reduces the cases for - // decimals by disallowing values having scale (e.g. `1.1`). - if (bigDecimal.scale <= 0) { - // `DecimalType` conversion can fail when - // 1. The precision is bigger than 38. - // 2. scale is bigger than precision. - return Some(DecimalType(bigDecimal.precision, bigDecimal.scale)) - } - } catch { - case _ : Exception => + allCatch opt { + val bigDecimal = decimalParser(value) + DecimalType(Math.max(bigDecimal.precision, bigDecimal.scale), bigDecimal.scale) } - None } private def isDouble(value: String): Boolean = { @@ -451,88 +429,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case other => Some(other) } - /** - * Returns the most general data type for two given data types. - */ - private[xml] def compatibleType(t1: DataType, t2: DataType): DataType = { - - def normalize(name: String): String = { - if (caseSensitive) name else name.toLowerCase(Locale.ROOT) - } - - // TODO: Optimise this logic. - findTightestCommonTypeOfTwo(t1, t2).getOrElse { - // t1 or t2 is a StructType, ArrayType, or an unexpected type. - (t1, t2) match { - // Double support larger range than fixed decimal, DecimalType.Maximum should be enough - // in most case, also have better precision. - case (DoubleType, _: DecimalType) => - DoubleType - case (_: DecimalType, DoubleType) => - DoubleType - case (t1: DecimalType, t2: DecimalType) => - val scale = math.max(t1.scale, t2.scale) - val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) - if (range + scale > 38) { - // DecimalType can't support precision > 38 - DoubleType - } else { - DecimalType(range + scale, scale) - } - case (TimestampNTZType, TimestampType) | (TimestampType, TimestampNTZType) => - TimestampType - - case (StructType(fields1), StructType(fields2)) => - val newFields = (fields1 ++ fields2) - // normalize field name and pair it with original field - .map(field => (normalize(field.name), field)) - .groupBy(_._1) // group by normalized field name - .map { case (_: String, fields: Array[(String, StructField)]) => - val fieldTypes = fields.map(_._2) - val dataType = fieldTypes.map(_.dataType).reduce(compatibleType) - // we pick up the first field name that we've encountered for the field - StructField(fields.head._2.name, dataType) - } - StructType(newFields.toArray.sortBy(_.name)) - - case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType( - compatibleType(elementType1, elementType2), containsNull1 || containsNull2) - - // In XML datasource, since StructType can be compared with ArrayType. - // In this case, ArrayType wraps the StructType. - case (ArrayType(ty1, _), ty2) => - ArrayType(compatibleType(ty1, ty2)) - - case (ty1, ArrayType(ty2, _)) => - ArrayType(compatibleType(ty1, ty2)) - - // As this library can infer an element with attributes as StructType whereas - // some can be inferred as other non-structural data types, this case should be - // treated. - case (st: StructType, dt: DataType) if st.fieldNames.contains(options.valueTag) => - val valueIndex = st.fieldNames.indexOf(options.valueTag) - val valueField = st.fields(valueIndex) - val valueDataType = compatibleType(valueField.dataType, dt) - st.fields(valueIndex) = StructField(options.valueTag, valueDataType, nullable = true) - st - - case (dt: DataType, st: StructType) if st.fieldNames.contains(options.valueTag) => - val valueIndex = st.fieldNames.indexOf(options.valueTag) - val valueField = st.fields(valueIndex) - val valueDataType = compatibleType(dt, valueField.dataType) - st.fields(valueIndex) = StructField(options.valueTag, valueDataType, nullable = true) - st - - // TODO: These null type checks should be in `findTightestCommonTypeOfTwo`. - case (_, NullType) => t1 - case (NullType, _) => t2 - // strings and every string is a XML object. - case (_, _) => StringType - } - } - } - /** * This helper function merges the data type of value tags and inner elements. * It could only be structure data. Consider the following case, @@ -560,9 +456,11 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) updateStructField( st, index, - ArrayType(compatibleType(st(index).dataType, valueTagType))) + ArrayType(compatibleType(caseSensitive, options.valueTag)( + st(index).dataType, valueTagType))) case Some(index) => - updateStructField(st, index, compatibleType(st(index).dataType, valueTagType)) + updateStructField(st, index, compatibleType(caseSensitive, options.valueTag)( + st(index).dataType, valueTagType)) case None => st.add(options.valueTag, valueTagType) } @@ -596,11 +494,102 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) // If the field name already exists, // merge the type and infer the combined field as an array type if necessary case Some(oldType) if !oldType.isInstanceOf[ArrayType] && !newType.isInstanceOf[NullType] => - ArrayType(compatibleType(oldType, newType)) + ArrayType(compatibleType(caseSensitive, options.valueTag)(oldType, newType)) case Some(oldType) => - compatibleType(oldType, newType) + compatibleType(caseSensitive, options.valueTag)(oldType, newType) case None => newType } } } + +object XmlInferSchema { + def normalize(name: String, caseSensitive: Boolean): String = { + if (caseSensitive) name else name.toLowerCase(Locale.ROOT) + } + + /** + * Returns the most general data type for two given data types. + */ + private[xml] def compatibleType(caseSensitive: Boolean, valueTag: String) + (t1: DataType, t2: DataType): DataType = { + + // TODO: Optimise this logic. + TypeCoercion.findTightestCommonType(t1, t2).getOrElse { + // t1 or t2 is a StructType, ArrayType, or an unexpected type. + (t1, t2) match { + // Double support larger range than fixed decimal, DecimalType.Maximum should be enough + // in most case, also have better precision. + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => + DoubleType + + case (t1: DecimalType, t2: DecimalType) => + val scale = math.max(t1.scale, t2.scale) + val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) + if (range + scale > 38) { + // DecimalType can't support precision > 38 + DoubleType + } else { + DecimalType(range + scale, scale) + } + case (TimestampNTZType, TimestampType) | (TimestampType, TimestampNTZType) => + TimestampType + + case (StructType(fields1), StructType(fields2)) => + val newFields = (fields1 ++ fields2) + // normalize field name and pair it with original field + .map(field => (normalize(field.name, caseSensitive), field)) + .groupBy(_._1) // group by normalized field name + .map { case (_: String, fields: Array[(String, StructField)]) => + val fieldTypes = fields.map(_._2) + val dataType = fieldTypes.map(_.dataType) + .reduce(compatibleType(caseSensitive, valueTag)) + // we pick up the first field name that we've encountered for the field + StructField(fields.head._2.name, dataType) + } + StructType(newFields.toArray.sortBy(_.name)) + + case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => + ArrayType( + compatibleType(caseSensitive, valueTag)( + elementType1, elementType2), containsNull1 || containsNull2) + + // In XML datasource, since StructType can be compared with ArrayType. + // In this case, ArrayType wraps the StructType. + case (ArrayType(ty1, _), ty2) => + ArrayType(compatibleType(caseSensitive, valueTag)(ty1, ty2)) + + case (ty1, ArrayType(ty2, _)) => + ArrayType(compatibleType(caseSensitive, valueTag)(ty1, ty2)) + + // As this library can infer an element with attributes as StructType whereas + // some can be inferred as other non-structural data types, this case should be + // treated. + case (st: StructType, dt: DataType) if st.fieldNames.contains(valueTag) => + val valueIndex = st.fieldNames.indexOf(valueTag) + val valueField = st.fields(valueIndex) + val valueDataType = compatibleType(caseSensitive, valueTag)(valueField.dataType, dt) + st.fields(valueIndex) = StructField(valueTag, valueDataType, nullable = true) + st + + case (dt: DataType, st: StructType) if st.fieldNames.contains(valueTag) => + val valueIndex = st.fieldNames.indexOf(valueTag) + val valueField = st.fields(valueIndex) + val valueDataType = compatibleType(caseSensitive, valueTag)(dt, valueField.dataType) + st.fields(valueIndex) = StructField(valueTag, valueDataType, nullable = true) + st + + // The case that given `DecimalType` is capable of given `IntegralType` is handled in + // `findTightestCommonType`. Both cases below will be executed only when the given + // `DecimalType` is not capable of the given `IntegralType`. + case (t1: IntegralType, t2: DecimalType) => + compatibleType(caseSensitive, valueTag)(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + compatibleType(caseSensitive, valueTag)(t1, DecimalType.forType(t2)) + + // strings and every string is a XML object. + case (_, _) => StringType + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 4b9a95856afb..78f9d5285c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -2127,7 +2127,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { <string>this is a simple string.</string> <integer>10</integer> <long>21474836470</long> - <decimal>92233720368547758070</decimal> + <bigInteger>92233720368547758070</bigInteger> <double>1.7976931348623157</double> <boolean>true</boolean> <null>null</null> @@ -2137,7 +2137,8 @@ class XmlSuite extends QueryTest with SharedSparkSession { val dfWithNodecimal = spark.read .option("nullValue", "null") .xml(primitiveFieldAndType) - assert(dfWithNodecimal.schema("decimal").dataType === DoubleType) + assert(dfWithNodecimal.schema("bigInteger").dataType === DoubleType) + assert(dfWithNodecimal.schema("double").dataType === DoubleType) val df = spark.read .option("nullValue", "null") @@ -2145,9 +2146,9 @@ class XmlSuite extends QueryTest with SharedSparkSession { .xml(primitiveFieldAndType) val expectedSchema = StructType( + StructField("bigInteger", DecimalType(20, 0), true) :: StructField("boolean", BooleanType, true) :: - StructField("decimal", DecimalType(20, 0), true) :: - StructField("double", DoubleType, true) :: + StructField("double", DecimalType(17, 16), true) :: StructField("integer", LongType, true) :: StructField("long", LongType, true) :: StructField("null", StringType, true) :: @@ -2157,8 +2158,9 @@ class XmlSuite extends QueryTest with SharedSparkSession { checkAnswer( df, - Row(true, + Row( new java.math.BigDecimal("92233720368547758070"), + true, 1.7976931348623157, 10, 21474836470L, @@ -2604,4 +2606,83 @@ class XmlSuite extends QueryTest with SharedSparkSession { checkAnswer(df, expectedAns) } + + test("Find compatible types even if inferred DecimalType is not capable of other IntegralType") { + val mixedIntegerAndDoubleRecords = Seq( + """<ROW><a>3</a><b>1.1</b></ROW>""", + s"""<ROW><a>3.1</a><b>0.${"0" * 38}1</b></ROW>""").toDS() + val xmlDF = spark.read + .option("prefersDecimal", "true") + .option("rowTag", "ROW") + .xml(mixedIntegerAndDoubleRecords) + + // The values in `a` field will be decimals as they fit in decimal. For `b` field, + // they will be doubles as `1.0E-39D` does not fit. + val expectedSchema = StructType( + StructField("a", DecimalType(21, 1), true) :: + StructField("b", DoubleType, true) :: Nil) + + assert(xmlDF.schema === expectedSchema) + checkAnswer( + xmlDF, + Row(BigDecimal("3"), 1.1D) :: + Row(BigDecimal("3.1"), 1.0E-39D) :: Nil + ) + } + + def bigIntegerRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + s"""<ROW><a>1${"0" * 38}</a><b>92233720368547758070</b></ROW>""" :: Nil))(Encoders.STRING) + + test("Infer big integers correctly even when it does not fit in decimal") { + val df = spark.read + .option("rowTag", "ROW") + .option("prefersDecimal", "true") + .xml(bigIntegerRecords) + + // The value in `a` field will be a double as it does not fit in decimal. For `b` field, + // it will be a decimal as `92233720368547758070`. + val expectedSchema = StructType( + StructField("a", DoubleType, true) :: + StructField("b", DecimalType(20, 0), true) :: Nil) + + assert(df.schema === expectedSchema) + checkAnswer(df, Row(1.0E38D, BigDecimal("92233720368547758070"))) + } + + def floatingValueRecords: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + s"""<ROW><a>0.${"0" * 38}1</a><b>.01</b></ROW>""" :: Nil))(Encoders.STRING) + + test("Infer floating-point values correctly even when it does not fit in decimal") { + val df = spark.read + .option("prefersDecimal", "true") + .option("rowTag", "ROW") + .xml(floatingValueRecords) + + // The value in `a` field will be a double as it does not fit in decimal. For `b` field, + // it will be a decimal as `0.01` by having a precision equal to the scale. + val expectedSchema = StructType( + StructField("a", DoubleType, true) :: + StructField("b", DecimalType(2, 2), true) :: Nil) + + assert(df.schema === expectedSchema) + checkAnswer(df, Row(1.0E-39D, BigDecimal("0.01"))) + + val mergedDF = spark.read + .option("prefersDecimal", "true") + .option("rowTag", "ROW") + .xml(floatingValueRecords.union(bigIntegerRecords)) + + val expectedMergedSchema = StructType( + StructField("a", DoubleType, true) :: + StructField("b", DecimalType(22, 2), true) :: Nil) + + assert(expectedMergedSchema === mergedDF.schema) + checkAnswer( + mergedDF, + Row(1.0E-39D, BigDecimal("0.01")) :: + Row(1.0E38D, BigDecimal("92233720368547758070")) :: Nil + ) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org