sunchao commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r496883121



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##########
@@ -116,82 +132,118 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
    * optimizes the expression by moving the cast to the literal side. 
Otherwise if result is not
    * true, this replaces the input binary comparison `exp` with simpler 
expressions.
    */
-  private def simplifyIntegralComparison(
+  private def simplifyNumericComparison(
       exp: BinaryComparison,
       fromExp: Expression,
-      toType: IntegralType,
+      toType: NumericType,
       value: Any): Expression = {
 
     val fromType = fromExp.dataType
-    val (min, max) = getRange(fromType)
-    val (minInToType, maxInToType) = {
-      (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
-    }
     val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
-    val minCmp = ordering.compare(value, minInToType)
-    val maxCmp = ordering.compare(value, maxInToType)
+    val range = getRange(fromType)
 
-    if (maxCmp > 0) {
-      exp match {
-        case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
-          falseIfNotNull(fromExp)
-        case LessThan(_, _) | LessThanOrEqual(_, _) =>
-          trueIfNotNull(fromExp)
-        // make sure the expression is evaluated if it is non-deterministic
-        case EqualNullSafe(_, _) if exp.deterministic =>
-          FalseLiteral
-        case _ => exp
+    if (range.isDefined) {
+      val (min, max) = range.get
+      val (minInToType, maxInToType) = {
+        (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
       }
-    } else if (maxCmp == 0) {
-      exp match {
-        case GreaterThan(_, _) =>
-          falseIfNotNull(fromExp)
-        case LessThanOrEqual(_, _) =>
-          trueIfNotNull(fromExp)
-        case LessThan(_, _) =>
-          Not(EqualTo(fromExp, Literal(max, fromType)))
-        case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
-          EqualTo(fromExp, Literal(max, fromType))
-        case EqualNullSafe(_, _) =>
-          EqualNullSafe(fromExp, Literal(max, fromType))
-        case _ => exp
+      val minCmp = ordering.compare(value, minInToType)
+      val maxCmp = ordering.compare(value, maxInToType)
+
+      if (maxCmp >= 0 || minCmp <= 0) {
+        return if (maxCmp > 0) {
+          exp match {
+            case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) 
=>
+              falseIfNotNull(fromExp)
+            case LessThan(_, _) | LessThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            // make sure the expression is evaluated if it is non-deterministic
+            case EqualNullSafe(_, _) if exp.deterministic =>
+              FalseLiteral
+            case _ => exp
+          }
+        } else if (maxCmp == 0) {
+          exp match {
+            case GreaterThan(_, _) =>
+              falseIfNotNull(fromExp)
+            case LessThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            case LessThan(_, _) =>
+              Not(EqualTo(fromExp, Literal(max, fromType)))
+            case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
+              EqualTo(fromExp, Literal(max, fromType))
+            case EqualNullSafe(_, _) =>
+              EqualNullSafe(fromExp, Literal(max, fromType))
+            case _ => exp
+          }
+        } else if (minCmp < 0) {
+          exp match {
+            case GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) =>
+              falseIfNotNull(fromExp)
+            // make sure the expression is evaluated if it is non-deterministic
+            case EqualNullSafe(_, _) if exp.deterministic =>
+              FalseLiteral
+            case _ => exp
+          }
+        } else { // minCmp == 0
+          exp match {
+            case LessThan(_, _) =>
+              falseIfNotNull(fromExp)
+            case GreaterThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            case GreaterThan(_, _) =>
+              Not(EqualTo(fromExp, Literal(min, fromType)))
+            case LessThanOrEqual(_, _) | EqualTo(_, _) =>
+              EqualTo(fromExp, Literal(min, fromType))
+            case EqualNullSafe(_, _) =>
+              EqualNullSafe(fromExp, Literal(min, fromType))
+            case _ => exp
+          }
+        }
       }
-    } else if (minCmp < 0) {
+    }
+
+    // When we reach to this point, it means either there is no min/max for 
the `fromType` (e.g.,
+    // decimal type), or that the literal `value` is within range `(min, 
max)`. For these, we
+    // optimize by moving the cast to the literal side.
+
+    val newValue = Cast(Literal(value), fromType).eval()
+    if (newValue == null) {
+      // This means the cast failed, for instance, due to the value is not 
representable in the
+      // narrower type. In this case we simply return the original expression.
+      return exp
+    }
+    val valueRoundTrip = Cast(Literal(newValue, fromType), toType).eval()

Review comment:
       So double to float can result to either rounding up or down. For 
instance, by casting 3.14 in double to float, even though the value is still 
3.14, the binary representation is rounded up:
   
   3.14 in double:
   ```
   0 10000000000 1001 0001 1110 1011 1000 0101 0001 1110 1011 1000 0101 0001 
1111
   ```
   
   3.14 in float
   ```
   0 10000000 1001 0001 1110 1011 1000 011
   ```
   Here the sign bit and exponent bits (11 and 8 bits respectively for double 
and float) are the same for both float and double. However, in the fraction 
part, the last is rounded up to 1.
   
   After casting back to double, there won't be any rounding up or down - the 
remaining digits are simply padded with 0:
   ```
   0 10000000000 1001 0001 1110 1011 1000 0110 0000000000000000000000000000
   ```
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to