cloud-fan commented on code in PR #48596:
URL: https://github.com/apache/spark/pull/48596#discussion_r1821559728


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala:
##########
@@ -0,0 +1,717 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion.PromoteStrings.conf
+import org.apache.spark.sql.catalyst.expressions.{
+  Alias,
+  ArrayJoin,
+  BinaryOperator,
+  CaseWhen,
+  Cast,
+  Coalesce,
+  Concat,
+  CreateArray,
+  CreateMap,
+  DateAdd,
+  DateSub,
+  Elt,
+  ExpectsInputTypes,
+  Expression,
+  Greatest,
+  If,
+  ImplicitCastInputTypes,
+  In,
+  InSubquery,
+  Least,
+  ListQuery,
+  Literal,
+  MapConcat,
+  MapZipWith,
+  NaNvl,
+  RangeFrame,
+  ScalaUDF,
+  Sequence,
+  SpecialFrameBoundary,
+  SpecifiedWindowFrame,
+  SubtractTimestamps,
+  TimeAdd,
+  WindowSpecDefinition
+}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.{AbstractArrayType, 
StringTypeWithCollation}
+import org.apache.spark.sql.types.{
+  AbstractDataType,
+  AnyDataType,
+  AnyTimestampTypeExpression,
+  ArrayType,
+  BinaryType,
+  BooleanType,
+  DataType,
+  DatetimeType,
+  DateType,
+  DateTypeExpression,
+  DecimalType,
+  DoubleType,
+  FloatType,
+  FractionalType,
+  IntegerType,
+  IntegralType,
+  MapType,
+  NullType,
+  StringType,
+  StringTypeExpression,
+  StructType,
+  TimestampNTZType,
+  TimestampType,
+  TimestampTypeExpression
+}
+
+abstract class TypeCoercionHelper {
+
+  /**
+   * A collection of [[Rule]] that can be used to coerce differing types that 
participate in
+   * operations into compatible ones.
+   */
+  def typeCoercionRules: List[Rule[LogicalPlan]]
+
+  /**
+   * Find the tightest common type of two types that might be used in a binary 
expression.
+   * This handles all numeric types except fixed-precision decimals 
interacting with each other or
+   * with primitive types, because in that case the precision and scale of the 
result depends on
+   * the operation. Those rules are implemented in [[DecimalPrecision]].
+   */
+  val findTightestCommonType: (DataType, DataType) => Option[DataType]
+
+  /**
+   * Looking for a widened data type of two given data types with some 
acceptable loss of precision.
+   * E.g. there is no common type for double and decimal because double's range
+   * is larger than decimal, and yet decimal is more precise than double, but 
in
+   * union we would cast the decimal into double.
+   */
+  def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType]
+
+  /**
+   * Looking for a widened data type of a given sequence of data types with 
some acceptable loss
+   * of precision.
+   * E.g. there is no common type for double and decimal because double's range
+   * is larger than decimal, and yet decimal is more precise than double, but 
in
+   * union we would cast the decimal into double.
+   */
+  def findWiderCommonType(types: Seq[DataType]): Option[DataType]
+
+  /**
+   * Given an expected data type, try to cast the expression and return the 
cast expression.
+   *
+   * If the expression already fits the input type, we simply return the 
expression itself.
+   * If the expression has an incompatible type that cannot be implicitly 
cast, return None.
+   */
+  def implicitCast(e: Expression, expectedType: AbstractDataType): 
Option[Expression]
+
+  /**
+   * Whether casting `from` as `to` is valid.
+   */
+  def canCast(from: DataType, to: DataType): Boolean
+
+  protected def findTypeForComplex(
+      t1: DataType,
+      t2: DataType,
+      findTypeFunc: (DataType, DataType) => Option[DataType]): 
Option[DataType] = (t1, t2) match {
+    case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
+      findTypeFunc(et1, et2).map { et =>
+        ArrayType(
+          et,
+          containsNull1 || containsNull2 ||
+          Cast.forceNullable(et1, et) || Cast.forceNullable(et2, et)
+        )
+      }
+    case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, 
valueContainsNull2)) =>
+      findTypeFunc(kt1, kt2)
+        .filter { kt =>
+          !Cast.forceNullable(kt1, kt) && !Cast.forceNullable(kt2, kt)
+        }
+        .flatMap { kt =>
+          findTypeFunc(vt1, vt2).map { vt =>
+            MapType(
+              kt,
+              vt,
+              valueContainsNull1 || valueContainsNull2 ||
+              Cast.forceNullable(vt1, vt) || Cast.forceNullable(vt2, vt)
+            )
+          }
+        }
+    case (StructType(fields1), StructType(fields2)) if fields1.length == 
fields2.length =>
+      val resolver = SQLConf.get.resolver
+      fields1.zip(fields2).foldLeft(Option(new StructType())) {
+        case (Some(struct), (field1, field2)) if resolver(field1.name, 
field2.name) =>
+          findTypeFunc(field1.dataType, field2.dataType).map { dt =>
+            struct.add(
+              field1.name,
+              dt,
+              field1.nullable || field2.nullable ||
+              Cast.forceNullable(field1.dataType, dt) || 
Cast.forceNullable(field2.dataType, dt)
+            )
+          }
+        case _ => None
+      }
+    case _ => None
+  }
+
+  /**
+   * Finds a wider type when one or both types are decimals. If the wider 
decimal type exceeds
+   * system limitation, this rule will truncate the decimal type. If a decimal 
and other fractional
+   * types are compared, returns a double type.
+   */
+  protected def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): 
Option[DataType] = {
+    (dt1, dt2) match {
+      case (t1: DecimalType, t2: DecimalType) =>
+        Some(DecimalPrecision.widerDecimalType(t1, t2))
+      case (t: IntegralType, d: DecimalType) =>
+        Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+      case (d: DecimalType, t: IntegralType) =>
+        Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+      case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: 
FractionalType) =>
+        Some(DoubleType)
+      case _ => None
+    }
+  }
+
+  /**
+   * Similar to [[findWiderTypeForTwo]] that can handle decimal types, but 
can't promote to
+   * string. If the wider decimal type exceeds system limitation, this rule 
will truncate
+   * the decimal type before return it.
+   */
+  private[catalyst] def findWiderTypeWithoutStringPromotionForTwo(
+      t1: DataType,
+      t2: DataType): Option[DataType] = {
+    findTightestCommonType(t1, t2)
+      .orElse(findWiderTypeForDecimal(t1, t2))
+      .orElse(findTypeForComplex(t1, t2, 
findWiderTypeWithoutStringPromotionForTwo))
+  }
+
+  def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): 
Option[DataType] = {
+    types.foldLeft[Option[DataType]](Some(NullType))(
+      (r, c) =>
+        r match {
+          case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c)
+          case None => None
+        }
+    )
+  }
+
+  /**
+   * Check whether the given types are equal ignoring nullable, containsNull 
and valueContainsNull.
+   */
+  def haveSameType(types: Seq[DataType]): Boolean = {
+    if (types.size <= 1) {
+      true
+    } else {
+      val head = types.head
+      types.tail.forall(e => DataTypeUtils.sameType(e, head))
+    }
+  }
+
+  protected def castIfNotSameType(expr: Expression, dt: DataType): Expression 
= {
+    if (!DataTypeUtils.sameType(expr.dataType, dt)) {
+      Cast(expr, dt)
+    } else {
+      expr
+    }
+  }
+
+  protected def findWiderDateTimeType(d1: DatetimeType, d2: DatetimeType): 
DatetimeType =
+    (d1, d2) match {
+      case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) =>
+        TimestampType
+
+      case (_: TimestampType, _: TimestampNTZType) | (_: TimestampNTZType, _: 
TimestampType) =>
+        TimestampType
+
+      case (_: TimestampNTZType, _: DateType) | (_: DateType, _: 
TimestampNTZType) =>
+        TimestampNTZType
+    }
+
+  /**
+   * Type coercion helper that matches agaist [[In]] and [[InSubquery]] 
expressions in order to
+   * type coerce LHS and RHS to expected types.
+   */
+  object InTypeCoercion {
+    def apply(expression: Expression): Expression = expression match {
+      // Handle type casting required between value expression and subquery 
output
+      // in IN subquery.
+      case i @ InSubquery(lhs, l: ListQuery) if !i.resolved && lhs.length == 
l.plan.output.length =>
+        // LHS is the value expressions of IN subquery.
+        // RHS is the subquery output.
+        val rhs = l.plan.output
+
+        val commonTypes = lhs.zip(rhs).flatMap {
+          case (l, r) =>
+            findWiderTypeForTwo(l.dataType, r.dataType)
+        }
+
+        // The number of columns/expressions must match between LHS and RHS of 
an
+        // IN subquery expression.
+        if (commonTypes.length == lhs.length) {
+          val castedRhs = rhs.zip(commonTypes).map {
+            case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
+            case (e, _) => e
+          }
+          val newLhs = lhs.zip(commonTypes).map {
+            case (e, dt) if e.dataType != dt => Cast(e, dt)
+            case (e, _) => e
+          }
+
+          InSubquery(newLhs, l.withNewPlan(Project(castedRhs, l.plan)))
+        } else {
+          i
+        }
+
+      case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
+        findWiderCommonType(i.children.map(_.dataType)) match {
+          case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, 
finalDataType)))
+          case None => i
+        }
+
+      case other => other
+    }
+  }
+
+  /**
+   * Type coercion helper that matches against function expression in order to 
type coerce function
+   * argument types to expected types.
+   */
+  object FunctionArgumentTypeCoercion {

Review Comment:
   why are they in the helper instead of individual files? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to