This is an automated email from the ASF dual-hosted git repository.

dbtsai pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 70c86e6  [SPARK-31027][SQL] Refactor DataSourceStrategy to be more 
extendable
70c86e6 is described below

commit 70c86e6b166434ce6d3420f122a33d933728fa91
Author: DB Tsai <d_t...@apple.com>
AuthorDate: Wed Mar 4 23:41:49 2020 +0900

    [SPARK-31027][SQL] Refactor DataSourceStrategy to be more extendable
    
    ### What changes were proposed in this pull request?
    Refactor `DataSourceStrategy.scala` and `DataSourceStrategySuite.scala` so 
it's more extendable to implement nested predicate pushdown.
    
    ### Why are the changes needed?
    To support nested predicate pushdown.
    
    ### Does this PR introduce any user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests and new tests.
    
    Closes #27778 from dbtsai/SPARK-31027.
    
    Authored-by: DB Tsai <d_t...@apple.com>
    Signed-off-by: HyukjinKwon <gurwls...@apache.org>
---
 .../execution/datasources/DataSourceStrategy.scala | 105 ++++++++++--------
 .../datasources/DataSourceStrategySuite.scala      | 121 +++++++++++++--------
 2 files changed, 133 insertions(+), 93 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2d902b5..1641b66 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -449,60 +449,60 @@ object DataSourceStrategy {
   }
 
   private def translateLeafNodeFilter(predicate: Expression): Option[Filter] = 
predicate match {
-    case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
-      Some(sources.EqualTo(a.name, convertToScala(v, t)))
-    case expressions.EqualTo(Literal(v, t), a: Attribute) =>
-      Some(sources.EqualTo(a.name, convertToScala(v, t)))
-
-    case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) =>
-      Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
-    case expressions.EqualNullSafe(Literal(v, t), a: Attribute) =>
-      Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
-
-    case expressions.GreaterThan(a: Attribute, Literal(v, t)) =>
-      Some(sources.GreaterThan(a.name, convertToScala(v, t)))
-    case expressions.GreaterThan(Literal(v, t), a: Attribute) =>
-      Some(sources.LessThan(a.name, convertToScala(v, t)))
-
-    case expressions.LessThan(a: Attribute, Literal(v, t)) =>
-      Some(sources.LessThan(a.name, convertToScala(v, t)))
-    case expressions.LessThan(Literal(v, t), a: Attribute) =>
-      Some(sources.GreaterThan(a.name, convertToScala(v, t)))
-
-    case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) =>
-      Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
-    case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) =>
-      Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
-
-    case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) =>
-      Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
-    case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) =>
-      Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
-
-    case expressions.InSet(a: Attribute, set) =>
-      val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
-      Some(sources.In(a.name, set.toArray.map(toScala)))
+    case expressions.EqualTo(PushableColumn(name), Literal(v, t)) =>
+      Some(sources.EqualTo(name, convertToScala(v, t)))
+    case expressions.EqualTo(Literal(v, t), PushableColumn(name)) =>
+      Some(sources.EqualTo(name, convertToScala(v, t)))
+
+    case expressions.EqualNullSafe(PushableColumn(name), Literal(v, t)) =>
+      Some(sources.EqualNullSafe(name, convertToScala(v, t)))
+    case expressions.EqualNullSafe(Literal(v, t), PushableColumn(name)) =>
+      Some(sources.EqualNullSafe(name, convertToScala(v, t)))
+
+    case expressions.GreaterThan(PushableColumn(name), Literal(v, t)) =>
+      Some(sources.GreaterThan(name, convertToScala(v, t)))
+    case expressions.GreaterThan(Literal(v, t), PushableColumn(name)) =>
+      Some(sources.LessThan(name, convertToScala(v, t)))
+
+    case expressions.LessThan(PushableColumn(name), Literal(v, t)) =>
+      Some(sources.LessThan(name, convertToScala(v, t)))
+    case expressions.LessThan(Literal(v, t), PushableColumn(name)) =>
+      Some(sources.GreaterThan(name, convertToScala(v, t)))
+
+    case expressions.GreaterThanOrEqual(PushableColumn(name), Literal(v, t)) =>
+      Some(sources.GreaterThanOrEqual(name, convertToScala(v, t)))
+    case expressions.GreaterThanOrEqual(Literal(v, t), PushableColumn(name)) =>
+      Some(sources.LessThanOrEqual(name, convertToScala(v, t)))
+
+    case expressions.LessThanOrEqual(PushableColumn(name), Literal(v, t)) =>
+      Some(sources.LessThanOrEqual(name, convertToScala(v, t)))
+    case expressions.LessThanOrEqual(Literal(v, t), PushableColumn(name)) =>
+      Some(sources.GreaterThanOrEqual(name, convertToScala(v, t)))
+
+    case expressions.InSet(e @ PushableColumn(name), set) =>
+      val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
+      Some(sources.In(name, set.toArray.map(toScala)))
 
     // Because we only convert In to InSet in Optimizer when there are more 
than certain
     // items. So it is possible we still get an In expression here that needs 
to be pushed
     // down.
-    case expressions.In(a: Attribute, list) if 
list.forall(_.isInstanceOf[Literal]) =>
+    case expressions.In(e @ PushableColumn(name), list) if 
list.forall(_.isInstanceOf[Literal]) =>
       val hSet = list.map(_.eval(EmptyRow))
-      val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
-      Some(sources.In(a.name, hSet.toArray.map(toScala)))
+      val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
+      Some(sources.In(name, hSet.toArray.map(toScala)))
 
-    case expressions.IsNull(a: Attribute) =>
-      Some(sources.IsNull(a.name))
-    case expressions.IsNotNull(a: Attribute) =>
-      Some(sources.IsNotNull(a.name))
-    case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, 
StringType)) =>
-      Some(sources.StringStartsWith(a.name, v.toString))
+    case expressions.IsNull(PushableColumn(name)) =>
+      Some(sources.IsNull(name))
+    case expressions.IsNotNull(PushableColumn(name)) =>
+      Some(sources.IsNotNull(name))
+    case expressions.StartsWith(PushableColumn(name), Literal(v: UTF8String, 
StringType)) =>
+      Some(sources.StringStartsWith(name, v.toString))
 
-    case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, 
StringType)) =>
-      Some(sources.StringEndsWith(a.name, v.toString))
+    case expressions.EndsWith(PushableColumn(name), Literal(v: UTF8String, 
StringType)) =>
+      Some(sources.StringEndsWith(name, v.toString))
 
-    case expressions.Contains(a: Attribute, Literal(v: UTF8String, 
StringType)) =>
-      Some(sources.StringContains(a.name, v.toString))
+    case expressions.Contains(PushableColumn(name), Literal(v: UTF8String, 
StringType)) =>
+      Some(sources.StringContains(name, v.toString))
 
     case expressions.Literal(true, BooleanType) =>
       Some(sources.AlwaysTrue)
@@ -646,3 +646,16 @@ object DataSourceStrategy {
     }
   }
 }
+
+/**
+ * Find the column name of an expression that can be pushed down.
+ */
+object PushableColumn {
+  def unapply(e: Expression): Option[String] = {
+    def helper(e: Expression) = e match {
+      case a: Attribute => Some(a.name)
+      case _ => None
+    }
+    helper(e)
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
index b76db70..7bd3213 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
@@ -22,68 +22,82 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.sources
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, 
StructType}
 
 class DataSourceStrategySuite extends PlanTest with SharedSparkSession {
+  val attrInts = Seq(
+    'cint.int
+  ).zip(Seq(
+    "cint"
+  ))
 
-  test("translate simple expression") {
-    val attrInt = 'cint.int
-    val attrStr = 'cstr.string
+  val attrStrs = Seq(
+    'cstr.string
+  ).zip(Seq(
+    "cstr"
+  ))
+
+  test("translate simple expression") { attrInts.zip(attrStrs)
+    .foreach { case ((attrInt, intColName), (attrStr, strColName)) =>
 
-    testTranslateFilter(EqualTo(attrInt, 1), Some(sources.EqualTo("cint", 1)))
-    testTranslateFilter(EqualTo(1, attrInt), Some(sources.EqualTo("cint", 1)))
+    testTranslateFilter(EqualTo(attrInt, 1), Some(sources.EqualTo(intColName, 
1)))
+    testTranslateFilter(EqualTo(1, attrInt), Some(sources.EqualTo(intColName, 
1)))
 
     testTranslateFilter(EqualNullSafe(attrStr, Literal(null)),
-      Some(sources.EqualNullSafe("cstr", null)))
+      Some(sources.EqualNullSafe(strColName, null)))
     testTranslateFilter(EqualNullSafe(Literal(null), attrStr),
-      Some(sources.EqualNullSafe("cstr", null)))
+      Some(sources.EqualNullSafe(strColName, null)))
 
-    testTranslateFilter(GreaterThan(attrInt, 1), 
Some(sources.GreaterThan("cint", 1)))
-    testTranslateFilter(GreaterThan(1, attrInt), Some(sources.LessThan("cint", 
1)))
+    testTranslateFilter(GreaterThan(attrInt, 1), 
Some(sources.GreaterThan(intColName, 1)))
+    testTranslateFilter(GreaterThan(1, attrInt), 
Some(sources.LessThan(intColName, 1)))
 
-    testTranslateFilter(LessThan(attrInt, 1), Some(sources.LessThan("cint", 
1)))
-    testTranslateFilter(LessThan(1, attrInt), Some(sources.GreaterThan("cint", 
1)))
+    testTranslateFilter(LessThan(attrInt, 1), 
Some(sources.LessThan(intColName, 1)))
+    testTranslateFilter(LessThan(1, attrInt), 
Some(sources.GreaterThan(intColName, 1)))
 
-    testTranslateFilter(GreaterThanOrEqual(attrInt, 1), 
Some(sources.GreaterThanOrEqual("cint", 1)))
-    testTranslateFilter(GreaterThanOrEqual(1, attrInt), 
Some(sources.LessThanOrEqual("cint", 1)))
+    testTranslateFilter(GreaterThanOrEqual(attrInt, 1),
+      Some(sources.GreaterThanOrEqual(intColName, 1)))
+    testTranslateFilter(GreaterThanOrEqual(1, attrInt),
+      Some(sources.LessThanOrEqual(intColName, 1)))
 
-    testTranslateFilter(LessThanOrEqual(attrInt, 1), 
Some(sources.LessThanOrEqual("cint", 1)))
-    testTranslateFilter(LessThanOrEqual(1, attrInt), 
Some(sources.GreaterThanOrEqual("cint", 1)))
+    testTranslateFilter(LessThanOrEqual(attrInt, 1),
+      Some(sources.LessThanOrEqual(intColName, 1)))
+    testTranslateFilter(LessThanOrEqual(1, attrInt),
+      Some(sources.GreaterThanOrEqual(intColName, 1)))
 
-    testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", 
Array(1, 2, 3))))
+    testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), 
Some(sources.In(intColName, Array(1, 2, 3))))
 
-    testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint", 
Array(1, 2, 3))))
+    testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In(intColName, 
Array(1, 2, 3))))
 
-    testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint")))
-    testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint")))
+    testTranslateFilter(IsNull(attrInt), Some(sources.IsNull(intColName)))
+    testTranslateFilter(IsNotNull(attrInt), 
Some(sources.IsNotNull(intColName)))
 
     // cint > 1 AND cint < 10
     testTranslateFilter(And(
       GreaterThan(attrInt, 1),
       LessThan(attrInt, 10)),
       Some(sources.And(
-        sources.GreaterThan("cint", 1),
-        sources.LessThan("cint", 10))))
+        sources.GreaterThan(intColName, 1),
+        sources.LessThan(intColName, 10))))
 
     // cint >= 8 OR cint <= 2
     testTranslateFilter(Or(
       GreaterThanOrEqual(attrInt, 8),
       LessThanOrEqual(attrInt, 2)),
       Some(sources.Or(
-        sources.GreaterThanOrEqual("cint", 8),
-        sources.LessThanOrEqual("cint", 2))))
+        sources.GreaterThanOrEqual(intColName, 8),
+        sources.LessThanOrEqual(intColName, 2))))
 
     testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)),
-      Some(sources.Not(sources.GreaterThanOrEqual("cint", 8))))
+      Some(sources.Not(sources.GreaterThanOrEqual(intColName, 8))))
 
-    testTranslateFilter(StartsWith(attrStr, "a"), 
Some(sources.StringStartsWith("cstr", "a")))
+    testTranslateFilter(StartsWith(attrStr, "a"), 
Some(sources.StringStartsWith(strColName, "a")))
 
-    testTranslateFilter(EndsWith(attrStr, "a"), 
Some(sources.StringEndsWith("cstr", "a")))
+    testTranslateFilter(EndsWith(attrStr, "a"), 
Some(sources.StringEndsWith(strColName, "a")))
 
-    testTranslateFilter(Contains(attrStr, "a"), 
Some(sources.StringContains("cstr", "a")))
-  }
+    testTranslateFilter(Contains(attrStr, "a"), 
Some(sources.StringContains(strColName, "a")))
+  }}
 
-  test("translate complex expression") {
-    val attrInt = 'cint.int
+  test("translate complex expression") { attrInts.foreach { case (attrInt, 
intColName) =>
 
     // ABS(cint) - 2 <= 1
     testTranslateFilter(LessThanOrEqual(
@@ -102,11 +116,11 @@ class DataSourceStrategySuite extends PlanTest with 
SharedSparkSession {
         LessThan(attrInt, 100))),
       Some(sources.Or(
         sources.And(
-          sources.GreaterThan("cint", 1),
-          sources.LessThan("cint", 10)),
+          sources.GreaterThan(intColName, 1),
+          sources.LessThan(intColName, 10)),
         sources.And(
-          sources.GreaterThan("cint", 50),
-          sources.LessThan("cint", 100)))))
+          sources.GreaterThan(intColName, 50),
+          sources.LessThan(intColName, 100)))))
 
     // SPARK-22548 Incorrect nested AND expression pushed down to JDBC data 
source
     // (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100)
@@ -142,11 +156,11 @@ class DataSourceStrategySuite extends PlanTest with 
SharedSparkSession {
         LessThan(attrInt, -10))),
       Some(sources.Or(
         sources.Or(
-          sources.EqualTo("cint", 1),
-          sources.EqualTo("cint", 10)),
+          sources.EqualTo(intColName, 1),
+          sources.EqualTo(intColName, 10)),
         sources.Or(
-          sources.GreaterThan("cint", 0),
-          sources.LessThan("cint", -10)))))
+          sources.GreaterThan(intColName, 0),
+          sources.LessThan(intColName, -10)))))
 
     // (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10)
     testTranslateFilter(Or(
@@ -173,11 +187,11 @@ class DataSourceStrategySuite extends PlanTest with 
SharedSparkSession {
         IsNotNull(attrInt))),
       Some(sources.And(
         sources.And(
-          sources.GreaterThan("cint", 1),
-          sources.LessThan("cint", 10)),
+          sources.GreaterThan(intColName, 1),
+          sources.LessThan(intColName, 10)),
         sources.And(
-          sources.EqualTo("cint", 6),
-          sources.IsNotNull("cint")))))
+          sources.EqualTo(intColName, 6),
+          sources.IsNotNull(intColName)))))
 
     // (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL)
     testTranslateFilter(And(
@@ -201,11 +215,11 @@ class DataSourceStrategySuite extends PlanTest with 
SharedSparkSession {
         IsNotNull(attrInt))),
       Some(sources.And(
         sources.Or(
-          sources.GreaterThan("cint", 1),
-          sources.LessThan("cint", 10)),
+          sources.GreaterThan(intColName, 1),
+          sources.LessThan(intColName, 10)),
         sources.Or(
-          sources.EqualTo("cint", 6),
-          sources.IsNotNull("cint")))))
+          sources.EqualTo(intColName, 6),
+          sources.IsNotNull(intColName)))))
 
     // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
     testTranslateFilter(And(
@@ -217,7 +231,7 @@ class DataSourceStrategySuite extends PlanTest with 
SharedSparkSession {
         // Functions such as 'Abs' are not supported
         EqualTo(Abs(attrInt), 6),
         IsNotNull(attrInt))), None)
-  }
+  }}
 
   test("SPARK-26865 DataSourceV2Strategy should push normalized filters") {
     val attrInt = 'cint.int
@@ -226,6 +240,19 @@ class DataSourceStrategySuite extends PlanTest with 
SharedSparkSession {
     }
   }
 
+  test("SPARK-31027 test `PushableColumn.unapply` that finds the column name 
of " +
+    "an expression that can be pushed down") {
+    attrInts.foreach { case (attrInt, colName) =>
+      assert(PushableColumn.unapply(attrInt) === Some(colName))
+    }
+    attrStrs.foreach { case (attrStr, colName) =>
+      assert(PushableColumn.unapply(attrStr) === Some(colName))
+    }
+
+    // `Abs(col)` can not be pushed down, so it returns `None`
+    assert(PushableColumn.unapply(Abs('col.int)) === None)
+  }
+
   /**
    * Translate the given Catalyst [[Expression]] into data source 
[[sources.Filter]]
    * then verify against the given [[sources.Filter]].


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to