This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cfde117 [SPARK-35316][SQL] UnwrapCastInBinaryComparison support
In/InSet predicate
cfde117 is described below
commit cfde117c6fec758c44f77fe9f9379ba38940b51a
Author: Fu Chen <[email protected]>
AuthorDate: Thu Jun 3 14:45:17 2021 +0000
[SPARK-35316][SQL] UnwrapCastInBinaryComparison support In/InSet predicate
### What changes were proposed in this pull request?
This pr add in/inset predicate support for `UnwrapCastInBinaryComparison`.
Current implement doesn't pushdown filters for `In/InSet` which contains
`Cast`.
For instance:
```scala
spark.range(50).selectExpr("cast(id as int) as
id").write.mode("overwrite").parquet("/tmp/parquet/t1")
spark.read.parquet("/tmp/parquet/t1").where("id in (1L, 2L, 4L)").explain
```
before this pr:
```
== Physical Plan ==
*(1) Filter cast(id#5 as bigint) IN (1,2,4)
+- *(1) ColumnarToRow
+- FileScan parquet [id#5] Batched: true, DataFilters: [cast(id#5 as
bigint) IN (1,2,4)], Format: Parquet, Location: InMemoryFileIndex(1
paths)[file:/tmp/parquet/t1], PartitionFilters: [], PushedFilters: [],
ReadSchema: struct<id:int>
```
after this pr:
```
== Physical Plan ==
*(1) Filter id#95 IN (1,2,4)
+- *(1) ColumnarToRow
+- FileScan parquet [id#95] Batched: true, DataFilters: [id#95 IN
(1,2,4)], Format: Parquet, Location: InMemoryFileIndex(1
paths)[file:/tmp/parquet/t1], PartitionFilters: [], PushedFilters: [In(id,
[1,2,4])], ReadSchema: struct<id:int>
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New test.
Closes #32488 from cfmcgrady/SPARK-35316.
Authored-by: Fu Chen <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../optimizer/UnwrapCastInBinaryComparison.scala | 114 +++++++++++++++++++--
.../UnwrapCastInBinaryComparisonSuite.scala | 50 +++++++++
2 files changed, 156 insertions(+), 8 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index 9f72751..097d810 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -17,19 +17,23 @@
package org.apache.spark.sql.catalyst.optimizer
+import scala.collection.immutable.HashSet
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON
+import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN,
INSET}
import org.apache.spark.sql.types._
/**
- * Unwrap casts in binary comparison operations with patterns like following:
+ * Unwrap casts in binary comparison or `In/InSet` operations with patterns
like following:
*
- * `BinaryComparison(Cast(fromExp, toType), Literal(value, toType))`
- * or
- * `BinaryComparison(Literal(value, toType), Cast(fromExp, toType))`
+ * - `BinaryComparison(Cast(fromExp, toType), Literal(value, toType))`
+ * - `BinaryComparison(Literal(value, toType), Cast(fromExp, toType))`
+ * - `In(Cast(fromExp, toType), Seq(Literal(v1, toType), Literal(v2, toType),
...)`
+ * - `InSet(Cast(fromExp, toType), Set(v1, v2, ...))`
*
* This rule optimizes expressions with the above pattern by either replacing
the cast with simpler
* constructs, or moving the cast from the expression side to the literal
side, which enables them
@@ -86,13 +90,22 @@ import org.apache.spark.sql.types._
* Further, the above `if(isnull(fromExp), null, false)` is represented using
conjunction
* `and(isnull(fromExp), null)`, to enable further optimization and filter
pushdown to data sources.
* Similarly, `if(isnull(fromExp), null, true)` is represented with
`or(isnotnull(fromExp), null)`.
+ *
+ * For `In/InSet` operation, first the rule transform the expression to Equals:
+ * `Seq(
+ * EqualTo(Cast(fromExp, toType), Literal(v1, toType)),
+ * EqualTo(Cast(fromExp, toType), Literal(v2, toType)),
+ * ...
+ * )`
+ * and using the same rule with `BinaryComparison` show as before to optimize
each `EqualTo`.
*/
object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan =
plan.transformWithPruning(
- _.containsPattern(BINARY_COMPARISON), ruleId) {
+ _.containsAnyPattern(BINARY_COMPARISON, IN, INSET), ruleId) {
case l: LogicalPlan =>
-
l.transformExpressionsUpWithPruning(_.containsPattern(BINARY_COMPARISON),
ruleId) {
- case e @ BinaryComparison(_, _) => unwrapCast(e)
+ l.transformExpressionsUpWithPruning(
+ _.containsAnyPattern(BINARY_COMPARISON, IN, INSET), ruleId) {
+ case e @ (BinaryComparison(_, _) | In(_, _) | InSet(_, _)) =>
unwrapCast(e)
}
}
@@ -121,6 +134,91 @@ object UnwrapCastInBinaryComparison extends
Rule[LogicalPlan] {
if canImplicitlyCast(fromExp, toType, literalType) =>
simplifyNumericComparison(be, fromExp, toType, value)
+ // As the analyzer makes sure that the list of In is already of the same
data type, then the
+ // rule can simply check the first literal in `in.list` can implicitly
cast to `toType` or not,
+ // and note that:
+ // 1. this rule doesn't convert in when `in.list` is empty or `in.list`
contains only null
+ // values.
+ // 2. this rule only handles the case when both `fromExp` and value in
`in.list` are of numeric
+ // type.
+ case in @ In(Cast(fromExp, toType: NumericType, _), list @ Seq(firstLit,
_*))
+ if canImplicitlyCast(fromExp, toType, firstLit.dataType) =>
+
+ // There are 3 kinds of literals in the list:
+ // 1. null literals
+ // 2. The literals that can cast to fromExp.dataType
+ // 3. The literals that cannot cast to fromExp.dataType
+ // null literals is special as we can cast null literals to any data
type.
+ val (nullList, canCastList, cannotCastList) =
+ (ArrayBuffer[Literal](), ArrayBuffer[Literal](),
ArrayBuffer[Expression]())
+ list.foreach {
+ case lit @ Literal(null, _) => nullList += lit
+ case lit @ NonNullLiteral(_, _) =>
+ unwrapCast(EqualTo(in.value, lit)) match {
+ case EqualTo(_, unwrapLit: Literal) => canCastList += unwrapLit
+ case e @ And(IsNull(_), Literal(null, BooleanType)) =>
cannotCastList += e
+ case _ => throw new IllegalStateException("Illegal unwrap cast
result found.")
+ }
+ case _ => throw new IllegalStateException("Illegal value found in
in.list.")
+ }
+
+ // return original expression when in.list contains only null values.
+ if (canCastList.isEmpty && cannotCastList.isEmpty) {
+ exp
+ } else {
+ // cast null value to fromExp.dataType, to make sure the new return
list is in the same data
+ // type.
+ val newList = nullList.map(lit => Cast(lit, fromExp.dataType)) ++
canCastList
+ val unwrapIn = In(fromExp, newList.toSeq)
+ cannotCastList.headOption match {
+ case None => unwrapIn
+ // since `cannotCastList` are all the same,
+ // convert to a single value `And(IsNull(_), Literal(null,
BooleanType))`.
+ case Some(falseIfNotNull @ And(IsNull(_), Literal(null,
BooleanType)))
+ if cannotCastList.map(_.canonicalized).distinct.length == 1 =>
+ Or(falseIfNotNull, unwrapIn)
+ case _ => exp
+ }
+ }
+
+ // The same with `In` expression, the analyzer makes sure that the hset of
InSet is already of
+ // the same data type, so simply check `fromExp.dataType` can implicitly
cast to `toType` and
+ // both `fromExp.dataType` and `toType` is numeric type or not.
+ case inSet @ InSet(Cast(fromExp, toType: NumericType, _), hset)
+ if hset.nonEmpty && canImplicitlyCast(fromExp, toType, toType) =>
+
+ // The same with `In`, there are 3 kinds of literals in the hset:
+ // 1. null literals
+ // 2. The literals that can cast to fromExp.dataType
+ // 3. The literals that cannot cast to fromExp.dataType
+ var (nullSet, canCastSet, cannotCastSet) =
+ (HashSet[Any](), HashSet[Any](), HashSet[Expression]())
+ hset.map(value => Literal.create(value, toType))
+ .foreach {
+ case lit @ Literal(null, _) => nullSet += lit.value
+ case lit @ NonNullLiteral(_, _) =>
+ unwrapCast(EqualTo(inSet.child, lit)) match {
+ case EqualTo(_, unwrapLit: Literal) => canCastSet +=
unwrapLit.value
+ case e @ And(IsNull(_), Literal(null, BooleanType)) =>
cannotCastSet += e
+ case _ => throw new IllegalStateException("Illegal unwrap cast
result found.")
+ }
+ case _ => throw new IllegalStateException("Illegal value found in
hset.")
+ }
+
+ if (canCastSet.isEmpty && cannotCastSet.isEmpty) {
+ exp
+ } else {
+ val unwrapInSet = InSet(fromExp, nullSet ++ canCastSet)
+ cannotCastSet.headOption match {
+ case None => unwrapInSet
+ // since `cannotCastList` are all the same,
+ // convert to a single value `And(IsNull(_), Literal(null,
BooleanType))`.
+ case Some(falseIfNotNull @ And(IsNull(_), Literal(null,
BooleanType)))
+ if cannotCastSet.map(_.canonicalized).size == 1 =>
Or(falseIfNotNull, unwrapInSet)
+ case _ => exp
+ }
+ }
+
case _ => exp
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
index 0afb166..e5df1ab 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.optimizer
+import scala.collection.immutable.HashSet
+
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions._
@@ -233,6 +235,54 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest
with ExpressionEvalHelp
assert(getRange(DecimalType(5, 2)).isEmpty)
}
+ test("SPARK-35316: unwrap should support In/InSet predicate.") {
+ val longLit = Literal.create(null, LongType)
+ val intLit = Literal.create(null, IntegerType)
+ val shortLit = Literal.create(null, ShortType)
+
+ def checkInAndInSet(in: In, expected: Expression): Unit = {
+ assertEquivalent(in, expected)
+ val toInSet = (in: In) => InSet(in.value, HashSet() ++
in.list.map(_.eval()))
+ val expectedInSet = expected match {
+ case expectedIn: In =>
+ toInSet(expectedIn)
+ case Or(falseIfNotNull: And, expectedIn: In) =>
+ Or(falseIfNotNull, toInSet(expectedIn))
+ }
+ assertEquivalent(toInSet(in), expectedInSet)
+ }
+
+ checkInAndInSet(
+ In(Cast(f, LongType), Seq(1.toLong, 2.toLong, 3.toLong)),
+ f.in(1.toShort, 2.toShort, 3.toShort))
+
+ // in.list contains the value which out of `fromType` range
+ checkInAndInSet(
+ In(Cast(f, LongType), Seq(1.toLong, Int.MaxValue.toLong, Long.MaxValue)),
+ Or(falseIfNotNull(f), f.in(1.toShort)))
+
+ // in.list only contains the value which out of `fromType` range
+ checkInAndInSet(
+ In(Cast(f, LongType), Seq(Int.MaxValue.toLong, Long.MaxValue)),
+ Or(falseIfNotNull(f), f.in()))
+
+ // in.list is empty
+ checkInAndInSet(
+ In(Cast(f, IntegerType), Seq.empty), Cast(f, IntegerType).in())
+
+ // in.list contains null value
+ checkInAndInSet(
+ In(Cast(f, IntegerType), Seq(intLit)), In(Cast(f, IntegerType),
Seq(intLit)))
+ checkInAndInSet(
+ In(Cast(f, IntegerType), Seq(intLit, intLit)), In(Cast(f, IntegerType),
Seq(intLit, intLit)))
+ checkInAndInSet(
+ In(Cast(f, IntegerType), Seq(intLit, 1)), f.in(shortLit, 1.toShort))
+ checkInAndInSet(
+ In(Cast(f, LongType), Seq(longLit, 1.toLong, Long.MaxValue)),
+ Or(falseIfNotNull(f), f.in(shortLit, 1.toShort))
+ )
+ }
+
private def castInt(e: Expression): Expression = Cast(e, IntegerType)
private def castDouble(e: Expression): Expression = Cast(e, DoubleType)
private def castDecimal2(e: Expression): Expression = Cast(e,
DecimalType(10, 4))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]