This is an automated email from the ASF dual-hosted git repository. joshrosen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1178bcecc83 [SPARK-40211][CORE][SQL] Allow customize initial partitions number in take() behavior 1178bcecc83 is described below commit 1178bcecc83925674cad4364537a39eba03e423e Author: Ziqi Liu <ziqi....@databricks.com> AuthorDate: Fri Aug 26 17:18:09 2022 -0700 [SPARK-40211][CORE][SQL] Allow customize initial partitions number in take() behavior ### What changes were proposed in this pull request? [SPARK-40211](https://issues.apache.org/jira/browse/SPARK-40211) add a `initialNumPartitions` config parameter to allow customizing initial partitions to try in `take()` ### Why are the changes needed? Currently, the initial partitions to try to hardcode to `1`, which might cause unnecessary overhead. By setting this new configuration to a high value we could effectively mitigate the “run multiple jobs” overhead in take behavior. We could also set it to higher-than-1-but-still-small values (like, say, 10) to achieve a middle-ground trade-off. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? Unit test Closes #37661 from liuzqt/SPARK-40211. Authored-by: Ziqi Liu <ziqi....@databricks.com> Signed-off-by: Josh Rosen <joshro...@databricks.com> --- .../org/apache/spark/internal/config/package.scala | 7 ++++ .../org/apache/spark/rdd/AsyncRDDActions.scala | 17 +++++---- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 ++--- .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 39 +++++++++++++++++++- .../org/apache/spark/sql/internal/SQLConf.scala | 12 +++++++ .../org/apache/spark/sql/execution/SparkPlan.scala | 12 +++---- .../org/apache/spark/sql/ConfigBehaviorSuite.scala | 41 ++++++++++++++++++++++ 7 files changed, 118 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9d1a56843ca..07d3d3e0778 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1956,6 +1956,13 @@ package object config { .intConf .createWithDefault(10) + private[spark] val RDD_LIMIT_INITIAL_NUM_PARTITIONS = + ConfigBuilder("spark.rdd.limit.initialNumPartitions") + .version("3.4.0") + .intConf + .checkValue(_ > 0, "value should be positive") + .createWithDefault(1) + private[spark] val RDD_LIMIT_SCALE_UP_FACTOR = ConfigBuilder("spark.rdd.limit.scaleUpFactor") .version("2.1.0") diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index d6379156ccf..9f89c82db31 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -25,6 +25,7 @@ import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, JobSubmitter} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{RDD_LIMIT_INITIAL_NUM_PARTITIONS, RDD_LIMIT_SCALE_UP_FACTOR} import org.apache.spark.util.ThreadUtils /** @@ -72,6 +73,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val results = new ArrayBuffer[T] val totalParts = self.partitions.length + val scaleUpFactor = Math.max(self.conf.get(RDD_LIMIT_SCALE_UP_FACTOR), 2) + /* Recursively triggers jobs to scan partitions until either the requested number of elements are retrieved, or the partitions to scan are exhausted. @@ -84,18 +87,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1L + var numPartsToTry = self.conf.get(RDD_LIMIT_INITIAL_NUM_PARTITIONS) if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. - // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. We also cap the estimation in the end. - if (results.size == 0) { - numPartsToTry = partsScanned * 4L + // If we didn't find any rows after the previous iteration, multiply by + // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need + // to try, but overestimate it by 50%. We also cap the estimation in the end. + if (results.isEmpty) { + numPartsToTry = partsScanned * scaleUpFactor } else { // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L) + numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b7284d25122..ab175595c19 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1445,12 +1445,12 @@ abstract class RDD[T: ClassTag]( while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1L + var numPartsToTry = conf.get(RDD_LIMIT_INITIAL_NUM_PARTITIONS) val left = num - buf.size if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. - // Otherwise, interpolate the number of partitions we need to try, but overestimate - // it by 50%. We also cap the estimation in the end. + // If we didn't find any rows after the previous iteration, multiply by + // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need + // to try, but overestimate it by 50%. We also cap the estimation in the end. if (buf.isEmpty) { numPartsToTry = partsScanned * scaleUpFactor } else { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ccef00c8e9d..c64573f7a0a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.rdd import java.io.{File, IOException, ObjectInputStream, ObjectOutputStream} import java.lang.management.ManagementFactory +import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -32,8 +33,9 @@ import org.scalatest.concurrent.Eventually import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD +import org.apache.spark.internal.config.{RDD_LIMIT_INITIAL_NUM_PARTITIONS, RDD_PARALLEL_LISTING_THRESHOLD} import org.apache.spark.rdd.RDDSuiteUtils._ +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.util.{ThreadUtils, Utils} class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { @@ -1255,6 +1257,41 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { assert(numPartsPerLocation(locations(1)) > 0.4 * numCoalescedPartitions) } + test("SPARK-40211: customize initialNumPartitions for take") { + val totalElements = 100 + val numToTake = 50 + val rdd = sc.parallelize(0 to totalElements, totalElements) + import scala.language.reflectiveCalls + val jobCountListener = new SparkListener { + private var count: AtomicInteger = new AtomicInteger(0) + def getCount: Int = count.get + def reset(): Unit = count.set(0) + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + count.incrementAndGet() + } + } + sc.addSparkListener(jobCountListener) + // with default RDD_LIMIT_INITIAL_NUM_PARTITIONS = 1, expecting multiple jobs + rdd.take(numToTake) + sc.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount > 1) + jobCountListener.reset() + rdd.takeAsync(numToTake).get() + sc.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount > 1) + + // setting RDD_LIMIT_INITIAL_NUM_PARTITIONS to large number(1000), expecting only 1 job + sc.conf.set(RDD_LIMIT_INITIAL_NUM_PARTITIONS, 1000) + jobCountListener.reset() + rdd.take(numToTake) + sc.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount == 1) + jobCountListener.reset() + rdd.takeAsync(numToTake).get() + sc.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount == 1) + } + // NOTE // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests // running after them and if they access sc those tests will fail as sc is already closed, because diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index eb7a6a9105e..de25c19a26e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -526,6 +526,16 @@ object SQLConf { .checkValue(_ >= 1, "The shuffle hash join factor cannot be negative.") .createWithDefault(3) + val LIMIT_INITIAL_NUM_PARTITIONS = buildConf("spark.sql.limit.initialNumPartitions") + .internal() + .doc("Initial number of partitions to try when executing a take on a query. Higher values " + + "lead to more partitions read. Lower values might lead to longer execution times as more" + + "jobs will be run") + .version("3.4.0") + .intConf + .checkValue(_ > 0, "value should be positive") + .createWithDefault(1) + val LIMIT_SCALE_UP_FACTOR = buildConf("spark.sql.limit.scaleUpFactor") .internal() .doc("Minimal increase rate in number of partitions between attempts when executing a take " + @@ -4316,6 +4326,8 @@ class SQLConf extends Serializable with Logging { def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD) + def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS) + def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) def advancedPartitionPredicatePushdownEnabled: Boolean = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 50a309d443a..a56732fdc12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -469,7 +469,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (n == 0) { return new Array[InternalRow](0) } - + val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2) + // TODO: refactor and reuse the code from RDD's take() val childRDD = getByteArrayRdd(n, takeFromEnd) val buf = if (takeFromEnd) new ListBuffer[InternalRow] else new ArrayBuffer[InternalRow] @@ -478,12 +479,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ while (buf.length < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1L + var numPartsToTry = conf.limitInitialNumPartitions if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. - // Otherwise, interpolate the number of partitions we need to try, but overestimate - // it by 50%. We also cap the estimation in the end. - val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2) + // If we didn't find any rows after the previous iteration, multiply by + // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need + // to try, but overestimate it by 50%. We also cap the estimation in the end. if (buf.isEmpty) { numPartsToTry = partsScanned * limitScaleUpFactor } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index 36989efbe87..9c442456ce8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql +import java.util.concurrent.atomic.AtomicInteger + import org.apache.commons.math3.stat.inference.ChiSquareTest +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -68,4 +71,42 @@ class ConfigBehaviorSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-40211: customize initialNumPartitions for take") { + val totalElements = 100 + val numToTake = 50 + import scala.language.reflectiveCalls + val jobCountListener = new SparkListener { + private var count: AtomicInteger = new AtomicInteger(0) + def getCount: Int = count.get + def reset(): Unit = count.set(0) + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + count.incrementAndGet() + } + } + spark.sparkContext.addSparkListener(jobCountListener) + val df = spark.range(0, totalElements, 1, totalElements) + + // with default LIMIT_INITIAL_NUM_PARTITIONS = 1, expecting multiple jobs + df.take(numToTake) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount > 1) + jobCountListener.reset() + df.tail(numToTake) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount > 1) + + // setting LIMIT_INITIAL_NUM_PARTITIONS to large number(1000), expecting only 1 job + + withSQLConf(SQLConf.LIMIT_INITIAL_NUM_PARTITIONS.key -> "1000") { + jobCountListener.reset() + df.take(numToTake) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount == 1) + jobCountListener.reset() + df.tail(numToTake) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount == 1) + } + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org