cfmcgrady commented on a change in pull request #32488:
URL: https://github.com/apache/spark/pull/32488#discussion_r642853036



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##########
@@ -121,6 +131,77 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
         if canImplicitlyCast(fromExp, toType, literalType) =>
       simplifyNumericComparison(be, fromExp, toType, value)
 
+    // As the analyzer makes sure that the list of In is already of the same 
data type, then the
+    // rule can simply check the first literal in `in.list` can implicitly 
cast to `toType` or not,
+    // and note that:
+    // 1. this rule doesn't convert in when `in.list` is empty.
+    // 2. this rule only handles the case when both `fromExp` and value in 
`in.list` are of numeric
+    // type.
+    case in @ In(Cast(fromExp, toType: NumericType, _), list @ Seq(firstLit, 
_*))
+        if canImplicitlyCast(fromExp, toType, firstLit.dataType) && 
in.inSetConvertible =>
+      val (newValueList, expr) =
+        list.map(lit => unwrapCast(EqualTo(in.value, lit)))
+          .partition {
+            case EqualTo(_, _: Literal) => true
+            case And(IsNull(_), Literal(null, BooleanType)) => false
+            case _ => throw new IllegalStateException("Illegal unwrap cast 
result found.")
+          }
+
+      val (nonNullValueList, nullValueList) = newValueList.partition {
+        case EqualTo(_, NonNullLiteral(_, _: NumericType)) => true
+        case EqualTo(_, Literal(null, _)) => false
+        case _ => throw new IllegalStateException("Illegal unwrap cast result 
found.")
+      }
+      // make sure the new return list have the same dataType.
+      val newList = {
+        if (nonNullValueList.nonEmpty) {
+          // cast the null value to the dataType of nonNullValueList
+          // when the nonNullValueList is nonEmpty.
+          nullValueList.map {
+            case EqualTo(_, lit) =>
+              Cast(lit, 
nonNullValueList.head.asInstanceOf[EqualTo].left.dataType)

Review comment:
       For instance:
   
   ```
   // x is type of short
   // x in (null, 1)
   In(Cast(x, IntegerType), Seq(Literal(null, IntegerType), Literal(1, 
IntegerType)))
   ```
   
   The `unwrapCast` unwrap `EqualTo(x, literal)` return original expression 
when `literal == null`
   
   
https://github.com/apache/spark/blob/fe09def3231600e52cb58aaba5f72af33ab4dc33/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala#L210-L215
   
   ```
   EqualTo(x, Literal(null, IntegerType)) => EqualTo(x, Literal(null, 
IntegerType))
   ```
   
   but moving the cast from the expression side to the literal side when 
`literal == 1`
   
   ```
   EqualTo(x, Literal(1, IntegerType)) => EqualTo(x, Literal(1, ShortType))
   ```
   
   Note that the type of two part is different, and we transform the data type 
(IntegerType) of `nullValueList` to `nonNullValueList` data type(ShotType), to 
make sure all of the values in the return in.list are in the same datatype.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to