Repository: spark Updated Branches: refs/heads/master bcceab649 -> f5f8e84d9
[SPARK-22614] Dataset API: repartitionByRange(...) ## What changes were proposed in this pull request? This PR introduces a way to explicitly range-partition a Dataset. So far, only round-robin and hash partitioning were possible via `df.repartition(...)`, but sometimes range partitioning might be desirable: e.g. when writing to disk, for better compression without the cost of global sort. The current implementation piggybacks on the existing `RepartitionByExpression` `LogicalPlan` and simply adds the following logic: If its expressions are of type `SortOrder`, then it will do `RangePartitioning`; otherwise `HashPartitioning`. This was by far the least intrusive solution I could come up with. ## How was this patch tested? Unit test for `RepartitionByExpression` changes, a test to ensure we're not changing the behavior of existing `.repartition()` and a few end-to-end tests in `DataFrameSuite`. Author: Adrian Ionescu <adr...@databricks.com> Closes #19828 from adrian-ionescu/repartitionByRange. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f5f8e84d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f5f8e84d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f5f8e84d Branch: refs/heads/master Commit: f5f8e84d9d35751dad51490b6ae22931aa88db7b Parents: bcceab6 Author: Adrian Ionescu <adr...@databricks.com> Authored: Thu Nov 30 15:41:34 2017 -0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Thu Nov 30 15:41:34 2017 -0800 ---------------------------------------------------------------------- .../plans/logical/basicLogicalOperators.scala | 20 +++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 26 +++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 57 ++++++++++++++++++-- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../org/apache/spark/sql/DataFrameSuite.scala | 57 ++++++++++++++++++++ 5 files changed, 157 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f5f8e84d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c2750c3..93de7c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -838,6 +839,25 @@ case class RepartitionByExpression( require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") + val partitioning: Partitioning = { + val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder]) + + require(sortOrder.isEmpty || nonSortOrder.isEmpty, + s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of type " + + "`SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`, which " + + "means `HashPartitioning`. In this case we have:" + + s""" + |SortOrder: ${sortOrder} + |NonSortOrder: ${nonSortOrder} + """.stripMargin) + + if (sortOrder.nonEmpty) { + RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions) + } else { + HashPartitioning(nonSortOrder, numPartitions) + } + } + override def maxRows: Option[Long] = child.maxRows override def shuffle: Boolean = true } http://git-wip-us.apache.org/repos/asf/spark/blob/f5f8e84d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e56a5d6..0e2e706 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning} import org.apache.spark.sql.types._ @@ -514,4 +515,29 @@ class AnalysisSuite extends AnalysisTest with Matchers { Seq("Number of column aliases does not match number of columns. " + "Number of column aliases: 5; number of columns: 4.")) } + + test("SPARK-22614 RepartitionByExpression partitioning") { + def checkPartitioning[T <: Partitioning](numPartitions: Int, exprs: Expression*): Unit = { + val partitioning = RepartitionByExpression(exprs, testRelation2, numPartitions).partitioning + assert(partitioning.isInstanceOf[T]) + } + + checkPartitioning[HashPartitioning](numPartitions = 10, exprs = Literal(20)) + checkPartitioning[HashPartitioning](numPartitions = 10, exprs = 'a.attr, 'b.attr) + + checkPartitioning[RangePartitioning](numPartitions = 10, + exprs = SortOrder(Literal(10), Ascending)) + checkPartitioning[RangePartitioning](numPartitions = 10, + exprs = SortOrder('a.attr, Ascending), SortOrder('b.attr, Descending)) + + intercept[IllegalArgumentException] { + checkPartitioning(numPartitions = 0, exprs = Literal(20)) + } + intercept[IllegalArgumentException] { + checkPartitioning(numPartitions = -1, exprs = Literal(20)) + } + intercept[IllegalArgumentException] { + checkPartitioning(numPartitions = 10, exprs = SortOrder('a.attr, Ascending), 'b.attr) + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/f5f8e84d/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1620ab3..167c9d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2732,8 +2732,18 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) + def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { + // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments. + // However, we don't want to complicate the semantics of this API method. + // Instead, let's give users a friendly error message, pointing them to the new method. + val sortOrders = partitionExprs.filter(_.expr.isInstanceOf[SortOrder]) + if (sortOrders.nonEmpty) throw new IllegalArgumentException( + s"""Invalid partitionExprs specified: $sortOrders + |For range partitioning use repartitionByRange(...) instead. + """.stripMargin) + withTypedPlan { + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) + } } /** @@ -2747,9 +2757,46 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression( - partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) + def repartition(partitionExprs: Column*): Dataset[T] = { + repartition(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*) + } + + /** + * Returns a new Dataset partitioned by the given partitioning expressions into + * `numPartitions`. The resulting Dataset is range partitioned. + * + * At least one partition-by expression must be specified. + * When no explicit sort order is specified, "ascending nulls first" is assumed. + * + * @group typedrel + * @since 2.3.0 + */ + @scala.annotation.varargs + def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { + require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.") + val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match { + case expr: SortOrder => expr + case expr: Expression => SortOrder(expr, Ascending) + }) + withTypedPlan { + RepartitionByExpression(sortOrder, logicalPlan, numPartitions) + } + } + + /** + * Returns a new Dataset partitioned by the given partitioning expressions, using + * `spark.sql.shuffle.partitions` as number of partitions. + * The resulting Dataset is range partitioned. + * + * At least one partition-by expression must be specified. + * When no explicit sort order is specified, "ascending nulls first" is assumed. + * + * @group typedrel + * @since 2.3.0 + */ + @scala.annotation.varargs + def repartitionByRange(partitionExprs: Column*): Dataset[T] = { + repartitionByRange(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/f5f8e84d/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1fe3cb1..9e713cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -482,9 +482,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil case r: logical.Range => execution.RangeExec(r) :: Nil - case logical.RepartitionByExpression(expressions, child, numPartitions) => - exchange.ShuffleExchangeExec(HashPartitioning( - expressions, numPartitions), planLater(child)) :: Nil + case r: logical.RepartitionByExpression => + exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/f5f8e84d/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 72a5cc9..5e4c1a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -358,6 +358,63 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select('key).collect().toSeq) } + test("repartition with SortOrder") { + // passing SortOrder expressions to .repartition() should result in an informative error + + def checkSortOrderErrorMsg[T](data: => Dataset[T]): Unit = { + val ex = intercept[IllegalArgumentException](data) + assert(ex.getMessage.contains("repartitionByRange")) + } + + checkSortOrderErrorMsg { + Seq(0).toDF("a").repartition(2, $"a".asc) + } + + checkSortOrderErrorMsg { + Seq((0, 0)).toDF("a", "b").repartition(2, $"a".asc, $"b") + } + } + + test("repartitionByRange") { + val data1d = Random.shuffle(0.to(9)) + val data2d = data1d.map(i => (i, data1d.size - i)) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, $"val".asc) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(i, i))) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, $"val".desc) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(i, data1d.size - 1 - i))) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, lit(42)) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(0, i))) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, lit(null), $"val".asc, rand()) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(i, i))) + + // .repartitionByRange() assumes .asc by default if no explicit sort order is specified + checkAnswer( + data2d.toDF("a", "b").repartitionByRange(data2d.size, $"a".desc, $"b") + .select(spark_partition_id().as("id"), $"a", $"b"), + data2d.toDF("a", "b").repartitionByRange(data2d.size, $"a".desc, $"b".asc) + .select(spark_partition_id().as("id"), $"a", $"b")) + + // at least one partition-by expression must be specified + intercept[IllegalArgumentException] { + data1d.toDF("val").repartitionByRange(data1d.size) + } + intercept[IllegalArgumentException] { + data1d.toDF("val").repartitionByRange(data1d.size, Seq.empty: _*) + } + } + test("coalesce") { intercept[IllegalArgumentException] { testData.select('key).coalesce(0) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org