Github user mn-mikke commented on a diff in the pull request:
https://github.com/apache/spark/pull/21687#discussion_r200851845
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
---
@@ -695,6 +695,56 @@ abstract class TernaryExpression extends Expression {
}
}
+/**
+ * A trait resolving nullable, containsNull, valueContainsNull flags of
the output date type.
+ * This logic is usually utilized by expressions combining data from
multiple child expressions
+ * of non-primitive types (e.g. [[CaseWhen]]).
+ */
+trait NonPrimitiveTypeMergingExpression extends Expression
+{
+ /**
+ * A collection of data types used for resolution the output type of the
expression. By default,
+ * data types of all child expressions. The collection must not be empty.
+ */
+ @transient
+ lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType)
+
+ /**
+ * A method determining whether the input types are equal ignoring
nullable, containsNull and
+ * valueContainsNull flags and thus convenient for resolution of the
final data type.
+ */
+ def areInputTypesForMergingEqual: Boolean = {
+ inputTypesForMerging.lengthCompare(1) <= 0 ||
inputTypesForMerging.sliding(2, 1).forall {
+ case Seq(dt1, dt2) => dt1.sameType(dt2)
+ }
+ }
+
+ private def mergeTwoDataTypes(dt1: DataType, dt2: DataType): DataType =
(dt1, dt2) match {
+ case (t1, t2) if t1 == t2 => t1
+ case (ArrayType(et1, cn1), ArrayType(et2, cn2)) =>
--- End diff --
That sounds like a good idea! Some parts like ```fields1.length ==
fields2.length``` and ```resolver(field1.name, field2.name)``` seem to be extra
in this context, but it could work.
@ueshin WDYT about making ```findTypeForComplex``` public?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]