This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cf7fc6b1031 [SPARK-39037][SQL] DS V2 aggregate push-down supports 
order by expressions
cf7fc6b1031 is described below

commit cf7fc6b1031383bfc1cbf7c201e9830dad413cf8
Author: Jiaan Geng <[email protected]>
AuthorDate: Thu Apr 28 15:50:21 2022 +0800

    [SPARK-39037][SQL] DS V2 aggregate push-down supports order by expressions
    
    ### What changes were proposed in this pull request?
    Currently, Spark DS V2 aggregate push-down only supports order by column.
    But the SQL show below is very useful and common.
    ```
    SELECT
      CASE
        WHEN 'SALARY' > 8000.00
          AND 'SALARY' < 10000.00
        THEN 'SALARY'
        ELSE 0.00
      END AS key,
      dept,
      name
    FROM "test"."employee"
    ORDER BY key
    ```
    
    ### Why are the changes needed?
    Let DS V2 aggregate push-down supports order by expressions
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    New tests
    
    Closes #36370 from beliefer/SPARK-39037.
    
    Authored-by: Jiaan Geng <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../execution/datasources/DataSourceStrategy.scala |  8 ++---
 .../sql/execution/datasources/jdbc/JDBCRDD.scala   |  7 ++--
 .../execution/datasources/jdbc/JDBCRelation.scala  |  3 +-
 .../datasources/v2/V2ScanRelationPushDown.scala    |  4 ++-
 .../execution/datasources/v2/jdbc/JDBCScan.scala   |  3 +-
 .../datasources/v2/jdbc/JDBCScanBuilder.scala      | 10 ++++--
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 40 +++++++++++++++++++---
 7 files changed, 56 insertions(+), 19 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index e35d0932076..04b77013a0c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -776,8 +776,8 @@ object DataSourceStrategy
   }
 
   protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): 
Seq[V2SortOrder] = {
-    def translateOortOrder(sortOrder: SortOrder): Option[V2SortOrder] = 
sortOrder match {
-      case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, 
nullOrderingV1, _) =>
+    def translateSortOrder(sortOrder: SortOrder): Option[V2SortOrder] = 
sortOrder match {
+      case SortOrder(PushableExpression(expr), directionV1, nullOrderingV1, _) 
=>
         val directionV2 = directionV1 match {
           case Ascending => SortDirection.ASCENDING
           case Descending => SortDirection.DESCENDING
@@ -786,11 +786,11 @@ object DataSourceStrategy
           case NullsFirst => NullOrdering.NULLS_FIRST
           case NullsLast => NullOrdering.NULLS_LAST
         }
-        Some(SortValue(FieldReference(name), directionV2, nullOrderingV2))
+        Some(SortValue(expr, directionV2, nullOrderingV2))
       case _ => None
     }
 
-    sortOrders.flatMap(translateOortOrder)
+    sortOrders.flatMap(translateSortOrder)
   }
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index b30b460ac67..13d6156aed1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -25,7 +25,6 @@ import org.apache.spark.{InterruptibleIterator, Partition, 
SparkContext, TaskCon
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.expressions.SortOrder
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
 import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
@@ -123,7 +122,7 @@ object JDBCRDD extends Logging {
       groupByColumns: Option[Array[String]] = None,
       sample: Option[TableSampleInfo] = None,
       limit: Int = 0,
-      sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] 
= {
+      sortOrders: Array[String] = Array.empty[String]): RDD[InternalRow] = {
     val url = options.url
     val dialect = JdbcDialects.get(url)
     val quotedColumns = if (groupByColumns.isEmpty) {
@@ -166,7 +165,7 @@ private[jdbc] class JDBCRDD(
     groupByColumns: Option[Array[String]],
     sample: Option[TableSampleInfo],
     limit: Int,
-    sortOrders: Array[SortOrder])
+    sortOrders: Array[String])
   extends RDD[InternalRow](sc, Nil) {
 
   /**
@@ -216,7 +215,7 @@ private[jdbc] class JDBCRDD(
 
   private def getOrderByClause: String = {
     if (sortOrders.nonEmpty) {
-      s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}"
+      s" ORDER BY ${sortOrders.mkString(", ")}"
     } else {
       ""
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 0f1a1b6dc66..ea841027607 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, 
SparkSession, SQLContext}
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, 
TimestampFormatter}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, 
stringToDate, stringToTimestamp}
-import org.apache.spark.sql.connector.expressions.SortOrder
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
@@ -305,7 +304,7 @@ private[sql] case class JDBCRelation(
       groupByColumns: Option[Array[String]],
       tableSample: Option[TableSampleInfo],
       limit: Int,
-      sortOrders: Array[SortOrder]): RDD[Row] = {
+      sortOrders: Array[String]): RDD[Row] = {
     // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
     JDBCRDD.scanTable(
       sparkSession.sparkContext,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 89398fabdc3..03b6544c772 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -381,7 +381,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
           order, project, alwaysInline = true) =>
       val aliasMap = getAliasMap(project)
       val newOrder = order.map(replaceAlias(_, 
aliasMap)).asInstanceOf[Seq[SortOrder]]
-      val orders = DataSourceStrategy.translateSortOrders(newOrder)
+      val normalizedOrders = DataSourceStrategy.normalizeExprs(
+        newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]]
+      val orders = DataSourceStrategy.translateSortOrders(normalizedOrders)
       if (orders.length == order.length) {
         val (isPushed, isPartiallyPushed) =
           PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
index f68f78d51fd..5ca23e550aa 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
@@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.sql.connector.expressions.SortOrder
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.connector.read.V1Scan
 import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
@@ -34,7 +33,7 @@ case class JDBCScan(
     groupByColumns: Option[Array[String]],
     tableSample: Option[TableSampleInfo],
     pushedLimit: Int,
-    sortOrders: Array[SortOrder]) extends V1Scan {
+    sortOrders: Array[String]) extends V1Scan {
 
   override def readSchema(): StructType = prunedSchema
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 8b378d2d87c..a09444d2a3e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -53,7 +53,7 @@ case class JDBCScanBuilder(
 
   private var pushedLimit = 0
 
-  private var sortOrders: Array[SortOrder] = Array.empty[SortOrder]
+  private var sortOrders: Array[String] = Array.empty[String]
 
   override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] 
= {
     if (jdbcOptions.pushDownPredicate) {
@@ -140,8 +140,14 @@ case class JDBCScanBuilder(
 
   override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = {
     if (jdbcOptions.pushDownLimit) {
+      val dialect = JdbcDialects.get(jdbcOptions.url)
+      val compiledOrders = orders.flatMap { order =>
+        dialect.compileExpression(order.expression())
+          .map(sortKey => s"$sortKey ${order.direction()} 
${order.nullOrdering()}")
+      }
+      if (orders.length != compiledOrders.length) return false
       pushedLimit = limit
-      sortOrders = orders
+      sortOrders = compiledOrders
       return true
     }
     false
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 74e226acb7a..178a4600125 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -222,7 +222,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkSortRemoved(df1)
     checkLimitRemoved(df1)
     checkPushedInfo(df1,
-      "PushedFilters: [], PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 
1, ")
+      "PushedFilters: [], PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 
1, ")
     checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false)))
 
     val df2 = spark.read
@@ -237,7 +237,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkSortRemoved(df2)
     checkLimitRemoved(df2)
     checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " +
-      "PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ")
+      "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ")
     checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false)))
 
     val df3 = spark.read
@@ -252,7 +252,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkSortRemoved(df3, false)
     checkLimitRemoved(df3, false)
     checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " +
-      "PushedTopN: ORDER BY [salary DESC NULLS LAST] LIMIT 1, ")
+      "PushedTopN: ORDER BY [SALARY DESC NULLS LAST] LIMIT 1, ")
     checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false)))
 
     val df4 =
@@ -261,7 +261,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
     checkSortRemoved(df4)
     checkLimitRemoved(df4)
     checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " +
-      "PushedTopN: ORDER BY [salary ASC NULLS LAST] LIMIT 1, ")
+      "PushedTopN: ORDER BY [SALARY ASC NULLS LAST] LIMIT 1, ")
     checkAnswer(df4, Seq(Row("david")))
 
     val df5 = spark.read.table("h2.test.employee")
@@ -304,6 +304,38 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     checkLimitRemoved(df8, false)
     checkPushedInfo(df8, "PushedFilters: [], ")
     checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false)))
+
+    val df9 = spark.read
+      .table("h2.test.employee")
+      .select($"DEPT", $"name", $"SALARY",
+        when(($"SALARY" > 8000).and($"SALARY" < 10000), 
$"salary").otherwise(0).as("key"))
+      .sort("key", "dept", "SALARY")
+      .limit(3)
+    checkSortRemoved(df9)
+    checkLimitRemoved(df9)
+    checkPushedInfo(df9, "PushedFilters: [], " +
+      "PushedTopN: ORDER BY [CASE WHEN (SALARY > 8000.00) AND " +
+      "(SALARY < 10000.00) THEN SALARY ELSE 0.00 END ASC NULL..., ")
+    checkAnswer(df9,
+      Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 
12000, 0)))
+
+    val df10 = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .select($"DEPT", $"name", $"SALARY",
+        when(($"SALARY" > 8000).and($"SALARY" < 10000), 
$"salary").otherwise(0).as("key"))
+      .orderBy($"key", $"dept", $"SALARY")
+      .limit(3)
+    checkSortRemoved(df10, false)
+    checkLimitRemoved(df10, false)
+    checkPushedInfo(df10, "PushedFilters: [], " +
+      "PushedTopN: ORDER BY [CASE WHEN (SALARY > 8000.00) AND " +
+      "(SALARY < 10000.00) THEN SALARY ELSE 0.00 END ASC NULL..., ")
+    checkAnswer(df10,
+      Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 
12000, 0)))
   }
 
   test("simple scan with top N: order by with alias") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to