sandip-db commented on code in PR #44601:
URL: https://github.com/apache/spark/pull/44601#discussion_r1443648967


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala:
##########
@@ -596,11 +492,102 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: 
Boolean)
       // If the field name already exists,
       // merge the type and infer the combined field as an array type if 
necessary
       case Some(oldType) if !oldType.isInstanceOf[ArrayType] && 
!newType.isInstanceOf[NullType] =>
-        ArrayType(compatibleType(oldType, newType))
+        ArrayType(compatibleType(caseSensitive, options.valueTag)(oldType, 
newType))
       case Some(oldType) =>
-        compatibleType(oldType, newType)
+        compatibleType(caseSensitive, options.valueTag)(oldType, newType)
       case None =>
         newType
     }
   }
 }
+
+object XmlInferSchema {
+  def normalize(name: String, caseSensitive: Boolean): String = {
+    if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
+  }
+
+  /**
+   * Returns the most general data type for two given data types.
+   */
+  private[xml] def compatibleType(caseSensitive: Boolean, valueTag: String)
+    (t1: DataType, t2: DataType): DataType = {
+
+    // TODO: Optimise this logic.
+    TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
+      // t1 or t2 is a StructType, ArrayType, or an unexpected type.
+      (t1, t2) match {
+        // Double support larger range than fixed decimal, DecimalType.Maximum 
should be enough
+        // in most case, also have better precision.
+        case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
+          DoubleType
+
+        case (t1: DecimalType, t2: DecimalType) =>
+          val scale = math.max(t1.scale, t2.scale)
+          val range = math.max(t1.precision - t1.scale, t2.precision - 
t2.scale)
+          if (range + scale > 38) {
+            // DecimalType can't support precision > 38
+            DoubleType
+          } else {
+            DecimalType(range + scale, scale)
+          }
+        case (TimestampNTZType, TimestampType) | (TimestampType, 
TimestampNTZType) =>
+          TimestampType
+
+        case (StructType(fields1), StructType(fields2)) =>
+          val newFields = (fields1 ++ fields2)
+           // normalize field name and pair it with original field
+           .map(field => (normalize(field.name, caseSensitive), field))
+           .groupBy(_._1) // group by normalized field name
+           .map { case (_: String, fields: Array[(String, StructField)]) =>
+             val fieldTypes = fields.map(_._2)
+             val dataType = fieldTypes.map(_.dataType)
+               .reduce(compatibleType(caseSensitive, valueTag))
+             // we pick up the first field name that we've encountered for the 
field
+             StructField(fields.head._2.name, dataType)
+           }
+           StructType(newFields.toArray.sortBy(_.name))
+
+        case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, 
containsNull2)) =>
+          ArrayType(
+            compatibleType(caseSensitive, valueTag)(
+              elementType1, elementType2), containsNull1 || containsNull2)
+
+        // In XML datasource, since StructType can be compared with ArrayType.
+        // In this case, ArrayType wraps the StructType.
+        case (ArrayType(ty1, _), ty2) =>
+          ArrayType(compatibleType(caseSensitive, valueTag)(ty1, ty2))
+
+        case (ty1, ArrayType(ty2, _)) =>
+          ArrayType(compatibleType(caseSensitive, valueTag)(ty1, ty2))
+
+        // As this library can infer an element with attributes as StructType 
whereas
+        // some can be inferred as other non-structural data types, this case 
should be
+        // treated.
+        case (st: StructType, dt: DataType) if 
st.fieldNames.contains(valueTag) =>
+          val valueIndex = st.fieldNames.indexOf(valueTag)
+          val valueField = st.fields(valueIndex)
+          val valueDataType = compatibleType(caseSensitive, 
valueTag)(valueField.dataType, dt)
+          st.fields(valueIndex) = StructField(valueTag, valueDataType, 
nullable = true)
+          st
+
+        case (dt: DataType, st: StructType) if 
st.fieldNames.contains(valueTag) =>
+          val valueIndex = st.fieldNames.indexOf(valueTag)
+          val valueField = st.fields(valueIndex)
+          val valueDataType = compatibleType(caseSensitive, valueTag)(dt, 
valueField.dataType)
+          st.fields(valueIndex) = StructField(valueTag, valueDataType, 
nullable = true)
+          st
+
+        // The case that given `DecimalType` is capable of given 
`IntegralType` is handled in
+        // `findTightestCommonType`. Both cases below will be executed only 
when the given
+        // `DecimalType` is not capable of the given `IntegralType`.
+        case (t1: IntegralType, t2: DecimalType) =>
+          compatibleType(caseSensitive, valueTag)(DecimalType.forType(t1), t2)

Review Comment:
   Added new test cases.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to