This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cf7fc6b1031 [SPARK-39037][SQL] DS V2 aggregate push-down supports
order by expressions
cf7fc6b1031 is described below
commit cf7fc6b1031383bfc1cbf7c201e9830dad413cf8
Author: Jiaan Geng <[email protected]>
AuthorDate: Thu Apr 28 15:50:21 2022 +0800
[SPARK-39037][SQL] DS V2 aggregate push-down supports order by expressions
### What changes were proposed in this pull request?
Currently, Spark DS V2 aggregate push-down only supports order by column.
But the SQL show below is very useful and common.
```
SELECT
CASE
WHEN 'SALARY' > 8000.00
AND 'SALARY' < 10000.00
THEN 'SALARY'
ELSE 0.00
END AS key,
dept,
name
FROM "test"."employee"
ORDER BY key
```
### Why are the changes needed?
Let DS V2 aggregate push-down supports order by expressions
### Does this PR introduce _any_ user-facing change?
'No'.
New feature.
### How was this patch tested?
New tests
Closes #36370 from beliefer/SPARK-39037.
Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../execution/datasources/DataSourceStrategy.scala | 8 ++---
.../sql/execution/datasources/jdbc/JDBCRDD.scala | 7 ++--
.../execution/datasources/jdbc/JDBCRelation.scala | 3 +-
.../datasources/v2/V2ScanRelationPushDown.scala | 4 ++-
.../execution/datasources/v2/jdbc/JDBCScan.scala | 3 +-
.../datasources/v2/jdbc/JDBCScanBuilder.scala | 10 ++++--
.../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 40 +++++++++++++++++++---
7 files changed, 56 insertions(+), 19 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index e35d0932076..04b77013a0c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -776,8 +776,8 @@ object DataSourceStrategy
}
protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]):
Seq[V2SortOrder] = {
- def translateOortOrder(sortOrder: SortOrder): Option[V2SortOrder] =
sortOrder match {
- case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1,
nullOrderingV1, _) =>
+ def translateSortOrder(sortOrder: SortOrder): Option[V2SortOrder] =
sortOrder match {
+ case SortOrder(PushableExpression(expr), directionV1, nullOrderingV1, _)
=>
val directionV2 = directionV1 match {
case Ascending => SortDirection.ASCENDING
case Descending => SortDirection.DESCENDING
@@ -786,11 +786,11 @@ object DataSourceStrategy
case NullsFirst => NullOrdering.NULLS_FIRST
case NullsLast => NullOrdering.NULLS_LAST
}
- Some(SortValue(FieldReference(name), directionV2, nullOrderingV2))
+ Some(SortValue(expr, directionV2, nullOrderingV2))
case _ => None
}
- sortOrders.flatMap(translateOortOrder)
+ sortOrders.flatMap(translateSortOrder)
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index b30b460ac67..13d6156aed1 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -25,7 +25,6 @@ import org.apache.spark.{InterruptibleIterator, Partition,
SparkContext, TaskCon
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
@@ -123,7 +122,7 @@ object JDBCRDD extends Logging {
groupByColumns: Option[Array[String]] = None,
sample: Option[TableSampleInfo] = None,
limit: Int = 0,
- sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow]
= {
+ sortOrders: Array[String] = Array.empty[String]): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = if (groupByColumns.isEmpty) {
@@ -166,7 +165,7 @@ private[jdbc] class JDBCRDD(
groupByColumns: Option[Array[String]],
sample: Option[TableSampleInfo],
limit: Int,
- sortOrders: Array[SortOrder])
+ sortOrders: Array[String])
extends RDD[InternalRow](sc, Nil) {
/**
@@ -216,7 +215,7 @@ private[jdbc] class JDBCRDD(
private def getOrderByClause: String = {
if (sortOrders.nonEmpty) {
- s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}"
+ s" ORDER BY ${sortOrders.mkString(", ")}"
} else {
""
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 0f1a1b6dc66..ea841027607 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode,
SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils,
TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId,
stringToDate, stringToTimestamp}
-import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
@@ -305,7 +304,7 @@ private[sql] case class JDBCRelation(
groupByColumns: Option[Array[String]],
tableSample: Option[TableSampleInfo],
limit: Int,
- sortOrders: Array[SortOrder]): RDD[Row] = {
+ sortOrders: Array[String]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 89398fabdc3..03b6544c772 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -381,7 +381,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper wit
order, project, alwaysInline = true) =>
val aliasMap = getAliasMap(project)
val newOrder = order.map(replaceAlias(_,
aliasMap)).asInstanceOf[Seq[SortOrder]]
- val orders = DataSourceStrategy.translateSortOrders(newOrder)
+ val normalizedOrders = DataSourceStrategy.normalizeExprs(
+ newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]]
+ val orders = DataSourceStrategy.translateSortOrders(normalizedOrders)
if (orders.length == order.length) {
val (isPushed, isPartiallyPushed) =
PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
index f68f78d51fd..5ca23e550aa 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
@@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
@@ -34,7 +33,7 @@ case class JDBCScan(
groupByColumns: Option[Array[String]],
tableSample: Option[TableSampleInfo],
pushedLimit: Int,
- sortOrders: Array[SortOrder]) extends V1Scan {
+ sortOrders: Array[String]) extends V1Scan {
override def readSchema(): StructType = prunedSchema
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 8b378d2d87c..a09444d2a3e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -53,7 +53,7 @@ case class JDBCScanBuilder(
private var pushedLimit = 0
- private var sortOrders: Array[SortOrder] = Array.empty[SortOrder]
+ private var sortOrders: Array[String] = Array.empty[String]
override def pushPredicates(predicates: Array[Predicate]): Array[Predicate]
= {
if (jdbcOptions.pushDownPredicate) {
@@ -140,8 +140,14 @@ case class JDBCScanBuilder(
override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = {
if (jdbcOptions.pushDownLimit) {
+ val dialect = JdbcDialects.get(jdbcOptions.url)
+ val compiledOrders = orders.flatMap { order =>
+ dialect.compileExpression(order.expression())
+ .map(sortKey => s"$sortKey ${order.direction()}
${order.nullOrdering()}")
+ }
+ if (orders.length != compiledOrders.length) return false
pushedLimit = limit
- sortOrders = orders
+ sortOrders = compiledOrders
return true
}
false
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 74e226acb7a..178a4600125 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -222,7 +222,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession
with ExplainSuiteHel
checkSortRemoved(df1)
checkLimitRemoved(df1)
checkPushedInfo(df1,
- "PushedFilters: [], PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT
1, ")
+ "PushedFilters: [], PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT
1, ")
checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false)))
val df2 = spark.read
@@ -237,7 +237,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession
with ExplainSuiteHel
checkSortRemoved(df2)
checkLimitRemoved(df2)
checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " +
- "PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ")
+ "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ")
checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false)))
val df3 = spark.read
@@ -252,7 +252,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession
with ExplainSuiteHel
checkSortRemoved(df3, false)
checkLimitRemoved(df3, false)
checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " +
- "PushedTopN: ORDER BY [salary DESC NULLS LAST] LIMIT 1, ")
+ "PushedTopN: ORDER BY [SALARY DESC NULLS LAST] LIMIT 1, ")
checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false)))
val df4 =
@@ -261,7 +261,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession
with ExplainSuiteHel
checkSortRemoved(df4)
checkLimitRemoved(df4)
checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " +
- "PushedTopN: ORDER BY [salary ASC NULLS LAST] LIMIT 1, ")
+ "PushedTopN: ORDER BY [SALARY ASC NULLS LAST] LIMIT 1, ")
checkAnswer(df4, Seq(Row("david")))
val df5 = spark.read.table("h2.test.employee")
@@ -304,6 +304,38 @@ class JDBCV2Suite extends QueryTest with
SharedSparkSession with ExplainSuiteHel
checkLimitRemoved(df8, false)
checkPushedInfo(df8, "PushedFilters: [], ")
checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false)))
+
+ val df9 = spark.read
+ .table("h2.test.employee")
+ .select($"DEPT", $"name", $"SALARY",
+ when(($"SALARY" > 8000).and($"SALARY" < 10000),
$"salary").otherwise(0).as("key"))
+ .sort("key", "dept", "SALARY")
+ .limit(3)
+ checkSortRemoved(df9)
+ checkLimitRemoved(df9)
+ checkPushedInfo(df9, "PushedFilters: [], " +
+ "PushedTopN: ORDER BY [CASE WHEN (SALARY > 8000.00) AND " +
+ "(SALARY < 10000.00) THEN SALARY ELSE 0.00 END ASC NULL..., ")
+ checkAnswer(df9,
+ Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex",
12000, 0)))
+
+ val df10 = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .select($"DEPT", $"name", $"SALARY",
+ when(($"SALARY" > 8000).and($"SALARY" < 10000),
$"salary").otherwise(0).as("key"))
+ .orderBy($"key", $"dept", $"SALARY")
+ .limit(3)
+ checkSortRemoved(df10, false)
+ checkLimitRemoved(df10, false)
+ checkPushedInfo(df10, "PushedFilters: [], " +
+ "PushedTopN: ORDER BY [CASE WHEN (SALARY > 8000.00) AND " +
+ "(SALARY < 10000.00) THEN SALARY ELSE 0.00 END ASC NULL..., ")
+ checkAnswer(df10,
+ Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex",
12000, 0)))
}
test("simple scan with top N: order by with alias") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]