cloud-fan commented on code in PR #37439:
URL: https://github.com/apache/spark/pull/37439#discussion_r955837238
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala:
##########
@@ -346,6 +293,83 @@ object UnwrapCastInBinaryComparison extends
Rule[LogicalPlan] {
}
}
+ private def simplifyIn[IN <: Expression](
+ fromExp: Expression,
+ toType: NumericType,
+ list: Seq[Expression],
+ buildExpr: (ArrayBuffer[Literal], ArrayBuffer[Literal]) => IN):
Option[Expression] = {
+
+ // There are 3 kinds of literals in the list:
+ // 1. null literals
+ // 2. The literals that can cast to fromExp.dataType
+ // 3. The literals that cannot cast to fromExp.dataType
+ // Note that:
+ // - null literals is special as we can cast null literals to any data type
+ // - for 3, we have three cases
+ // 1). the literal cannot cast to fromExp.dataType, and there is no
min/max for the fromType,
+ // for instance:
+ // `cast(input[2, decimal(5,2), true] as decimal(10,4)) =
123456.1234`
+ // 2). the literal value is out of fromType range, for instance:
+ // `cast(input[0, smallint, true] as bigint) = 2147483647`
+ // 3). the literal value is rounded up/down after casting to `fromType`,
for instance:
+ // `cast(input[1, float, true] as double) = 3.14`
+ // note that 3.14 will be rounded to 3.14000010... after casting to
float
+
+ val (nullList, canCastList) = (ArrayBuffer[Literal](),
ArrayBuffer[Literal]())
+ var cannotCastCounter = 0
+ val fromType = fromExp.dataType
+ val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
+ val minMaxInToType = getRange(fromType).map {
+ case (min, max) =>
+ (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
+ }
+
+ list.foreach {
+ case lit @ Literal(null, _) => nullList += lit
+ case NonNullLiteral(value, _) =>
+ val minMaxCmp = minMaxInToType.map {
+ case (minInToType, maxInToType) =>
+ (ordering.compare(value, minInToType), ordering.compare(value,
maxInToType))
+ }
+ minMaxCmp match {
+ // the literal value is out of fromType range
+ case Some((minCmp, maxCmp)) if maxCmp > 0 || minCmp < 0 =>
+ cannotCastCounter += 1
+ case _ =>
+ val newValue = Cast(Literal(value), fromType, ansiEnabled =
false).eval()
+ if (newValue == null) {
+ // the literal cannot cast to fromExp.dataType, and there is no
min/max for the
+ // fromType
+ cannotCastCounter += 1
+ } else {
+ val valueRoundTrip = Cast(Literal(newValue, fromType),
toType).eval()
+ val cmp = ordering.compare(value, valueRoundTrip)
+ if (cmp == 0) {
+ canCastList += Literal(newValue, fromType)
+ } else {
+ // the literal value is rounded up/down after casting to
`fromType`
+ cannotCastCounter += 1
+ }
+ }
+ }
+ }
+
+ // return None when list contains only null values.
+ if (canCastList.isEmpty && cannotCastCounter == 0) {
+ None
+ } else {
+ val unwrapExpr = buildExpr(nullList, canCastList)
+ if (cannotCastCounter == 0) {
+ Option(unwrapExpr)
+ } else {
+ // since can not cast literals are all transformed to the same
`falseIfNotNull(fromExp)`,
Review Comment:
need to update the comment
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]