This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3c16fae [SPARK-31027][SQL] Refactor DataSourceStrategy to be more
extendable
3c16fae is described below
commit 3c16fae5c1369387d3730ce9851cb5e6fbd229c7
Author: DB Tsai <[email protected]>
AuthorDate: Wed Mar 4 23:41:49 2020 +0900
[SPARK-31027][SQL] Refactor DataSourceStrategy to be more extendable
### What changes were proposed in this pull request?
Refactor `DataSourceStrategy.scala` and `DataSourceStrategySuite.scala` so
it's more extendable to implement nested predicate pushdown.
### Why are the changes needed?
To support nested predicate pushdown.
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Existing tests and new tests.
Closes #27778 from dbtsai/SPARK-31027.
Authored-by: DB Tsai <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
---
.../execution/datasources/DataSourceStrategy.scala | 105 ++++++++++--------
.../datasources/DataSourceStrategySuite.scala | 121 +++++++++++++--------
2 files changed, 133 insertions(+), 93 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2d902b5..1641b66 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -449,60 +449,60 @@ object DataSourceStrategy {
}
private def translateLeafNodeFilter(predicate: Expression): Option[Filter] =
predicate match {
- case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
- Some(sources.EqualTo(a.name, convertToScala(v, t)))
- case expressions.EqualTo(Literal(v, t), a: Attribute) =>
- Some(sources.EqualTo(a.name, convertToScala(v, t)))
-
- case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) =>
- Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
- case expressions.EqualNullSafe(Literal(v, t), a: Attribute) =>
- Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
-
- case expressions.GreaterThan(a: Attribute, Literal(v, t)) =>
- Some(sources.GreaterThan(a.name, convertToScala(v, t)))
- case expressions.GreaterThan(Literal(v, t), a: Attribute) =>
- Some(sources.LessThan(a.name, convertToScala(v, t)))
-
- case expressions.LessThan(a: Attribute, Literal(v, t)) =>
- Some(sources.LessThan(a.name, convertToScala(v, t)))
- case expressions.LessThan(Literal(v, t), a: Attribute) =>
- Some(sources.GreaterThan(a.name, convertToScala(v, t)))
-
- case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) =>
- Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
- case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) =>
- Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
-
- case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) =>
- Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
- case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) =>
- Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
-
- case expressions.InSet(a: Attribute, set) =>
- val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
- Some(sources.In(a.name, set.toArray.map(toScala)))
+ case expressions.EqualTo(PushableColumn(name), Literal(v, t)) =>
+ Some(sources.EqualTo(name, convertToScala(v, t)))
+ case expressions.EqualTo(Literal(v, t), PushableColumn(name)) =>
+ Some(sources.EqualTo(name, convertToScala(v, t)))
+
+ case expressions.EqualNullSafe(PushableColumn(name), Literal(v, t)) =>
+ Some(sources.EqualNullSafe(name, convertToScala(v, t)))
+ case expressions.EqualNullSafe(Literal(v, t), PushableColumn(name)) =>
+ Some(sources.EqualNullSafe(name, convertToScala(v, t)))
+
+ case expressions.GreaterThan(PushableColumn(name), Literal(v, t)) =>
+ Some(sources.GreaterThan(name, convertToScala(v, t)))
+ case expressions.GreaterThan(Literal(v, t), PushableColumn(name)) =>
+ Some(sources.LessThan(name, convertToScala(v, t)))
+
+ case expressions.LessThan(PushableColumn(name), Literal(v, t)) =>
+ Some(sources.LessThan(name, convertToScala(v, t)))
+ case expressions.LessThan(Literal(v, t), PushableColumn(name)) =>
+ Some(sources.GreaterThan(name, convertToScala(v, t)))
+
+ case expressions.GreaterThanOrEqual(PushableColumn(name), Literal(v, t)) =>
+ Some(sources.GreaterThanOrEqual(name, convertToScala(v, t)))
+ case expressions.GreaterThanOrEqual(Literal(v, t), PushableColumn(name)) =>
+ Some(sources.LessThanOrEqual(name, convertToScala(v, t)))
+
+ case expressions.LessThanOrEqual(PushableColumn(name), Literal(v, t)) =>
+ Some(sources.LessThanOrEqual(name, convertToScala(v, t)))
+ case expressions.LessThanOrEqual(Literal(v, t), PushableColumn(name)) =>
+ Some(sources.GreaterThanOrEqual(name, convertToScala(v, t)))
+
+ case expressions.InSet(e @ PushableColumn(name), set) =>
+ val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
+ Some(sources.In(name, set.toArray.map(toScala)))
// Because we only convert In to InSet in Optimizer when there are more
than certain
// items. So it is possible we still get an In expression here that needs
to be pushed
// down.
- case expressions.In(a: Attribute, list) if
list.forall(_.isInstanceOf[Literal]) =>
+ case expressions.In(e @ PushableColumn(name), list) if
list.forall(_.isInstanceOf[Literal]) =>
val hSet = list.map(_.eval(EmptyRow))
- val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
- Some(sources.In(a.name, hSet.toArray.map(toScala)))
+ val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
+ Some(sources.In(name, hSet.toArray.map(toScala)))
- case expressions.IsNull(a: Attribute) =>
- Some(sources.IsNull(a.name))
- case expressions.IsNotNull(a: Attribute) =>
- Some(sources.IsNotNull(a.name))
- case expressions.StartsWith(a: Attribute, Literal(v: UTF8String,
StringType)) =>
- Some(sources.StringStartsWith(a.name, v.toString))
+ case expressions.IsNull(PushableColumn(name)) =>
+ Some(sources.IsNull(name))
+ case expressions.IsNotNull(PushableColumn(name)) =>
+ Some(sources.IsNotNull(name))
+ case expressions.StartsWith(PushableColumn(name), Literal(v: UTF8String,
StringType)) =>
+ Some(sources.StringStartsWith(name, v.toString))
- case expressions.EndsWith(a: Attribute, Literal(v: UTF8String,
StringType)) =>
- Some(sources.StringEndsWith(a.name, v.toString))
+ case expressions.EndsWith(PushableColumn(name), Literal(v: UTF8String,
StringType)) =>
+ Some(sources.StringEndsWith(name, v.toString))
- case expressions.Contains(a: Attribute, Literal(v: UTF8String,
StringType)) =>
- Some(sources.StringContains(a.name, v.toString))
+ case expressions.Contains(PushableColumn(name), Literal(v: UTF8String,
StringType)) =>
+ Some(sources.StringContains(name, v.toString))
case expressions.Literal(true, BooleanType) =>
Some(sources.AlwaysTrue)
@@ -646,3 +646,16 @@ object DataSourceStrategy {
}
}
}
+
+/**
+ * Find the column name of an expression that can be pushed down.
+ */
+object PushableColumn {
+ def unapply(e: Expression): Option[String] = {
+ def helper(e: Expression) = e match {
+ case a: Attribute => Some(a.name)
+ case _ => None
+ }
+ helper(e)
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
index b76db70..7bd3213 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
@@ -22,68 +22,82 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.sources
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField,
StructType}
class DataSourceStrategySuite extends PlanTest with SharedSparkSession {
+ val attrInts = Seq(
+ 'cint.int
+ ).zip(Seq(
+ "cint"
+ ))
- test("translate simple expression") {
- val attrInt = 'cint.int
- val attrStr = 'cstr.string
+ val attrStrs = Seq(
+ 'cstr.string
+ ).zip(Seq(
+ "cstr"
+ ))
+
+ test("translate simple expression") { attrInts.zip(attrStrs)
+ .foreach { case ((attrInt, intColName), (attrStr, strColName)) =>
- testTranslateFilter(EqualTo(attrInt, 1), Some(sources.EqualTo("cint", 1)))
- testTranslateFilter(EqualTo(1, attrInt), Some(sources.EqualTo("cint", 1)))
+ testTranslateFilter(EqualTo(attrInt, 1), Some(sources.EqualTo(intColName,
1)))
+ testTranslateFilter(EqualTo(1, attrInt), Some(sources.EqualTo(intColName,
1)))
testTranslateFilter(EqualNullSafe(attrStr, Literal(null)),
- Some(sources.EqualNullSafe("cstr", null)))
+ Some(sources.EqualNullSafe(strColName, null)))
testTranslateFilter(EqualNullSafe(Literal(null), attrStr),
- Some(sources.EqualNullSafe("cstr", null)))
+ Some(sources.EqualNullSafe(strColName, null)))
- testTranslateFilter(GreaterThan(attrInt, 1),
Some(sources.GreaterThan("cint", 1)))
- testTranslateFilter(GreaterThan(1, attrInt), Some(sources.LessThan("cint",
1)))
+ testTranslateFilter(GreaterThan(attrInt, 1),
Some(sources.GreaterThan(intColName, 1)))
+ testTranslateFilter(GreaterThan(1, attrInt),
Some(sources.LessThan(intColName, 1)))
- testTranslateFilter(LessThan(attrInt, 1), Some(sources.LessThan("cint",
1)))
- testTranslateFilter(LessThan(1, attrInt), Some(sources.GreaterThan("cint",
1)))
+ testTranslateFilter(LessThan(attrInt, 1),
Some(sources.LessThan(intColName, 1)))
+ testTranslateFilter(LessThan(1, attrInt),
Some(sources.GreaterThan(intColName, 1)))
- testTranslateFilter(GreaterThanOrEqual(attrInt, 1),
Some(sources.GreaterThanOrEqual("cint", 1)))
- testTranslateFilter(GreaterThanOrEqual(1, attrInt),
Some(sources.LessThanOrEqual("cint", 1)))
+ testTranslateFilter(GreaterThanOrEqual(attrInt, 1),
+ Some(sources.GreaterThanOrEqual(intColName, 1)))
+ testTranslateFilter(GreaterThanOrEqual(1, attrInt),
+ Some(sources.LessThanOrEqual(intColName, 1)))
- testTranslateFilter(LessThanOrEqual(attrInt, 1),
Some(sources.LessThanOrEqual("cint", 1)))
- testTranslateFilter(LessThanOrEqual(1, attrInt),
Some(sources.GreaterThanOrEqual("cint", 1)))
+ testTranslateFilter(LessThanOrEqual(attrInt, 1),
+ Some(sources.LessThanOrEqual(intColName, 1)))
+ testTranslateFilter(LessThanOrEqual(1, attrInt),
+ Some(sources.GreaterThanOrEqual(intColName, 1)))
- testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint",
Array(1, 2, 3))))
+ testTranslateFilter(InSet(attrInt, Set(1, 2, 3)),
Some(sources.In(intColName, Array(1, 2, 3))))
- testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint",
Array(1, 2, 3))))
+ testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In(intColName,
Array(1, 2, 3))))
- testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint")))
- testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint")))
+ testTranslateFilter(IsNull(attrInt), Some(sources.IsNull(intColName)))
+ testTranslateFilter(IsNotNull(attrInt),
Some(sources.IsNotNull(intColName)))
// cint > 1 AND cint < 10
testTranslateFilter(And(
GreaterThan(attrInt, 1),
LessThan(attrInt, 10)),
Some(sources.And(
- sources.GreaterThan("cint", 1),
- sources.LessThan("cint", 10))))
+ sources.GreaterThan(intColName, 1),
+ sources.LessThan(intColName, 10))))
// cint >= 8 OR cint <= 2
testTranslateFilter(Or(
GreaterThanOrEqual(attrInt, 8),
LessThanOrEqual(attrInt, 2)),
Some(sources.Or(
- sources.GreaterThanOrEqual("cint", 8),
- sources.LessThanOrEqual("cint", 2))))
+ sources.GreaterThanOrEqual(intColName, 8),
+ sources.LessThanOrEqual(intColName, 2))))
testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)),
- Some(sources.Not(sources.GreaterThanOrEqual("cint", 8))))
+ Some(sources.Not(sources.GreaterThanOrEqual(intColName, 8))))
- testTranslateFilter(StartsWith(attrStr, "a"),
Some(sources.StringStartsWith("cstr", "a")))
+ testTranslateFilter(StartsWith(attrStr, "a"),
Some(sources.StringStartsWith(strColName, "a")))
- testTranslateFilter(EndsWith(attrStr, "a"),
Some(sources.StringEndsWith("cstr", "a")))
+ testTranslateFilter(EndsWith(attrStr, "a"),
Some(sources.StringEndsWith(strColName, "a")))
- testTranslateFilter(Contains(attrStr, "a"),
Some(sources.StringContains("cstr", "a")))
- }
+ testTranslateFilter(Contains(attrStr, "a"),
Some(sources.StringContains(strColName, "a")))
+ }}
- test("translate complex expression") {
- val attrInt = 'cint.int
+ test("translate complex expression") { attrInts.foreach { case (attrInt,
intColName) =>
// ABS(cint) - 2 <= 1
testTranslateFilter(LessThanOrEqual(
@@ -102,11 +116,11 @@ class DataSourceStrategySuite extends PlanTest with
SharedSparkSession {
LessThan(attrInt, 100))),
Some(sources.Or(
sources.And(
- sources.GreaterThan("cint", 1),
- sources.LessThan("cint", 10)),
+ sources.GreaterThan(intColName, 1),
+ sources.LessThan(intColName, 10)),
sources.And(
- sources.GreaterThan("cint", 50),
- sources.LessThan("cint", 100)))))
+ sources.GreaterThan(intColName, 50),
+ sources.LessThan(intColName, 100)))))
// SPARK-22548 Incorrect nested AND expression pushed down to JDBC data
source
// (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100)
@@ -142,11 +156,11 @@ class DataSourceStrategySuite extends PlanTest with
SharedSparkSession {
LessThan(attrInt, -10))),
Some(sources.Or(
sources.Or(
- sources.EqualTo("cint", 1),
- sources.EqualTo("cint", 10)),
+ sources.EqualTo(intColName, 1),
+ sources.EqualTo(intColName, 10)),
sources.Or(
- sources.GreaterThan("cint", 0),
- sources.LessThan("cint", -10)))))
+ sources.GreaterThan(intColName, 0),
+ sources.LessThan(intColName, -10)))))
// (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10)
testTranslateFilter(Or(
@@ -173,11 +187,11 @@ class DataSourceStrategySuite extends PlanTest with
SharedSparkSession {
IsNotNull(attrInt))),
Some(sources.And(
sources.And(
- sources.GreaterThan("cint", 1),
- sources.LessThan("cint", 10)),
+ sources.GreaterThan(intColName, 1),
+ sources.LessThan(intColName, 10)),
sources.And(
- sources.EqualTo("cint", 6),
- sources.IsNotNull("cint")))))
+ sources.EqualTo(intColName, 6),
+ sources.IsNotNull(intColName)))))
// (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL)
testTranslateFilter(And(
@@ -201,11 +215,11 @@ class DataSourceStrategySuite extends PlanTest with
SharedSparkSession {
IsNotNull(attrInt))),
Some(sources.And(
sources.Or(
- sources.GreaterThan("cint", 1),
- sources.LessThan("cint", 10)),
+ sources.GreaterThan(intColName, 1),
+ sources.LessThan(intColName, 10)),
sources.Or(
- sources.EqualTo("cint", 6),
- sources.IsNotNull("cint")))))
+ sources.EqualTo(intColName, 6),
+ sources.IsNotNull(intColName)))))
// (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
testTranslateFilter(And(
@@ -217,7 +231,7 @@ class DataSourceStrategySuite extends PlanTest with
SharedSparkSession {
// Functions such as 'Abs' are not supported
EqualTo(Abs(attrInt), 6),
IsNotNull(attrInt))), None)
- }
+ }}
test("SPARK-26865 DataSourceV2Strategy should push normalized filters") {
val attrInt = 'cint.int
@@ -226,6 +240,19 @@ class DataSourceStrategySuite extends PlanTest with
SharedSparkSession {
}
}
+ test("SPARK-31027 test `PushableColumn.unapply` that finds the column name
of " +
+ "an expression that can be pushed down") {
+ attrInts.foreach { case (attrInt, colName) =>
+ assert(PushableColumn.unapply(attrInt) === Some(colName))
+ }
+ attrStrs.foreach { case (attrStr, colName) =>
+ assert(PushableColumn.unapply(attrStr) === Some(colName))
+ }
+
+ // `Abs(col)` can not be pushed down, so it returns `None`
+ assert(PushableColumn.unapply(Abs('col.int)) === None)
+ }
+
/**
* Translate the given Catalyst [[Expression]] into data source
[[sources.Filter]]
* then verify against the given [[sources.Filter]].
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]