cfmcgrady commented on code in PR #37439:
URL: https://github.com/apache/spark/pull/37439#discussion_r957967041
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala:
##########
@@ -346,6 +293,82 @@ object UnwrapCastInBinaryComparison extends
Rule[LogicalPlan] {
}
}
+ private def simplifyIn[IN <: Expression](
+ fromExp: Expression,
+ toType: NumericType,
+ list: Seq[Expression],
+ buildExpr: (ArrayBuffer[Literal], ArrayBuffer[Literal]) => IN):
Option[Expression] = {
+
+ // There are 3 kinds of literals in the list:
+ // 1. null literals
+ // 2. The literals that can cast to fromExp.dataType
+ // 3. The literals that cannot cast to fromExp.dataType
+ // Note that:
+ // - null literals is special as we can cast null literals to any data type
+ // - for 3, we have three cases
+ // 1). the literal cannot cast to fromExp.dataType, and there is no
min/max for the fromType,
+ // for instance:
+ // `cast(input[2, decimal(5,2), true] as decimal(10,4)) =
123456.1234`
+ // 2). the literal value is out of fromType range, for instance:
+ // `cast(input[0, smallint, true] as bigint) = 2147483647`
+ // 3). the literal value is rounded up/down after casting to `fromType`,
for instance:
+ // `cast(input[1, float, true] as double) = 3.14`
+ // note that 3.14 will be rounded to 3.14000010... after casting to
float
+
+ val (nullList, canCastList) = (ArrayBuffer[Literal](),
ArrayBuffer[Literal]())
+ var containsCannotCastLiteral = false
+ val fromType = fromExp.dataType
+ val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
+ val minMaxInToType = getRange(fromType).map {
+ case (min, max) =>
+ (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
+ }
+
+ list.foreach {
+ case lit @ Literal(null, _) => nullList += lit
+ case NonNullLiteral(value, _) =>
+ val minMaxCmp = minMaxInToType.map {
+ case (minInToType, maxInToType) =>
+ (ordering.compare(value, minInToType), ordering.compare(value,
maxInToType))
+ }
+ minMaxCmp match {
+ // the literal value is out of fromType range
+ case Some((minCmp, maxCmp)) if maxCmp > 0 || minCmp < 0 =>
+ containsCannotCastLiteral = true
+ case _ =>
+ val newValue = Cast(Literal(value), fromType, ansiEnabled =
false).eval()
+ if (newValue == null) {
+ // the literal cannot cast to fromExp.dataType, and there is no
min/max for the
+ // fromType
+ containsCannotCastLiteral = true
+ } else {
+ val valueRoundTrip = Cast(Literal(newValue, fromType),
toType).eval()
+ val cmp = ordering.compare(value, valueRoundTrip)
+ if (cmp == 0) {
+ canCastList += Literal(newValue, fromType)
+ } else {
+ // the literal value is rounded up/down after casting to
`fromType`
+ containsCannotCastLiteral = true
+ }
+ }
+ }
+ }
+
+ // return None when list contains only null values.
+ if (canCastList.isEmpty && !containsCannotCastLiteral) {
+ None
+ } else {
+ val unwrapExpr = buildExpr(nullList, canCastList)
+ if (!containsCannotCastLiteral) {
+ Option(unwrapExpr)
+ } else {
+ // the list contains a literal that cannot be cast to
`fromExp.dataType`
+ Option(Or(falseIfNotNull(fromExp), unwrapExpr))
Review Comment:
1. this rule is similar to optimizing `EqualTo(Cast(fromExp, toType),
Literal(value, toType))`, when the literal is out of fromType range, the
EqualTo will be optimized to `falseIfNotNull(fromType)`
```
cast(input[0, smallint, true] as int) = 32768 => isnull(input[0, smallint,
true]) AND null
```
2. the `And` expression will be optimized by
`ReplaceNullWithFalseInPredicate/BooleanSimplification` then pushed down into
data sources in the next optimizer
https://github.com/apache/spark/blob/b8f694f643186e91c31ce4420c8567c3e3ecad4e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala#L41-L43
```
=== Applying Rule
org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison ===
GlobalLimit 21
GlobalLimit 21
+- LocalLimit 21 +-
LocalLimit 21
+- Project [cast(a#13 as string) AS a#17]
+- Project [cast(a#13 as string) AS a#17]
! +- Filter cast(a#13 as decimal(12,2)) IN (100000.00,1.00,2.00)
+- Filter ((isnull(a#13) AND null) OR a#13 IN (1,2))
+- Relation spark_catalog.default.t1[a#13] parquet
+- Relation spark_catalog.default.t1[a#13] parquet
11:13:15.807 WARN org.apache.spark.sql.catalyst.rules.PlanChangeLogger:
=== Applying Rule
org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate ===
GlobalLimit 21 GlobalLimit 21
+- LocalLimit 21 +- LocalLimit
21
+- Project [cast(a#13 as string) AS a#17] +- Project
[cast(a#13 as string) AS a#17]
! +- Filter ((isnull(a#13) AND null) OR a#13 IN (1,2)) +-
Filter ((isnull(a#13) AND false) OR a#13 IN (1,2))
+- Relation spark_catalog.default.t1[a#13] parquet +-
Relation spark_catalog.default.t1[a#13] parquet
11:13:15.820 WARN org.apache.spark.sql.catalyst.rules.PlanChangeLogger:
=== Applying Rule
org.apache.spark.sql.catalyst.optimizer.BooleanSimplification ===
GlobalLimit 21 GlobalLimit 21
+- LocalLimit 21 +- LocalLimit
21
+- Project [cast(a#13 as string) AS a#17] +- Project
[cast(a#13 as string) AS a#17]
! +- Filter ((isnull(a#13) AND false) OR a#13 IN (1,2)) +-
Filter a#13 IN (1,2)
+- Relation spark_catalog.default.t1[a#13] parquet +-
Relation spark_catalog.default.t1[a#13] parquet
```
> a is of `decimal(3, 0)` type
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]