This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cfde117  [SPARK-35316][SQL] UnwrapCastInBinaryComparison support 
In/InSet predicate
cfde117 is described below

commit cfde117c6fec758c44f77fe9f9379ba38940b51a
Author: Fu Chen <[email protected]>
AuthorDate: Thu Jun 3 14:45:17 2021 +0000

    [SPARK-35316][SQL] UnwrapCastInBinaryComparison support In/InSet predicate
    
    ### What changes were proposed in this pull request?
    
    This pr add in/inset predicate support for `UnwrapCastInBinaryComparison`.
    
    Current implement doesn't pushdown filters for `In/InSet` which contains 
`Cast`.
    
    For instance:
    
    ```scala
    spark.range(50).selectExpr("cast(id as int) as 
id").write.mode("overwrite").parquet("/tmp/parquet/t1")
    spark.read.parquet("/tmp/parquet/t1").where("id in (1L, 2L, 4L)").explain
    ```
    
    before this pr:
    
    ```
    == Physical Plan ==
    *(1) Filter cast(id#5 as bigint) IN (1,2,4)
    +- *(1) ColumnarToRow
       +- FileScan parquet [id#5] Batched: true, DataFilters: [cast(id#5 as 
bigint) IN (1,2,4)], Format: Parquet, Location: InMemoryFileIndex(1 
paths)[file:/tmp/parquet/t1], PartitionFilters: [], PushedFilters: [], 
ReadSchema: struct<id:int>
    ```
    
    after this pr:
    
    ```
    == Physical Plan ==
    *(1) Filter id#95 IN (1,2,4)
    +- *(1) ColumnarToRow
       +- FileScan parquet [id#95] Batched: true, DataFilters: [id#95 IN 
(1,2,4)], Format: Parquet, Location: InMemoryFileIndex(1 
paths)[file:/tmp/parquet/t1], PartitionFilters: [], PushedFilters: [In(id, 
[1,2,4])], ReadSchema: struct<id:int>
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    ### How was this patch tested?
    
    New test.
    
    Closes #32488 from cfmcgrady/SPARK-35316.
    
    Authored-by: Fu Chen <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../optimizer/UnwrapCastInBinaryComparison.scala   | 114 +++++++++++++++++++--
 .../UnwrapCastInBinaryComparisonSuite.scala        |  50 +++++++++
 2 files changed, 156 insertions(+), 8 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index 9f72751..097d810 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -17,19 +17,23 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import scala.collection.immutable.HashSet
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON
+import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN, 
INSET}
 import org.apache.spark.sql.types._
 
 /**
- * Unwrap casts in binary comparison operations with patterns like following:
+ * Unwrap casts in binary comparison or `In/InSet` operations with patterns 
like following:
  *
- * `BinaryComparison(Cast(fromExp, toType), Literal(value, toType))`
- *   or
- * `BinaryComparison(Literal(value, toType), Cast(fromExp, toType))`
+ * - `BinaryComparison(Cast(fromExp, toType), Literal(value, toType))`
+ * - `BinaryComparison(Literal(value, toType), Cast(fromExp, toType))`
+ * - `In(Cast(fromExp, toType), Seq(Literal(v1, toType), Literal(v2, toType), 
...)`
+ * - `InSet(Cast(fromExp, toType), Set(v1, v2, ...))`
  *
  * This rule optimizes expressions with the above pattern by either replacing 
the cast with simpler
  * constructs, or moving the cast from the expression side to the literal 
side, which enables them
@@ -86,13 +90,22 @@ import org.apache.spark.sql.types._
  * Further, the above `if(isnull(fromExp), null, false)` is represented using 
conjunction
  * `and(isnull(fromExp), null)`, to enable further optimization and filter 
pushdown to data sources.
  * Similarly, `if(isnull(fromExp), null, true)` is represented with 
`or(isnotnull(fromExp), null)`.
+ *
+ * For `In/InSet` operation, first the rule transform the expression to Equals:
+ * `Seq(
+ *   EqualTo(Cast(fromExp, toType), Literal(v1, toType)),
+ *   EqualTo(Cast(fromExp, toType), Literal(v2, toType)),
+ *   ...
+ * )`
+ * and using the same rule with `BinaryComparison` show as before to optimize 
each `EqualTo`.
  */
 object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = 
plan.transformWithPruning(
-    _.containsPattern(BINARY_COMPARISON), ruleId) {
+    _.containsAnyPattern(BINARY_COMPARISON, IN, INSET), ruleId) {
     case l: LogicalPlan =>
-      
l.transformExpressionsUpWithPruning(_.containsPattern(BINARY_COMPARISON), 
ruleId) {
-        case e @ BinaryComparison(_, _) => unwrapCast(e)
+      l.transformExpressionsUpWithPruning(
+        _.containsAnyPattern(BINARY_COMPARISON, IN, INSET), ruleId) {
+        case e @ (BinaryComparison(_, _) | In(_, _) | InSet(_, _)) => 
unwrapCast(e)
       }
   }
 
@@ -121,6 +134,91 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
         if canImplicitlyCast(fromExp, toType, literalType) =>
       simplifyNumericComparison(be, fromExp, toType, value)
 
+    // As the analyzer makes sure that the list of In is already of the same 
data type, then the
+    // rule can simply check the first literal in `in.list` can implicitly 
cast to `toType` or not,
+    // and note that:
+    // 1. this rule doesn't convert in when `in.list` is empty or `in.list` 
contains only null
+    // values.
+    // 2. this rule only handles the case when both `fromExp` and value in 
`in.list` are of numeric
+    // type.
+    case in @ In(Cast(fromExp, toType: NumericType, _), list @ Seq(firstLit, 
_*))
+      if canImplicitlyCast(fromExp, toType, firstLit.dataType) =>
+
+      // There are 3 kinds of literals in the list:
+      // 1. null literals
+      // 2. The literals that can cast to fromExp.dataType
+      // 3. The literals that cannot cast to fromExp.dataType
+      // null literals is special as we can cast null literals to any data 
type.
+      val (nullList, canCastList, cannotCastList) =
+        (ArrayBuffer[Literal](), ArrayBuffer[Literal](), 
ArrayBuffer[Expression]())
+      list.foreach {
+        case lit @ Literal(null, _) => nullList += lit
+        case lit @ NonNullLiteral(_, _) =>
+          unwrapCast(EqualTo(in.value, lit)) match {
+            case EqualTo(_, unwrapLit: Literal) => canCastList += unwrapLit
+            case e @ And(IsNull(_), Literal(null, BooleanType)) => 
cannotCastList += e
+            case _ => throw new IllegalStateException("Illegal unwrap cast 
result found.")
+          }
+        case _ => throw new IllegalStateException("Illegal value found in 
in.list.")
+      }
+
+      // return original expression when in.list contains only null values.
+      if (canCastList.isEmpty && cannotCastList.isEmpty) {
+        exp
+      } else {
+        // cast null value to fromExp.dataType, to make sure the new return 
list is in the same data
+        // type.
+        val newList = nullList.map(lit => Cast(lit, fromExp.dataType)) ++ 
canCastList
+        val unwrapIn = In(fromExp, newList.toSeq)
+        cannotCastList.headOption match {
+          case None => unwrapIn
+          // since `cannotCastList` are all the same,
+          // convert to a single value `And(IsNull(_), Literal(null, 
BooleanType))`.
+          case Some(falseIfNotNull @ And(IsNull(_), Literal(null, 
BooleanType)))
+              if cannotCastList.map(_.canonicalized).distinct.length == 1 =>
+            Or(falseIfNotNull, unwrapIn)
+          case _ => exp
+        }
+      }
+
+    // The same with `In` expression, the analyzer makes sure that the hset of 
InSet is already of
+    // the same data type, so simply check `fromExp.dataType` can implicitly 
cast to `toType` and
+    // both `fromExp.dataType` and `toType` is numeric type or not.
+    case inSet @ InSet(Cast(fromExp, toType: NumericType, _), hset)
+      if hset.nonEmpty && canImplicitlyCast(fromExp, toType, toType) =>
+
+      // The same with `In`, there are 3 kinds of literals in the hset:
+      // 1. null literals
+      // 2. The literals that can cast to fromExp.dataType
+      // 3. The literals that cannot cast to fromExp.dataType
+      var (nullSet, canCastSet, cannotCastSet) =
+        (HashSet[Any](), HashSet[Any](), HashSet[Expression]())
+      hset.map(value => Literal.create(value, toType))
+        .foreach {
+          case lit @ Literal(null, _) => nullSet += lit.value
+          case lit @ NonNullLiteral(_, _) =>
+            unwrapCast(EqualTo(inSet.child, lit)) match {
+              case EqualTo(_, unwrapLit: Literal) => canCastSet += 
unwrapLit.value
+              case e @ And(IsNull(_), Literal(null, BooleanType)) => 
cannotCastSet += e
+              case _ => throw new IllegalStateException("Illegal unwrap cast 
result found.")
+            }
+          case _ => throw new IllegalStateException("Illegal value found in 
hset.")
+        }
+
+      if (canCastSet.isEmpty && cannotCastSet.isEmpty) {
+        exp
+      } else {
+        val unwrapInSet = InSet(fromExp, nullSet ++ canCastSet)
+        cannotCastSet.headOption match {
+          case None => unwrapInSet
+          // since `cannotCastList` are all the same,
+          // convert to a single value `And(IsNull(_), Literal(null, 
BooleanType))`.
+          case Some(falseIfNotNull @ And(IsNull(_), Literal(null, 
BooleanType)))
+            if cannotCastSet.map(_.canonicalized).size == 1 => 
Or(falseIfNotNull, unwrapInSet)
+          case _ => exp
+        }
+      }
+
     case _ => exp
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
index 0afb166..e5df1ab 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import scala.collection.immutable.HashSet
+
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
 import org.apache.spark.sql.catalyst.expressions._
@@ -233,6 +235,54 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
     assert(getRange(DecimalType(5, 2)).isEmpty)
   }
 
+  test("SPARK-35316: unwrap should support In/InSet predicate.") {
+    val longLit = Literal.create(null, LongType)
+    val intLit = Literal.create(null, IntegerType)
+    val shortLit = Literal.create(null, ShortType)
+
+    def checkInAndInSet(in: In, expected: Expression): Unit = {
+      assertEquivalent(in, expected)
+      val toInSet = (in: In) => InSet(in.value, HashSet() ++ 
in.list.map(_.eval()))
+      val expectedInSet = expected match {
+        case expectedIn: In =>
+          toInSet(expectedIn)
+        case Or(falseIfNotNull: And, expectedIn: In) =>
+          Or(falseIfNotNull, toInSet(expectedIn))
+      }
+      assertEquivalent(toInSet(in), expectedInSet)
+    }
+
+    checkInAndInSet(
+      In(Cast(f, LongType), Seq(1.toLong, 2.toLong, 3.toLong)),
+      f.in(1.toShort, 2.toShort, 3.toShort))
+
+    // in.list contains the value which out of `fromType` range
+    checkInAndInSet(
+      In(Cast(f, LongType), Seq(1.toLong, Int.MaxValue.toLong, Long.MaxValue)),
+      Or(falseIfNotNull(f), f.in(1.toShort)))
+
+    // in.list only contains the value which out of `fromType` range
+    checkInAndInSet(
+      In(Cast(f, LongType), Seq(Int.MaxValue.toLong, Long.MaxValue)),
+      Or(falseIfNotNull(f), f.in()))
+
+    // in.list is empty
+    checkInAndInSet(
+      In(Cast(f, IntegerType), Seq.empty), Cast(f, IntegerType).in())
+
+    // in.list contains null value
+    checkInAndInSet(
+      In(Cast(f, IntegerType), Seq(intLit)), In(Cast(f, IntegerType), 
Seq(intLit)))
+    checkInAndInSet(
+      In(Cast(f, IntegerType), Seq(intLit, intLit)), In(Cast(f, IntegerType), 
Seq(intLit, intLit)))
+    checkInAndInSet(
+      In(Cast(f, IntegerType), Seq(intLit, 1)), f.in(shortLit, 1.toShort))
+    checkInAndInSet(
+      In(Cast(f, LongType), Seq(longLit, 1.toLong, Long.MaxValue)),
+      Or(falseIfNotNull(f), f.in(shortLit, 1.toShort))
+    )
+  }
+
   private def castInt(e: Expression): Expression = Cast(e, IntegerType)
   private def castDouble(e: Expression): Expression = Cast(e, DoubleType)
   private def castDecimal2(e: Expression): Expression = Cast(e, 
DecimalType(10, 4))

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to