cfmcgrady commented on code in PR #37439:
URL: https://github.com/apache/spark/pull/37439#discussion_r957967041


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala:
##########
@@ -346,6 +293,82 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
     }
   }
 
+  private def simplifyIn[IN <: Expression](
+      fromExp: Expression,
+      toType: NumericType,
+      list: Seq[Expression],
+      buildExpr: (ArrayBuffer[Literal], ArrayBuffer[Literal]) => IN): 
Option[Expression] = {
+
+    // There are 3 kinds of literals in the list:
+    // 1. null literals
+    // 2. The literals that can cast to fromExp.dataType
+    // 3. The literals that cannot cast to fromExp.dataType
+    // Note that:
+    // - null literals is special as we can cast null literals to any data type
+    // - for 3, we have three cases
+    //   1). the literal cannot cast to fromExp.dataType, and there is no 
min/max for the fromType,
+    //     for instance:
+    //         `cast(input[2, decimal(5,2), true] as decimal(10,4)) = 
123456.1234`
+    //   2). the literal value is out of fromType range, for instance:
+    //         `cast(input[0, smallint, true] as bigint) = 2147483647`
+    //   3). the literal value is rounded up/down after casting to `fromType`, 
for instance:
+    //         `cast(input[1, float, true] as double) = 3.14`
+    //     note that 3.14 will be rounded to 3.14000010... after casting to 
float
+
+    val (nullList, canCastList) = (ArrayBuffer[Literal](), 
ArrayBuffer[Literal]())
+    var containsCannotCastLiteral = false
+    val fromType = fromExp.dataType
+    val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
+    val minMaxInToType = getRange(fromType).map {
+      case (min, max) =>
+        (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
+    }
+
+    list.foreach {
+      case lit @ Literal(null, _) => nullList += lit
+      case NonNullLiteral(value, _) =>
+        val minMaxCmp = minMaxInToType.map {
+          case (minInToType, maxInToType) =>
+            (ordering.compare(value, minInToType), ordering.compare(value, 
maxInToType))
+        }
+        minMaxCmp match {
+          // the literal value is out of fromType range
+          case Some((minCmp, maxCmp)) if maxCmp > 0 || minCmp < 0 =>
+            containsCannotCastLiteral = true
+          case _ =>
+            val newValue = Cast(Literal(value), fromType, ansiEnabled = 
false).eval()
+            if (newValue == null) {
+              // the literal cannot cast to fromExp.dataType, and there is no 
min/max for the
+              // fromType
+              containsCannotCastLiteral = true
+            } else {
+              val valueRoundTrip = Cast(Literal(newValue, fromType), 
toType).eval()
+              val cmp = ordering.compare(value, valueRoundTrip)
+              if (cmp == 0) {
+                canCastList += Literal(newValue, fromType)
+              } else {
+                // the literal value is rounded up/down after casting to 
`fromType`
+                containsCannotCastLiteral = true
+              }
+            }
+        }
+    }
+
+    // return None when list contains only null values.
+    if (canCastList.isEmpty && !containsCannotCastLiteral) {
+      None
+    } else {
+      val unwrapExpr = buildExpr(nullList, canCastList)
+      if (!containsCannotCastLiteral) {
+        Option(unwrapExpr)
+      } else {
+        // the list contains a literal that cannot be cast to 
`fromExp.dataType`
+        Option(Or(falseIfNotNull(fromExp), unwrapExpr))

Review Comment:
   1. this rule is similar to optimizing `EqualTo(Cast(fromExp, toType), 
Literal(value, toType))`, when the literal is out of fromType range, the 
EqualTo will be optimized to `falseIfNotNull(fromType)`
   
   ```
   cast(input[0, smallint, true] as int) = 32768  => isnull(input[0, smallint, 
true]) AND null
   ```
   
   2. the `And` expression will be optimized by 
`ReplaceNullWithFalseInPredicate/BooleanSimplification` then pushed down into 
data sources in the next optimizer
   
   
https://github.com/apache/spark/blob/b8f694f643186e91c31ce4420c8567c3e3ecad4e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala#L41-L43
   
   
   ```
   === Applying Rule 
org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison ===
    GlobalLimit 21                                                         
GlobalLimit 21
    +- LocalLimit 21                                                       +- 
LocalLimit 21
       +- Project [cast(a#13 as string) AS a#17]                              
+- Project [cast(a#13 as string) AS a#17]
   !      +- Filter cast(a#13 as decimal(12,2)) IN (100000.00,1.00,2.00)        
 +- Filter ((isnull(a#13) AND null) OR a#13 IN (1,2))
             +- Relation spark_catalog.default.t1[a#13] parquet                 
    +- Relation spark_catalog.default.t1[a#13] parquet
              
   11:13:15.807 WARN org.apache.spark.sql.catalyst.rules.PlanChangeLogger: 
   === Applying Rule 
org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate ===
    GlobalLimit 21                                                GlobalLimit 21
    +- LocalLimit 21                                              +- LocalLimit 
21
       +- Project [cast(a#13 as string) AS a#17]                     +- Project 
[cast(a#13 as string) AS a#17]
   !      +- Filter ((isnull(a#13) AND null) OR a#13 IN (1,2))          +- 
Filter ((isnull(a#13) AND false) OR a#13 IN (1,2))
             +- Relation spark_catalog.default.t1[a#13] parquet            +- 
Relation spark_catalog.default.t1[a#13] parquet
   
   11:13:15.820 WARN org.apache.spark.sql.catalyst.rules.PlanChangeLogger: 
   === Applying Rule 
org.apache.spark.sql.catalyst.optimizer.BooleanSimplification ===
    GlobalLimit 21                                                GlobalLimit 21
    +- LocalLimit 21                                              +- LocalLimit 
21
       +- Project [cast(a#13 as string) AS a#17]                     +- Project 
[cast(a#13 as string) AS a#17]
   !      +- Filter ((isnull(a#13) AND false) OR a#13 IN (1,2))         +- 
Filter a#13 IN (1,2)
             +- Relation spark_catalog.default.t1[a#13] parquet            +- 
Relation spark_catalog.default.t1[a#13] parquet
   ```
   
   > a is of `decimal(3, 0)` type 
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to