This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4ad7386eefe [SPARK-38978][SQL] DS V2 supports push down OFFSET operator 4ad7386eefe is described below commit 4ad7386eefe0856e500d1a11e2bb992a045ff217 Author: Jiaan Geng <belie...@163.com> AuthorDate: Fri Jun 24 17:33:07 2022 +0800 [SPARK-38978][SQL] DS V2 supports push down OFFSET operator ### What changes were proposed in this pull request? Currently, DS V2 push-down supports `LIMIT` but `OFFSET`. If we can pushing down `OFFSET` to JDBC data source, it will be better performance. ### Why are the changes needed? push down `OFFSET` could improves the performance. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests. Closes #36295 from beliefer/SPARK-38978. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/connector/read/ScanBuilder.java | 3 +- .../sql/connector/read/SupportsPushDownLimit.java | 4 +- ...canBuilder.java => SupportsPushDownOffset.java} | 17 +- .../sql/connector/read/SupportsPushDownTopN.java | 23 +- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/execution/DataSourceScanExec.scala | 9 +- .../execution/datasources/DataSourceStrategy.scala | 6 +- .../execution/datasources/jdbc/JDBCOptions.scala | 5 + .../sql/execution/datasources/jdbc/JDBCRDD.scala | 12 +- .../execution/datasources/jdbc/JDBCRelation.scala | 6 +- .../execution/datasources/v2/PushDownUtils.scala | 15 +- .../datasources/v2/PushedDownOperators.scala | 1 + .../datasources/v2/V2ScanRelationPushDown.scala | 75 ++++- .../execution/datasources/v2/jdbc/JDBCScan.scala | 5 +- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 20 +- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 7 + .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 352 ++++++++++++++++++++- 17 files changed, 514 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java index 27ee534d804..f5ce604148b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java @@ -23,7 +23,8 @@ import org.apache.spark.annotation.Evolving; * An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ * interfaces to do operator push down, and keep the operator push down result in the returned * {@link Scan}. When pushing down operators, the push down order is: - * sample -> filter -> aggregate -> limit -> column pruning. + * sample -> filter -> aggregate -> limit/top-n(sort + limit) -> offset -> + * column pruning. * * @since 3.0.0 */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java index 035154d0845..8a725cd7ed7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java @@ -21,8 +21,8 @@ import org.apache.spark.annotation.Evolving; /** * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to - * push down LIMIT. Please note that the combination of LIMIT with other operations - * such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. + * push down LIMIT. We can push down LIMIT with many other operations if they follow the + * operator order we defined in {@link ScanBuilder}'s class doc. * * @since 3.3.0 */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownOffset.java similarity index 68% copy from sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java copy to sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownOffset.java index 27ee534d804..ffa2cad3715 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownOffset.java @@ -20,14 +20,17 @@ package org.apache.spark.sql.connector.read; import org.apache.spark.annotation.Evolving; /** - * An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ - * interfaces to do operator push down, and keep the operator push down result in the returned - * {@link Scan}. When pushing down operators, the push down order is: - * sample -> filter -> aggregate -> limit -> column pruning. + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down OFFSET. We can push down OFFSET with many other operations if they follow the + * operator order we defined in {@link ScanBuilder}'s class doc. * - * @since 3.0.0 + * @since 3.4.0 */ @Evolving -public interface ScanBuilder { - Scan build(); +public interface SupportsPushDownOffset extends ScanBuilder { + + /** + * Pushes down OFFSET to the data source. + */ + boolean pushOffset(int offset); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java index cba1592c4fa..83d15ba2296 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -22,23 +22,22 @@ import org.apache.spark.sql.connector.expressions.SortOrder; /** * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to - * push down top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N - * with other operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc. - * is NOT pushed down. + * push down top N(query with ORDER BY ... LIMIT n). We can push down top N with many other + * operations if they follow the operator order we defined in {@link ScanBuilder}'s class doc. * * @since 3.3.0 */ @Evolving public interface SupportsPushDownTopN extends ScanBuilder { - /** - * Pushes down top N to the data source. - */ - boolean pushTopN(SortOrder[] orders, int limit); + /** + * Pushes down top N to the data source. + */ + boolean pushTopN(SortOrder[] orders, int limit); - /** - * Whether the top N is partially pushed or not. If it returns true, then Spark will do top N - * again. This method will only be called when {@link #pushTopN} returns true. - */ - default boolean isPartiallyPushed() { return true; } + /** + * Whether the top N is partially pushed or not. If it returns true, then Spark will do top N + * again. This method will only be called when {@link #pushTopN} returns true. + */ + default boolean isPartiallyPushed() { return true; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d79d50ced2b..39d33d80261 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2103,7 +2103,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset by skipping the first `m` rows. + * Returns a new Dataset by skipping the first `n` rows. * * @group typedrel * @since 3.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index f5d349d975f..9e316cc88cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -148,9 +148,11 @@ case class RowDataSourceScanExec( s"ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}" + s" LIMIT ${pushedDownOperators.limit.get}" Some("PushedTopN" -> pushedTopN) - } else { - pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") - } + } else { + pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") + } + + val offsetInfo = pushedDownOperators.offset.map(value => "PushedOffset" -> s"OFFSET $value") val pushedFilters = if (pushedDownOperators.pushedPredicates.nonEmpty) { seqToString(pushedDownOperators.pushedPredicates.map(_.describe())) @@ -164,6 +166,7 @@ case class RowDataSourceScanExec( Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())), "PushedGroupByExpressions" -> seqToString(v.groupByExpressions.map(_.describe())))} ++ topNOrLimitInfo ++ + offsetInfo ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 294889ec449..564668ab392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -346,7 +346,7 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, - PushedDownOperators(None, None, None, Seq.empty, Seq.empty), + PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty), toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -420,7 +420,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None, Seq.empty, Seq.empty), + PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -443,7 +443,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None, Seq.empty, Seq.empty), + PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 80675c7dc47..e725de95335 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -196,6 +196,10 @@ class JDBCOptions( // This only applies to Data Source V2 JDBC val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean + // An option to allow/disallow pushing down OFFSET into V2 JDBC data source + // This only applies to Data Source V2 JDBC + val pushDownOffset = parameters.getOrElse(JDBC_PUSHDOWN_OFFSET, "false").toBoolean + // An option to allow/disallow pushing down TABLESAMPLE into JDBC data source // This only applies to Data Source V2 JDBC val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "false").toBoolean @@ -283,6 +287,7 @@ object JDBCOptions { val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate") val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit") + val JDBC_PUSHDOWN_OFFSET = newOption("pushDownOffset") val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample") val JDBC_KEYTAB = newOption("keytab") val JDBC_PRINCIPAL = newOption("principal") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index e95fe280c76..e23fe05a8a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -124,7 +124,8 @@ object JDBCRDD extends Logging { groupByColumns: Option[Array[String]] = None, sample: Option[TableSampleInfo] = None, limit: Int = 0, - sortOrders: Array[String] = Array.empty[String]): RDD[InternalRow] = { + sortOrders: Array[String] = Array.empty[String], + offset: Int = 0): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -145,7 +146,8 @@ object JDBCRDD extends Logging { groupByColumns, sample, limit, - sortOrders) + sortOrders, + offset) } // scalastyle:on argcount } @@ -167,7 +169,8 @@ private[jdbc] class JDBCRDD( groupByColumns: Option[Array[String]], sample: Option[TableSampleInfo], limit: Int, - sortOrders: Array[String]) + sortOrders: Array[String], + offset: Int) extends RDD[InternalRow](sc, Nil) { /** @@ -305,10 +308,11 @@ private[jdbc] class JDBCRDD( } val myLimitClause: String = dialect.getLimitClause(limit) + val myOffsetClause: String = dialect.getOffsetClause(offset) val sqlText = options.prepareQuery + s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" + - s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause" + s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause $myOffsetClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 427a494eb67..4f19d3df40b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -304,7 +304,8 @@ private[sql] case class JDBCRelation( groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], limit: Int, - sortOrders: Array[String]): RDD[Row] = { + sortOrders: Array[String], + offset: Int): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, @@ -317,7 +318,8 @@ private[sql] case class JDBCRelation( groupByColumns, tableSample, limit, - sortOrders).asInstanceOf[RDD[Row]] + sortOrders, + offset).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 60371d6bf43..5fb16aa7323 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeS import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.filter.Predicate -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources @@ -130,6 +130,19 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down OFFSET to the data source Scan. + * + * @return the Boolean value represents whether to push down. + */ + def pushOffset(scanBuilder: ScanBuilder, offset: Int): Boolean = { + scanBuilder match { + case s: SupportsPushDownOffset => + s.pushOffset(offset) + case _ => false + } + } + /** * Pushes down top N to the data source Scan. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala index a95b4593fc3..49044c6e24d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala @@ -28,6 +28,7 @@ case class PushedDownOperators( aggregation: Option[Aggregation], sample: Option[TableSampleInfo], limit: Option[Int], + offset: Option[Int], sortValues: Seq[SortOrder], pushedPredicates: Seq[Predicate]) { assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index ccdba26aab3..b55aeefca0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum} @@ -31,7 +31,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { @@ -43,7 +43,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit pushDownSample, pushDownFilters, pushDownAggregates, - pushDownLimits, + pushDownLimitAndOffset, pruneColumns) pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) => @@ -407,7 +407,60 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit case other => (other, false) } - def pushDownLimits(plan: LogicalPlan): LogicalPlan = plan.transform { + private def pushDownOffset( + plan: LogicalPlan, + offset: Int): Boolean = plan match { + case sHolder: ScanBuilderHolder => + val isPushed = PushDownUtils.pushOffset(sHolder.builder, offset) + if (isPushed) { + sHolder.pushedOffset = Some(offset) + } + isPushed + case Project(projectList, child) if projectList.forall(_.deterministic) => + pushDownOffset(child, offset) + case _ => false + } + + def pushDownLimitAndOffset(plan: LogicalPlan): LogicalPlan = plan.transform { + case offset @ LimitAndOffset(limit, offsetValue, child) => + val (newChild, canRemoveLimit) = pushDownLimit(child, limit) + if (canRemoveLimit) { + // Try to push down OFFSET only if the LIMIT operator has been pushed and can be removed. + val isPushed = pushDownOffset(newChild, offsetValue) + if (isPushed) { + newChild + } else { + // Keep the OFFSET operator if we failed to push down OFFSET to the data source. + offset.withNewChildren(Seq(newChild)) + } + } else { + // Keep the OFFSET operator if we can't remove LIMIT operator. + offset + } + case globalLimit @ OffsetAndLimit(offset, limit, child) => + // For `df.offset(n).limit(m)`, we can push down `limit(m + n)` first. + val (newChild, canRemoveLimit) = pushDownLimit(child, limit + offset) + if (canRemoveLimit) { + // Try to push down OFFSET only if the LIMIT operator has been pushed and can be removed. + val isPushed = pushDownOffset(newChild, offset) + if (isPushed) { + newChild + } else { + // Still keep the OFFSET operator if we can't push it down. + Offset(Literal(offset), newChild) + } + } else { + // For `df.offset(n).limit(m)`, since we can't push down `limit(m + n)`, + // try to push down `offset(n)` here. + val isPushed = pushDownOffset(child, offset) + if (isPushed) { + // Keep the LIMIT operator if we can't push it down. + Limit(Literal(limit, IntegerType), child) + } else { + // Keep the origin plan if we can't push OFFSET operator and LIMIT operator. + globalLimit + } + } case globalLimit @ Limit(IntegerLiteral(limitValue), child) => val (newChild, canRemoveLimit) = pushDownLimit(child, limitValue) if (canRemoveLimit) { @@ -417,6 +470,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild)) globalLimit.withNewChildren(Seq(newLocalLimit)) } + case offset @ Offset(IntegerLiteral(n), child) => + val isPushed = pushDownOffset(child, n) + if (isPushed) { + child + } else { + offset + } } private def getWrappedScan( @@ -431,7 +491,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit case _ => Array.empty[sources.Filter] } val pushedDownOperators = PushedDownOperators(aggregation, sHolder.pushedSample, - sHolder.pushedLimit, sHolder.sortOrders, sHolder.pushedPredicates) + sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders, sHolder.pushedPredicates) V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan } @@ -444,6 +504,8 @@ case class ScanBuilderHolder( builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None + var pushedOffset: Option[Int] = None + var sortOrders: Seq[V2SortOrder] = Seq.empty[V2SortOrder] var pushedSample: Option[TableSampleInfo] = None @@ -451,7 +513,6 @@ case class ScanBuilderHolder( var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate] } - // A wrapper for v1 scan to carry the translated filters and the handled ones, along with // other pushed down operators. This is required by the physical v1 scan node. case class V1ScanWrapper( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index 5ca23e550aa..ea642a3a5e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -33,7 +33,8 @@ case class JDBCScan( groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], pushedLimit: Int, - sortOrders: Array[String]) extends V1Scan { + sortOrders: Array[String], + pushedOffset: Int) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -49,7 +50,7 @@ case class JDBCScan( pushedAggregateColumn } relation.buildScan(columnList, prunedSchema, pushedPredicates, groupByColumns, tableSample, - pushedLimit, sortOrders) + pushedLimit, sortOrders, pushedOffset) } }.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 3681154a1bc..2dc2015d195 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.Predicate -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -39,6 +39,7 @@ case class JDBCScanBuilder( with SupportsPushDownRequiredColumns with SupportsPushDownAggregates with SupportsPushDownLimit + with SupportsPushDownOffset with SupportsPushDownTableSample with SupportsPushDownTopN with Logging { @@ -53,6 +54,8 @@ case class JDBCScanBuilder( private var pushedLimit = 0 + private var pushedOffset = 0 + private var sortOrders: Array[String] = Array.empty[String] override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { @@ -139,6 +142,19 @@ case class JDBCScanBuilder( false } + override def pushOffset(offset: Int): Boolean = { + if (jdbcOptions.pushDownOffset && !isPartiallyPushed) { + // Spark pushes down LIMIT first, then OFFSET. In SQL statements, OFFSET is applied before + // LIMIT. Here we need to adjust the LIMIT value to match SQL statements. + if (pushedLimit > 0) { + pushedLimit = pushedLimit - offset + } + pushedOffset = offset + return true + } + false + } + override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = { if (jdbcOptions.pushDownLimit) { val dialect = JdbcDialects.get(jdbcOptions.url) @@ -181,6 +197,6 @@ case class JDBCScanBuilder( // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedPredicate, - pushedAggregateList, pushedGroupBys, tableSample, pushedLimit, sortOrders) + pushedAggregateList, pushedGroupBys, tableSample, pushedLimit, sortOrders, pushedOffset) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c54ac84c735..fa2a45be185 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -549,6 +549,13 @@ abstract class JdbcDialect extends Serializable with Logging{ if (limit > 0 ) s"LIMIT $limit" else "" } + /** + * returns the OFFSET clause for the SELECT statement + */ + def getOffsetClause(offset: Integer): String = { + if (offset > 0 ) s"OFFSET $offset" else "" + } + def supportsTableSample: Boolean = false def getTableSample(sample: TableSampleInfo): String = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index a6073566813..07c4261ae51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -23,7 +23,7 @@ import java.util.Properties import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Offset, Sort} import org.apache.spark.sql.connector.IntegralAverage import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog @@ -45,6 +45,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .set("spark.sql.catalog.h2.driver", "org.h2.Driver") .set("spark.sql.catalog.h2.pushDownAggregate", "true") .set("spark.sql.catalog.h2.pushDownLimit", "true") + .set("spark.sql.catalog.h2.pushDownOffset", "true") private def withConnection[T](f: Connection => T): T = { val conn = DriverManager.getConnection(url, new Properties()) @@ -207,6 +208,355 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) } + private def checkOffsetRemoved(df: DataFrame, removed: Boolean = true): Unit = { + val offsets = df.queryExecution.optimizedPlan.collect { + case offset: Offset => offset + } + if (removed) { + assert(offsets.isEmpty) + } else { + assert(offsets.nonEmpty) + } + } + + test("simple scan with OFFSET") { + val df1 = spark.read + .table("h2.test.employee") + .where($"dept" === 1) + .offset(1) + checkOffsetRemoved(df1) + checkPushedInfo(df1, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedOffset: OFFSET 1,") + checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df2 = spark.read + .option("pushDownOffset", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .offset(1) + checkOffsetRemoved(df2, false) + checkPushedInfo(df2, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df3 = spark.read + .table("h2.test.employee") + .where($"dept" === 1) + .sort($"salary") + .offset(1) + checkOffsetRemoved(df3, false) + checkPushedInfo(df3, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkAnswer(df3, Seq(Row(1, "amy", 10000.00, 1000.0, true))) + + val df4 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .filter($"dept" > 1) + .offset(1) + checkOffsetRemoved(df4, false) + checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], ReadSchema:") + checkAnswer(df4, Seq(Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) + + val df5 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .offset(1) + checkOffsetRemoved(df5, false) + checkPushedInfo(df5, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + checkAnswer(df5, Seq(Row(2, 22000.00), Row(6, 12000.00))) + + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val df6 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .offset(1) + checkOffsetRemoved(df6, false) + // OFFSET is pushed down only if all the filters are pushed down + checkPushedInfo(df6, "PushedFilters: [], ") + checkAnswer(df6, Seq(Row(10000.00, 1300.0, "dav"), Row(9000.00, 1200.0, "cat"))) + } + + test("simple scan with LIMIT and OFFSET") { + val df1 = spark.read + .table("h2.test.employee") + .where($"dept" === 1) + .limit(2) + .offset(1) + checkLimitRemoved(df1) + checkOffsetRemoved(df1) + checkPushedInfo(df1, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 2, PushedOffset: OFFSET 1,") + checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df2 = spark.read + .option("pushDownLimit", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .limit(2) + .offset(1) + checkLimitRemoved(df2, false) + checkOffsetRemoved(df2, false) + checkPushedInfo(df2, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df3 = spark.read + .option("pushDownOffset", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .limit(2) + .offset(1) + checkLimitRemoved(df3) + checkOffsetRemoved(df3, false) + checkPushedInfo(df3, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 2, ReadSchema:") + checkAnswer(df3, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df4 = spark.read + .option("pushDownLimit", "false") + .option("pushDownOffset", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .limit(2) + .offset(1) + checkLimitRemoved(df4, false) + checkOffsetRemoved(df4, false) + checkPushedInfo(df4, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkAnswer(df4, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df5 = spark.read + .table("h2.test.employee") + .where($"dept" === 1) + .sort($"salary") + .limit(2) + .offset(1) + checkLimitRemoved(df5) + checkOffsetRemoved(df5) + checkPushedInfo(df5, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " + + "PushedOffset: OFFSET 1, PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2, ReadSchema:") + checkAnswer(df5, Seq(Row(1, "amy", 10000.00, 1000.0, true))) + + val df6 = spark.read + .option("pushDownLimit", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .sort($"salary") + .limit(2) + .offset(1) + checkLimitRemoved(df6, false) + checkOffsetRemoved(df6, false) + checkPushedInfo(df6, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkAnswer(df6, Seq(Row(1, "amy", 10000.00, 1000.0, true))) + + val df7 = spark.read + .option("pushDownOffset", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .sort($"salary") + .limit(2) + .offset(1) + checkLimitRemoved(df7) + checkOffsetRemoved(df7, false) + checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1]," + + " PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2, ReadSchema:") + checkAnswer(df7, Seq(Row(1, "amy", 10000.00, 1000.0, true))) + + val df8 = spark.read + .option("pushDownLimit", "false") + .option("pushDownOffset", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .sort($"salary") + .limit(2) + .offset(1) + checkLimitRemoved(df8, false) + checkOffsetRemoved(df8, false) + checkPushedInfo(df8, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkAnswer(df8, Seq(Row(1, "amy", 10000.00, 1000.0, true))) + + val df9 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .filter($"dept" > 1) + .limit(2) + .offset(1) + checkLimitRemoved(df9, false) + checkOffsetRemoved(df9, false) + checkPushedInfo(df9, + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 2, ReadSchema:") + checkAnswer(df9, Seq(Row(2, "david", 10000.00, 1300.0, true))) + + val df10 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .limit(2) + .offset(1) + checkLimitRemoved(df10, false) + checkOffsetRemoved(df10, false) + checkPushedInfo(df10, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + checkAnswer(df10, Seq(Row(2, 22000.00))) + + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val df11 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .limit(2) + .offset(1) + checkLimitRemoved(df11, false) + checkOffsetRemoved(df11, false) + checkPushedInfo(df11, "PushedFilters: [], ") + checkAnswer(df11, Seq(Row(9000.00, 1200.0, "cat"))) + } + + test("simple scan with OFFSET and LIMIT") { + val df1 = spark.read + .table("h2.test.employee") + .where($"dept" === 1) + .offset(1) + .limit(1) + checkLimitRemoved(df1) + checkOffsetRemoved(df1) + checkPushedInfo(df1, + "[DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 2, PushedOffset: OFFSET 1,") + checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df2 = spark.read + .option("pushDownOffset", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .offset(1) + .limit(1) + checkLimitRemoved(df2) + checkOffsetRemoved(df2, false) + checkPushedInfo(df2, + "[DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 2, ReadSchema:") + checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df3 = spark.read + .option("pushDownLimit", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .offset(1) + .limit(1) + checkLimitRemoved(df3, false) + checkOffsetRemoved(df3) + checkPushedInfo(df3, + "[DEPT IS NOT NULL, DEPT = 1], PushedOffset: OFFSET 1, ReadSchema:") + checkAnswer(df3, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df4 = spark.read + .option("pushDownOffset", "false") + .option("pushDownLimit", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .offset(1) + .limit(1) + checkLimitRemoved(df4, false) + checkOffsetRemoved(df4, false) + checkPushedInfo(df4, + "[DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkAnswer(df4, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df5 = spark.read + .table("h2.test.employee") + .where($"dept" === 1) + .sort($"salary") + .offset(1) + .limit(1) + checkLimitRemoved(df5) + checkOffsetRemoved(df5) + checkPushedInfo(df5, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " + + "PushedOffset: OFFSET 1, PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2, ReadSchema:") + checkAnswer(df5, Seq(Row(1, "amy", 10000.00, 1000.0, true))) + + val df6 = spark.read + .option("pushDownOffset", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .sort($"salary") + .offset(1) + .limit(1) + checkLimitRemoved(df6) + checkOffsetRemoved(df6, false) + checkPushedInfo(df6, "[DEPT IS NOT NULL, DEPT = 1]," + + " PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 2, ReadSchema:") + checkAnswer(df6, Seq(Row(1, "amy", 10000.00, 1000.0, true))) + + val df7 = spark.read + .option("pushDownLimit", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .sort($"salary") + .offset(1) + .limit(1) + checkLimitRemoved(df7, false) + checkOffsetRemoved(df7, false) + checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkAnswer(df7, Seq(Row(1, "amy", 10000.00, 1000.0, true))) + + val df8 = spark.read + .option("pushDownOffset", "false") + .option("pushDownLimit", "false") + .table("h2.test.employee") + .where($"dept" === 1) + .sort($"salary") + .offset(1) + .limit(1) + checkLimitRemoved(df8, false) + checkOffsetRemoved(df8, false) + checkPushedInfo(df8, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ReadSchema:") + checkAnswer(df8, Seq(Row(1, "amy", 10000.00, 1000.0, true))) + + val df9 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .filter($"dept" > 1) + .offset(1) + .limit(1) + checkLimitRemoved(df9, false) + checkOffsetRemoved(df9, false) + checkPushedInfo(df9, + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 2, ReadSchema:") + checkAnswer(df9, Seq(Row(2, "david", 10000.00, 1300.0, true))) + + val df10 = sql("SELECT dept, sum(salary) FROM h2.test.employee group by dept LIMIT 1 OFFSET 1") + checkLimitRemoved(df10, false) + checkOffsetRemoved(df10, false) + checkPushedInfo(df10, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + checkAnswer(df10, Seq(Row(2, 22000.00))) + + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val df11 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .offset(1) + .limit(1) + checkLimitRemoved(df11, false) + checkOffsetRemoved(df11, false) + checkPushedInfo(df11, "PushedFilters: [], ") + checkAnswer(df11, Seq(Row(9000.00, 1200.0, "cat"))) + } + private def checkSortRemoved(df: DataFrame, removed: Boolean = true): Unit = { val sorts = df.queryExecution.optimizedPlan.collect { case s: Sort => s --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org