This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 018808236708 [SPARK-46599][SQL] XML: Use 
TypeCoercion.findTightestCommonType for compatibility check
018808236708 is described below

commit 018808236708bea7a78618abf750bea39be3c9f8
Author: Sandip Agarwala <131817656+sandip...@users.noreply.github.com>
AuthorDate: Sun Jan 7 09:58:13 2024 +0900

    [SPARK-46599][SQL] XML: Use TypeCoercion.findTightestCommonType for 
compatibility check
    
    ### What changes were proposed in this pull request?
    Make the following changes to the XML schema inference:
    - Use TypeCoercion.findTightestCommonType for compatibility check.
    - Update DecimalType to support scale > 0
    - Create a spark job so that TypeCoercion can access the SQLConf.
    - Added reduceOption so that each partition returns just one StructType as 
opposed to a list of StructType
    
    ### Why are the changes needed?
    To achieve consistency of dataType compatibility checks with other formats.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    ### How was this patch tested?
    Existing and new unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #44601 from sandip-db/xml-typecoercion.
    
    Authored-by: Sandip Agarwala <131817656+sandip...@users.noreply.github.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../spark/sql/catalyst/xml/XmlInferSchema.scala    | 245 ++++++++++-----------
 .../sql/execution/datasources/xml/XmlSuite.scala   |  91 +++++++-
 2 files changed, 203 insertions(+), 133 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
index 9d0c16d95e46..59222f56454f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
@@ -31,9 +31,11 @@ import scala.util.control.NonFatal
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion
 import org.apache.spark.sql.catalyst.expressions.ExprUtils
 import org.apache.spark.sql.catalyst.util.{DateFormatter, PermissiveMode, 
TimestampFormatter}
 import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
+import org.apache.spark.sql.catalyst.xml.XmlInferSchema.compatibleType
 import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
 import org.apache.spark.sql.types._
 
@@ -63,32 +65,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
     legacyFormat = FAST_DATE_FORMAT,
     isParsing = true)
 
-  /**
-   * Copied from internal Spark api
-   * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion]]
-   */
-  private val numericPrecedence: IndexedSeq[DataType] =
-    IndexedSeq[DataType](
-      ByteType,
-      ShortType,
-      IntegerType,
-      LongType,
-      FloatType,
-      DoubleType,
-      TimestampType,
-      DecimalType.SYSTEM_DEFAULT)
-
-  private val findTightestCommonTypeOfTwo: (DataType, DataType) => 
Option[DataType] = {
-    case (t1, t2) if t1 == t2 => Some(t1)
-
-    // Promote numeric types to the highest of the two
-    case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
-      val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
-      Some(numericPrecedence(index))
-
-    case _ => None
-  }
-
   /**
    * Infer the type of a collection of XML records in three stages:
    *   1. Infer the type of each record
@@ -102,13 +78,26 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
       xml
     }
     // perform schema inference on each row and merge afterwards
-    val rootType = schemaData.mapPartitions { iter =>
+    val mergedTypesFromPartitions = schemaData.mapPartitions { iter =>
       val xsdSchema = 
Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)
 
       iter.flatMap { xml =>
         infer(xml, xsdSchema)
+      }.reduceOption(compatibleType(caseSensitive, options.valueTag)).iterator
+    }
+
+    // Here we manually submit a fold-like Spark job, so that we can set the 
SQLConf when running
+    // the fold functions in the scheduler event loop thread.
+    val existingConf = SQLConf.get
+    var rootType: DataType = StructType(Nil)
+    val foldPartition = (iter: Iterator[DataType]) =>
+      iter.fold(StructType(Nil))(compatibleType(caseSensitive, 
options.valueTag))
+    val mergeResult = (index: Int, taskResult: DataType) => {
+      rootType = SQLConf.withExistingConf(existingConf) {
+        compatibleType(caseSensitive, options.valueTag)(rootType, taskResult)
       }
-    }.fold(StructType(Seq()))(compatibleType)
+    }
+    xml.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, 
mergeResult)
 
     canonicalizeType(rootType) match {
       case Some(st: StructType) => st
@@ -339,21 +328,10 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
       return None
     }
 
-    try {
-      // The conversion can fail when the `field` is not a form of number.
-      val bigDecimal = decimalParser(signSafeValue)
-      // Because many other formats do not support decimal, it reduces the 
cases for
-      // decimals by disallowing values having scale (e.g. `1.1`).
-      if (bigDecimal.scale <= 0) {
-        // `DecimalType` conversion can fail when
-        //   1. The precision is bigger than 38.
-        //   2. scale is bigger than precision.
-        return Some(DecimalType(bigDecimal.precision, bigDecimal.scale))
-      }
-    } catch {
-      case _ : Exception =>
+    allCatch opt {
+      val bigDecimal = decimalParser(value)
+      DecimalType(Math.max(bigDecimal.precision, bigDecimal.scale), 
bigDecimal.scale)
     }
-    None
   }
 
   private def isDouble(value: String): Boolean = {
@@ -451,88 +429,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
     case other => Some(other)
   }
 
-  /**
-   * Returns the most general data type for two given data types.
-   */
-  private[xml] def compatibleType(t1: DataType, t2: DataType): DataType = {
-
-    def normalize(name: String): String = {
-      if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
-    }
-
-    // TODO: Optimise this logic.
-    findTightestCommonTypeOfTwo(t1, t2).getOrElse {
-      // t1 or t2 is a StructType, ArrayType, or an unexpected type.
-      (t1, t2) match {
-        // Double support larger range than fixed decimal, DecimalType.Maximum 
should be enough
-        // in most case, also have better precision.
-        case (DoubleType, _: DecimalType) =>
-          DoubleType
-        case (_: DecimalType, DoubleType) =>
-          DoubleType
-        case (t1: DecimalType, t2: DecimalType) =>
-          val scale = math.max(t1.scale, t2.scale)
-          val range = math.max(t1.precision - t1.scale, t2.precision - 
t2.scale)
-          if (range + scale > 38) {
-            // DecimalType can't support precision > 38
-            DoubleType
-          } else {
-            DecimalType(range + scale, scale)
-          }
-        case (TimestampNTZType, TimestampType) | (TimestampType, 
TimestampNTZType) =>
-          TimestampType
-
-        case (StructType(fields1), StructType(fields2)) =>
-          val newFields = (fields1 ++ fields2)
-            // normalize field name and pair it with original field
-            .map(field => (normalize(field.name), field))
-            .groupBy(_._1) // group by normalized field name
-            .map { case (_: String, fields: Array[(String, StructField)]) =>
-              val fieldTypes = fields.map(_._2)
-              val dataType = fieldTypes.map(_.dataType).reduce(compatibleType)
-              // we pick up the first field name that we've encountered for 
the field
-              StructField(fields.head._2.name, dataType)
-          }
-          StructType(newFields.toArray.sortBy(_.name))
-
-        case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, 
containsNull2)) =>
-          ArrayType(
-            compatibleType(elementType1, elementType2), containsNull1 || 
containsNull2)
-
-        // In XML datasource, since StructType can be compared with ArrayType.
-        // In this case, ArrayType wraps the StructType.
-        case (ArrayType(ty1, _), ty2) =>
-          ArrayType(compatibleType(ty1, ty2))
-
-        case (ty1, ArrayType(ty2, _)) =>
-          ArrayType(compatibleType(ty1, ty2))
-
-        // As this library can infer an element with attributes as StructType 
whereas
-        // some can be inferred as other non-structural data types, this case 
should be
-        // treated.
-        case (st: StructType, dt: DataType) if 
st.fieldNames.contains(options.valueTag) =>
-          val valueIndex = st.fieldNames.indexOf(options.valueTag)
-          val valueField = st.fields(valueIndex)
-          val valueDataType = compatibleType(valueField.dataType, dt)
-          st.fields(valueIndex) = StructField(options.valueTag, valueDataType, 
nullable = true)
-          st
-
-        case (dt: DataType, st: StructType) if 
st.fieldNames.contains(options.valueTag) =>
-          val valueIndex = st.fieldNames.indexOf(options.valueTag)
-          val valueField = st.fields(valueIndex)
-          val valueDataType = compatibleType(dt, valueField.dataType)
-          st.fields(valueIndex) = StructField(options.valueTag, valueDataType, 
nullable = true)
-          st
-
-        // TODO: These null type checks should be in 
`findTightestCommonTypeOfTwo`.
-        case (_, NullType) => t1
-        case (NullType, _) => t2
-        // strings and every string is a XML object.
-        case (_, _) => StringType
-      }
-    }
-  }
-
   /**
    * This helper function merges the data type of value tags and inner 
elements.
    * It could only be structure data. Consider the following case,
@@ -560,9 +456,11 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
             updateStructField(
               st,
               index,
-              ArrayType(compatibleType(st(index).dataType, valueTagType)))
+              ArrayType(compatibleType(caseSensitive, options.valueTag)(
+                st(index).dataType, valueTagType)))
           case Some(index) =>
-            updateStructField(st, index, compatibleType(st(index).dataType, 
valueTagType))
+            updateStructField(st, index, compatibleType(caseSensitive, 
options.valueTag)(
+              st(index).dataType, valueTagType))
           case None =>
             st.add(options.valueTag, valueTagType)
         }
@@ -596,11 +494,102 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
       // If the field name already exists,
       // merge the type and infer the combined field as an array type if 
necessary
       case Some(oldType) if !oldType.isInstanceOf[ArrayType] && 
!newType.isInstanceOf[NullType] =>
-        ArrayType(compatibleType(oldType, newType))
+        ArrayType(compatibleType(caseSensitive, options.valueTag)(oldType, 
newType))
       case Some(oldType) =>
-        compatibleType(oldType, newType)
+        compatibleType(caseSensitive, options.valueTag)(oldType, newType)
       case None =>
         newType
     }
   }
 }
+
+object XmlInferSchema {
+  def normalize(name: String, caseSensitive: Boolean): String = {
+    if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
+  }
+
+  /**
+   * Returns the most general data type for two given data types.
+   */
+  private[xml] def compatibleType(caseSensitive: Boolean, valueTag: String)
+    (t1: DataType, t2: DataType): DataType = {
+
+    // TODO: Optimise this logic.
+    TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
+      // t1 or t2 is a StructType, ArrayType, or an unexpected type.
+      (t1, t2) match {
+        // Double support larger range than fixed decimal, DecimalType.Maximum 
should be enough
+        // in most case, also have better precision.
+        case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
+          DoubleType
+
+        case (t1: DecimalType, t2: DecimalType) =>
+          val scale = math.max(t1.scale, t2.scale)
+          val range = math.max(t1.precision - t1.scale, t2.precision - 
t2.scale)
+          if (range + scale > 38) {
+            // DecimalType can't support precision > 38
+            DoubleType
+          } else {
+            DecimalType(range + scale, scale)
+          }
+        case (TimestampNTZType, TimestampType) | (TimestampType, 
TimestampNTZType) =>
+          TimestampType
+
+        case (StructType(fields1), StructType(fields2)) =>
+          val newFields = (fields1 ++ fields2)
+           // normalize field name and pair it with original field
+           .map(field => (normalize(field.name, caseSensitive), field))
+           .groupBy(_._1) // group by normalized field name
+           .map { case (_: String, fields: Array[(String, StructField)]) =>
+             val fieldTypes = fields.map(_._2)
+             val dataType = fieldTypes.map(_.dataType)
+               .reduce(compatibleType(caseSensitive, valueTag))
+             // we pick up the first field name that we've encountered for the 
field
+             StructField(fields.head._2.name, dataType)
+           }
+           StructType(newFields.toArray.sortBy(_.name))
+
+        case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, 
containsNull2)) =>
+          ArrayType(
+            compatibleType(caseSensitive, valueTag)(
+              elementType1, elementType2), containsNull1 || containsNull2)
+
+        // In XML datasource, since StructType can be compared with ArrayType.
+        // In this case, ArrayType wraps the StructType.
+        case (ArrayType(ty1, _), ty2) =>
+          ArrayType(compatibleType(caseSensitive, valueTag)(ty1, ty2))
+
+        case (ty1, ArrayType(ty2, _)) =>
+          ArrayType(compatibleType(caseSensitive, valueTag)(ty1, ty2))
+
+        // As this library can infer an element with attributes as StructType 
whereas
+        // some can be inferred as other non-structural data types, this case 
should be
+        // treated.
+        case (st: StructType, dt: DataType) if 
st.fieldNames.contains(valueTag) =>
+          val valueIndex = st.fieldNames.indexOf(valueTag)
+          val valueField = st.fields(valueIndex)
+          val valueDataType = compatibleType(caseSensitive, 
valueTag)(valueField.dataType, dt)
+          st.fields(valueIndex) = StructField(valueTag, valueDataType, 
nullable = true)
+          st
+
+        case (dt: DataType, st: StructType) if 
st.fieldNames.contains(valueTag) =>
+          val valueIndex = st.fieldNames.indexOf(valueTag)
+          val valueField = st.fields(valueIndex)
+          val valueDataType = compatibleType(caseSensitive, valueTag)(dt, 
valueField.dataType)
+          st.fields(valueIndex) = StructField(valueTag, valueDataType, 
nullable = true)
+          st
+
+        // The case that given `DecimalType` is capable of given 
`IntegralType` is handled in
+        // `findTightestCommonType`. Both cases below will be executed only 
when the given
+        // `DecimalType` is not capable of the given `IntegralType`.
+        case (t1: IntegralType, t2: DecimalType) =>
+          compatibleType(caseSensitive, valueTag)(DecimalType.forType(t1), t2)
+        case (t1: DecimalType, t2: IntegralType) =>
+          compatibleType(caseSensitive, valueTag)(t1, DecimalType.forType(t2))
+
+        // strings and every string is a XML object.
+        case (_, _) => StringType
+      }
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
index 4b9a95856afb..78f9d5285c23 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
@@ -2127,7 +2127,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
         <string>this is a simple string.</string>
         <integer>10</integer>
         <long>21474836470</long>
-        <decimal>92233720368547758070</decimal>
+        <bigInteger>92233720368547758070</bigInteger>
         <double>1.7976931348623157</double>
         <boolean>true</boolean>
         <null>null</null>
@@ -2137,7 +2137,8 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     val dfWithNodecimal = spark.read
       .option("nullValue", "null")
       .xml(primitiveFieldAndType)
-    assert(dfWithNodecimal.schema("decimal").dataType === DoubleType)
+    assert(dfWithNodecimal.schema("bigInteger").dataType === DoubleType)
+    assert(dfWithNodecimal.schema("double").dataType === DoubleType)
 
     val df = spark.read
       .option("nullValue", "null")
@@ -2145,9 +2146,9 @@ class XmlSuite extends QueryTest with SharedSparkSession {
       .xml(primitiveFieldAndType)
 
     val expectedSchema = StructType(
+      StructField("bigInteger", DecimalType(20, 0), true) ::
       StructField("boolean", BooleanType, true) ::
-      StructField("decimal", DecimalType(20, 0), true) ::
-      StructField("double", DoubleType, true) ::
+      StructField("double", DecimalType(17, 16), true) ::
       StructField("integer", LongType, true) ::
       StructField("long", LongType, true) ::
       StructField("null", StringType, true) ::
@@ -2157,8 +2158,9 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
     checkAnswer(
       df,
-      Row(true,
+      Row(
         new java.math.BigDecimal("92233720368547758070"),
+        true,
         1.7976931348623157,
         10,
         21474836470L,
@@ -2604,4 +2606,83 @@ class XmlSuite extends QueryTest with SharedSparkSession 
{
 
     checkAnswer(df, expectedAns)
   }
+
+  test("Find compatible types even if inferred DecimalType is not capable of 
other IntegralType") {
+    val mixedIntegerAndDoubleRecords = Seq(
+      """<ROW><a>3</a><b>1.1</b></ROW>""",
+      s"""<ROW><a>3.1</a><b>0.${"0" * 38}1</b></ROW>""").toDS()
+    val xmlDF = spark.read
+      .option("prefersDecimal", "true")
+      .option("rowTag", "ROW")
+      .xml(mixedIntegerAndDoubleRecords)
+
+    // The values in `a` field will be decimals as they fit in decimal. For 
`b` field,
+    // they will be doubles as `1.0E-39D` does not fit.
+    val expectedSchema = StructType(
+      StructField("a", DecimalType(21, 1), true) ::
+        StructField("b", DoubleType, true) :: Nil)
+
+    assert(xmlDF.schema === expectedSchema)
+    checkAnswer(
+      xmlDF,
+      Row(BigDecimal("3"), 1.1D) ::
+        Row(BigDecimal("3.1"), 1.0E-39D) :: Nil
+    )
+  }
+
+  def bigIntegerRecords: Dataset[String] =
+    spark.createDataset(spark.sparkContext.parallelize(
+      s"""<ROW><a>1${"0" * 38}</a><b>92233720368547758070</b></ROW>""" :: 
Nil))(Encoders.STRING)
+
+  test("Infer big integers correctly even when it does not fit in decimal") {
+    val df = spark.read
+      .option("rowTag", "ROW")
+      .option("prefersDecimal", "true")
+      .xml(bigIntegerRecords)
+
+    // The value in `a` field will be a double as it does not fit in decimal. 
For `b` field,
+    // it will be a decimal as `92233720368547758070`.
+    val expectedSchema = StructType(
+      StructField("a", DoubleType, true) ::
+        StructField("b", DecimalType(20, 0), true) :: Nil)
+
+    assert(df.schema === expectedSchema)
+    checkAnswer(df, Row(1.0E38D, BigDecimal("92233720368547758070")))
+  }
+
+  def floatingValueRecords: Dataset[String] =
+    spark.createDataset(spark.sparkContext.parallelize(
+      s"""<ROW><a>0.${"0" * 38}1</a><b>.01</b></ROW>""" :: 
Nil))(Encoders.STRING)
+
+  test("Infer floating-point values correctly even when it does not fit in 
decimal") {
+    val df = spark.read
+      .option("prefersDecimal", "true")
+      .option("rowTag", "ROW")
+      .xml(floatingValueRecords)
+
+    // The value in `a` field will be a double as it does not fit in decimal. 
For `b` field,
+    // it will be a decimal as `0.01` by having a precision equal to the scale.
+    val expectedSchema = StructType(
+      StructField("a", DoubleType, true) ::
+        StructField("b", DecimalType(2, 2), true) :: Nil)
+
+    assert(df.schema === expectedSchema)
+    checkAnswer(df, Row(1.0E-39D, BigDecimal("0.01")))
+
+    val mergedDF = spark.read
+      .option("prefersDecimal", "true")
+      .option("rowTag", "ROW")
+      .xml(floatingValueRecords.union(bigIntegerRecords))
+
+    val expectedMergedSchema = StructType(
+      StructField("a", DoubleType, true) ::
+        StructField("b", DecimalType(22, 2), true) :: Nil)
+
+    assert(expectedMergedSchema === mergedDF.schema)
+    checkAnswer(
+      mergedDF,
+      Row(1.0E-39D, BigDecimal("0.01")) ::
+        Row(1.0E38D, BigDecimal("92233720368547758070")) :: Nil
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to