This is an automated email from the ASF dual-hosted git repository. dbtsai pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 70c86e6 [SPARK-31027][SQL] Refactor DataSourceStrategy to be more extendable 70c86e6 is described below commit 70c86e6b166434ce6d3420f122a33d933728fa91 Author: DB Tsai <d_t...@apple.com> AuthorDate: Wed Mar 4 23:41:49 2020 +0900 [SPARK-31027][SQL] Refactor DataSourceStrategy to be more extendable ### What changes were proposed in this pull request? Refactor `DataSourceStrategy.scala` and `DataSourceStrategySuite.scala` so it's more extendable to implement nested predicate pushdown. ### Why are the changes needed? To support nested predicate pushdown. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing tests and new tests. Closes #27778 from dbtsai/SPARK-31027. Authored-by: DB Tsai <d_t...@apple.com> Signed-off-by: HyukjinKwon <gurwls...@apache.org> --- .../execution/datasources/DataSourceStrategy.scala | 105 ++++++++++-------- .../datasources/DataSourceStrategySuite.scala | 121 +++++++++++++-------- 2 files changed, 133 insertions(+), 93 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2d902b5..1641b66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -449,60 +449,60 @@ object DataSourceStrategy { } private def translateLeafNodeFilter(predicate: Expression): Option[Filter] = predicate match { - case expressions.EqualTo(a: Attribute, Literal(v, t)) => - Some(sources.EqualTo(a.name, convertToScala(v, t))) - case expressions.EqualTo(Literal(v, t), a: Attribute) => - Some(sources.EqualTo(a.name, convertToScala(v, t))) - - case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) => - Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) - case expressions.EqualNullSafe(Literal(v, t), a: Attribute) => - Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) - - case expressions.GreaterThan(a: Attribute, Literal(v, t)) => - Some(sources.GreaterThan(a.name, convertToScala(v, t))) - case expressions.GreaterThan(Literal(v, t), a: Attribute) => - Some(sources.LessThan(a.name, convertToScala(v, t))) - - case expressions.LessThan(a: Attribute, Literal(v, t)) => - Some(sources.LessThan(a.name, convertToScala(v, t))) - case expressions.LessThan(Literal(v, t), a: Attribute) => - Some(sources.GreaterThan(a.name, convertToScala(v, t))) - - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) => - Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) - case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) => - Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) - - case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) => - Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) - case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) => - Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) - - case expressions.InSet(a: Attribute, set) => - val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) - Some(sources.In(a.name, set.toArray.map(toScala))) + case expressions.EqualTo(PushableColumn(name), Literal(v, t)) => + Some(sources.EqualTo(name, convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), PushableColumn(name)) => + Some(sources.EqualTo(name, convertToScala(v, t))) + + case expressions.EqualNullSafe(PushableColumn(name), Literal(v, t)) => + Some(sources.EqualNullSafe(name, convertToScala(v, t))) + case expressions.EqualNullSafe(Literal(v, t), PushableColumn(name)) => + Some(sources.EqualNullSafe(name, convertToScala(v, t))) + + case expressions.GreaterThan(PushableColumn(name), Literal(v, t)) => + Some(sources.GreaterThan(name, convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), PushableColumn(name)) => + Some(sources.LessThan(name, convertToScala(v, t))) + + case expressions.LessThan(PushableColumn(name), Literal(v, t)) => + Some(sources.LessThan(name, convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), PushableColumn(name)) => + Some(sources.GreaterThan(name, convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(PushableColumn(name), Literal(v, t)) => + Some(sources.GreaterThanOrEqual(name, convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), PushableColumn(name)) => + Some(sources.LessThanOrEqual(name, convertToScala(v, t))) + + case expressions.LessThanOrEqual(PushableColumn(name), Literal(v, t)) => + Some(sources.LessThanOrEqual(name, convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), PushableColumn(name)) => + Some(sources.GreaterThanOrEqual(name, convertToScala(v, t))) + + case expressions.InSet(e @ PushableColumn(name), set) => + val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) + Some(sources.In(name, set.toArray.map(toScala))) // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. - case expressions.In(a: Attribute, list) if list.forall(_.isInstanceOf[Literal]) => + case expressions.In(e @ PushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) => val hSet = list.map(_.eval(EmptyRow)) - val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) - Some(sources.In(a.name, hSet.toArray.map(toScala))) + val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) + Some(sources.In(name, hSet.toArray.map(toScala))) - case expressions.IsNull(a: Attribute) => - Some(sources.IsNull(a.name)) - case expressions.IsNotNull(a: Attribute) => - Some(sources.IsNotNull(a.name)) - case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringStartsWith(a.name, v.toString)) + case expressions.IsNull(PushableColumn(name)) => + Some(sources.IsNull(name)) + case expressions.IsNotNull(PushableColumn(name)) => + Some(sources.IsNotNull(name)) + case expressions.StartsWith(PushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(sources.StringStartsWith(name, v.toString)) - case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringEndsWith(a.name, v.toString)) + case expressions.EndsWith(PushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(sources.StringEndsWith(name, v.toString)) - case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringContains(a.name, v.toString)) + case expressions.Contains(PushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(sources.StringContains(name, v.toString)) case expressions.Literal(true, BooleanType) => Some(sources.AlwaysTrue) @@ -646,3 +646,16 @@ object DataSourceStrategy { } } } + +/** + * Find the column name of an expression that can be pushed down. + */ +object PushableColumn { + def unapply(e: Expression): Option[String] = { + def helper(e: Expression) = e match { + case a: Attribute => Some(a.name) + case _ => None + } + helper(e) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index b76db70..7bd3213 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -22,68 +22,82 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.sources import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} class DataSourceStrategySuite extends PlanTest with SharedSparkSession { + val attrInts = Seq( + 'cint.int + ).zip(Seq( + "cint" + )) - test("translate simple expression") { - val attrInt = 'cint.int - val attrStr = 'cstr.string + val attrStrs = Seq( + 'cstr.string + ).zip(Seq( + "cstr" + )) + + test("translate simple expression") { attrInts.zip(attrStrs) + .foreach { case ((attrInt, intColName), (attrStr, strColName)) => - testTranslateFilter(EqualTo(attrInt, 1), Some(sources.EqualTo("cint", 1))) - testTranslateFilter(EqualTo(1, attrInt), Some(sources.EqualTo("cint", 1))) + testTranslateFilter(EqualTo(attrInt, 1), Some(sources.EqualTo(intColName, 1))) + testTranslateFilter(EqualTo(1, attrInt), Some(sources.EqualTo(intColName, 1))) testTranslateFilter(EqualNullSafe(attrStr, Literal(null)), - Some(sources.EqualNullSafe("cstr", null))) + Some(sources.EqualNullSafe(strColName, null))) testTranslateFilter(EqualNullSafe(Literal(null), attrStr), - Some(sources.EqualNullSafe("cstr", null))) + Some(sources.EqualNullSafe(strColName, null))) - testTranslateFilter(GreaterThan(attrInt, 1), Some(sources.GreaterThan("cint", 1))) - testTranslateFilter(GreaterThan(1, attrInt), Some(sources.LessThan("cint", 1))) + testTranslateFilter(GreaterThan(attrInt, 1), Some(sources.GreaterThan(intColName, 1))) + testTranslateFilter(GreaterThan(1, attrInt), Some(sources.LessThan(intColName, 1))) - testTranslateFilter(LessThan(attrInt, 1), Some(sources.LessThan("cint", 1))) - testTranslateFilter(LessThan(1, attrInt), Some(sources.GreaterThan("cint", 1))) + testTranslateFilter(LessThan(attrInt, 1), Some(sources.LessThan(intColName, 1))) + testTranslateFilter(LessThan(1, attrInt), Some(sources.GreaterThan(intColName, 1))) - testTranslateFilter(GreaterThanOrEqual(attrInt, 1), Some(sources.GreaterThanOrEqual("cint", 1))) - testTranslateFilter(GreaterThanOrEqual(1, attrInt), Some(sources.LessThanOrEqual("cint", 1))) + testTranslateFilter(GreaterThanOrEqual(attrInt, 1), + Some(sources.GreaterThanOrEqual(intColName, 1))) + testTranslateFilter(GreaterThanOrEqual(1, attrInt), + Some(sources.LessThanOrEqual(intColName, 1))) - testTranslateFilter(LessThanOrEqual(attrInt, 1), Some(sources.LessThanOrEqual("cint", 1))) - testTranslateFilter(LessThanOrEqual(1, attrInt), Some(sources.GreaterThanOrEqual("cint", 1))) + testTranslateFilter(LessThanOrEqual(attrInt, 1), + Some(sources.LessThanOrEqual(intColName, 1))) + testTranslateFilter(LessThanOrEqual(1, attrInt), + Some(sources.GreaterThanOrEqual(intColName, 1))) - testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In(intColName, Array(1, 2, 3)))) - testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In(intColName, Array(1, 2, 3)))) - testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint"))) - testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint"))) + testTranslateFilter(IsNull(attrInt), Some(sources.IsNull(intColName))) + testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull(intColName))) // cint > 1 AND cint < 10 testTranslateFilter(And( GreaterThan(attrInt, 1), LessThan(attrInt, 10)), Some(sources.And( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)))) + sources.GreaterThan(intColName, 1), + sources.LessThan(intColName, 10)))) // cint >= 8 OR cint <= 2 testTranslateFilter(Or( GreaterThanOrEqual(attrInt, 8), LessThanOrEqual(attrInt, 2)), Some(sources.Or( - sources.GreaterThanOrEqual("cint", 8), - sources.LessThanOrEqual("cint", 2)))) + sources.GreaterThanOrEqual(intColName, 8), + sources.LessThanOrEqual(intColName, 2)))) testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)), - Some(sources.Not(sources.GreaterThanOrEqual("cint", 8)))) + Some(sources.Not(sources.GreaterThanOrEqual(intColName, 8)))) - testTranslateFilter(StartsWith(attrStr, "a"), Some(sources.StringStartsWith("cstr", "a"))) + testTranslateFilter(StartsWith(attrStr, "a"), Some(sources.StringStartsWith(strColName, "a"))) - testTranslateFilter(EndsWith(attrStr, "a"), Some(sources.StringEndsWith("cstr", "a"))) + testTranslateFilter(EndsWith(attrStr, "a"), Some(sources.StringEndsWith(strColName, "a"))) - testTranslateFilter(Contains(attrStr, "a"), Some(sources.StringContains("cstr", "a"))) - } + testTranslateFilter(Contains(attrStr, "a"), Some(sources.StringContains(strColName, "a"))) + }} - test("translate complex expression") { - val attrInt = 'cint.int + test("translate complex expression") { attrInts.foreach { case (attrInt, intColName) => // ABS(cint) - 2 <= 1 testTranslateFilter(LessThanOrEqual( @@ -102,11 +116,11 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { LessThan(attrInt, 100))), Some(sources.Or( sources.And( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)), + sources.GreaterThan(intColName, 1), + sources.LessThan(intColName, 10)), sources.And( - sources.GreaterThan("cint", 50), - sources.LessThan("cint", 100))))) + sources.GreaterThan(intColName, 50), + sources.LessThan(intColName, 100))))) // SPARK-22548 Incorrect nested AND expression pushed down to JDBC data source // (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100) @@ -142,11 +156,11 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { LessThan(attrInt, -10))), Some(sources.Or( sources.Or( - sources.EqualTo("cint", 1), - sources.EqualTo("cint", 10)), + sources.EqualTo(intColName, 1), + sources.EqualTo(intColName, 10)), sources.Or( - sources.GreaterThan("cint", 0), - sources.LessThan("cint", -10))))) + sources.GreaterThan(intColName, 0), + sources.LessThan(intColName, -10))))) // (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10) testTranslateFilter(Or( @@ -173,11 +187,11 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { IsNotNull(attrInt))), Some(sources.And( sources.And( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)), + sources.GreaterThan(intColName, 1), + sources.LessThan(intColName, 10)), sources.And( - sources.EqualTo("cint", 6), - sources.IsNotNull("cint"))))) + sources.EqualTo(intColName, 6), + sources.IsNotNull(intColName))))) // (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL) testTranslateFilter(And( @@ -201,11 +215,11 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { IsNotNull(attrInt))), Some(sources.And( sources.Or( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)), + sources.GreaterThan(intColName, 1), + sources.LessThan(intColName, 10)), sources.Or( - sources.EqualTo("cint", 6), - sources.IsNotNull("cint"))))) + sources.EqualTo(intColName, 6), + sources.IsNotNull(intColName))))) // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL) testTranslateFilter(And( @@ -217,7 +231,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { // Functions such as 'Abs' are not supported EqualTo(Abs(attrInt), 6), IsNotNull(attrInt))), None) - } + }} test("SPARK-26865 DataSourceV2Strategy should push normalized filters") { val attrInt = 'cint.int @@ -226,6 +240,19 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { } } + test("SPARK-31027 test `PushableColumn.unapply` that finds the column name of " + + "an expression that can be pushed down") { + attrInts.foreach { case (attrInt, colName) => + assert(PushableColumn.unapply(attrInt) === Some(colName)) + } + attrStrs.foreach { case (attrStr, colName) => + assert(PushableColumn.unapply(attrStr) === Some(colName)) + } + + // `Abs(col)` can not be pushed down, so it returns `None` + assert(PushableColumn.unapply(Abs('col.int)) === None) + } + /** * Translate the given Catalyst [[Expression]] into data source [[sources.Filter]] * then verify against the given [[sources.Filter]]. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org