cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r500137824



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##########
@@ -116,82 +132,118 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
    * optimizes the expression by moving the cast to the literal side. 
Otherwise if result is not
    * true, this replaces the input binary comparison `exp` with simpler 
expressions.
    */
-  private def simplifyIntegralComparison(
+  private def simplifyNumericComparison(
       exp: BinaryComparison,
       fromExp: Expression,
-      toType: IntegralType,
+      toType: NumericType,
       value: Any): Expression = {
 
     val fromType = fromExp.dataType
-    val (min, max) = getRange(fromType)
-    val (minInToType, maxInToType) = {
-      (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
-    }
     val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
-    val minCmp = ordering.compare(value, minInToType)
-    val maxCmp = ordering.compare(value, maxInToType)
+    val range = getRange(fromType)
 
-    if (maxCmp > 0) {
-      exp match {
-        case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
-          falseIfNotNull(fromExp)
-        case LessThan(_, _) | LessThanOrEqual(_, _) =>
-          trueIfNotNull(fromExp)
-        // make sure the expression is evaluated if it is non-deterministic
-        case EqualNullSafe(_, _) if exp.deterministic =>
-          FalseLiteral
-        case _ => exp
+    if (range.isDefined) {
+      val (min, max) = range.get
+      val (minInToType, maxInToType) = {
+        (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
       }
-    } else if (maxCmp == 0) {
-      exp match {
-        case GreaterThan(_, _) =>
-          falseIfNotNull(fromExp)
-        case LessThanOrEqual(_, _) =>
-          trueIfNotNull(fromExp)
-        case LessThan(_, _) =>
-          Not(EqualTo(fromExp, Literal(max, fromType)))
-        case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
-          EqualTo(fromExp, Literal(max, fromType))
-        case EqualNullSafe(_, _) =>
-          EqualNullSafe(fromExp, Literal(max, fromType))
-        case _ => exp
+      val minCmp = ordering.compare(value, minInToType)
+      val maxCmp = ordering.compare(value, maxInToType)
+
+      if (maxCmp >= 0 || minCmp <= 0) {
+        return if (maxCmp > 0) {
+          exp match {
+            case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) 
=>
+              falseIfNotNull(fromExp)
+            case LessThan(_, _) | LessThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            // make sure the expression is evaluated if it is non-deterministic
+            case EqualNullSafe(_, _) if exp.deterministic =>
+              FalseLiteral
+            case _ => exp
+          }
+        } else if (maxCmp == 0) {
+          exp match {
+            case GreaterThan(_, _) =>
+              falseIfNotNull(fromExp)
+            case LessThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            case LessThan(_, _) =>
+              Not(EqualTo(fromExp, Literal(max, fromType)))
+            case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
+              EqualTo(fromExp, Literal(max, fromType))
+            case EqualNullSafe(_, _) =>
+              EqualNullSafe(fromExp, Literal(max, fromType))
+            case _ => exp
+          }
+        } else if (minCmp < 0) {
+          exp match {
+            case GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) =>
+              falseIfNotNull(fromExp)
+            // make sure the expression is evaluated if it is non-deterministic
+            case EqualNullSafe(_, _) if exp.deterministic =>
+              FalseLiteral
+            case _ => exp
+          }
+        } else { // minCmp == 0
+          exp match {
+            case LessThan(_, _) =>
+              falseIfNotNull(fromExp)
+            case GreaterThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            case GreaterThan(_, _) =>
+              Not(EqualTo(fromExp, Literal(min, fromType)))
+            case LessThanOrEqual(_, _) | EqualTo(_, _) =>
+              EqualTo(fromExp, Literal(min, fromType))
+            case EqualNullSafe(_, _) =>
+              EqualNullSafe(fromExp, Literal(min, fromType))
+            case _ => exp
+          }
+        }
       }
-    } else if (minCmp < 0) {
+    }
+
+    // When we reach to this point, it means either there is no min/max for 
the `fromType` (e.g.,
+    // decimal type), or that the literal `value` is within range `(min, 
max)`. For these, we

Review comment:
       makes sense.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to