This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6f8c620 [SPARK-35558] Optimizes for multi-quantile retrieval
6f8c620 is described below
commit 6f8c62047cea125d52af5dad7fb5ad3eadb7f7d0
Author: Alkis Polyzotis <[email protected]>
AuthorDate: Sat Jun 5 14:25:33 2021 -0500
[SPARK-35558] Optimizes for multi-quantile retrieval
### What changes were proposed in this pull request?
Optimizes the retrieval of approximate quantiles for an array of
percentiles.
* Adds an overload for QuantileSummaries.query that accepts an array of
percentiles and optimizes the computation to do a single pass over the sketch
and avoid redundant computation.
* Modifies the ApproximatePercentiles operator to call into the new method.
All formatting changes are the result of running ./dev/scalafmt
### Why are the changes needed?
The existing implementation does repeated calls per input percentile
resulting in redundant computation.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added unit tests for the new method.
Closes #32700 from alkispoly-db/spark_35558_approx_quants_array.
Authored-by: Alkis Polyzotis <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
---
.../aggregate/ApproximatePercentile.scala | 11 +--
.../sql/catalyst/util/QuantileSummaries.scala | 107 +++++++++++++++------
.../sql/catalyst/util/QuantileSummariesSuite.scala | 79 +++++++++++----
.../spark/sql/execution/stat/StatFunctions.scala | 7 +-
.../org/apache/spark/sql/DataFrameStatSuite.scala | 2 +-
5 files changed, 149 insertions(+), 57 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 38d8d7d..78e64bf 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -261,19 +261,12 @@ object ApproximatePercentile {
* val Array(p25, median, p75) =
percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75))
* }}}
*/
- def getPercentiles(percentages: Array[Double]): Array[Double] = {
+ def getPercentiles(percentages: Array[Double]): Seq[Double] = {
if (!isCompressed) compress()
if (summaries.count == 0 || percentages.length == 0) {
Array.emptyDoubleArray
} else {
- val result = new Array[Double](percentages.length)
- var i = 0
- while (i < percentages.length) {
- // Since summaries.count != 0, the query here never return None.
- result(i) = summaries.query(percentages(i)).get
- i += 1
- }
- result
+ summaries.query(percentages).get
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
index addf140..e0cd613 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
@@ -229,46 +229,99 @@ class QuantileSummaries(
}
/**
- * Runs a query for a given quantile.
+ * Finds the approximate quantile for a percentile, starting at a specific
index in the summary.
+ * This is a helper method that is called as we are making a pass over the
summary and a sorted
+ * sequence of input percentiles.
+ *
+ * @param index The point at which to start scanning the summary for an
approximate value.
+ * @param minRankAtIndex The accumulated minimum rank at the given index.
+ * @param targetError Target error from the summary.
+ * @param percentile The percentile whose value is computed.
+ * @return A tuple (i, r, a) where: i is the updated index for the next
call, r is the updated
+ * rank at i, and a is the approximate quantile.
+ */
+ private def findApproxQuantile(
+ index: Int,
+ minRankAtIndex: Long,
+ targetError: Double,
+ percentile: Double): (Int, Long, Double) = {
+ var curSample = sampled(index)
+ val rank = math.ceil(percentile * count).toLong
+ var i = index
+ var minRank = minRankAtIndex
+ while (i < sampled.length - 1) {
+ val maxRank = minRank + curSample.delta
+ if (maxRank - targetError <= rank && rank <= minRank + targetError) {
+ return (i, minRank, curSample.value)
+ } else {
+ i += 1
+ curSample = sampled(i)
+ minRank += curSample.g
+ }
+ }
+ (sampled.length - 1, 0, sampled.last.value)
+ }
+
+ /**
+ * Runs a query for a given sequence of percentiles.
* The result follows the approximation guarantees detailed above.
* The query can only be run on a compressed summary: you need to call
compress() before using
* it.
*
- * @param quantile the target quantile
- * @return
+ * @param percentiles the target percentiles
+ * @return the corresponding approximate quantiles, in the same order as the
input
*/
- def query(quantile: Double): Option[Double] = {
- require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range
[0.0, 1.0]")
- require(headSampled.isEmpty,
+ def query(percentiles: Seq[Double]): Option[Seq[Double]] = {
+ percentiles.foreach(p =>
+ require(p >= 0 && p <= 1.0, "percentile should be in the range [0.0,
1.0]"))
+ require(
+ headSampled.isEmpty,
"Cannot operate on an uncompressed summary, call compress() first")
if (sampled.isEmpty) return None
- if (quantile <= relativeError) {
- return Some(sampled.head.value)
- }
+ val targetError = sampled.foldLeft(Long.MinValue)((currentMax, stats) =>
+ currentMax.max(stats.delta + stats.g)) / 2
- if (quantile >= 1 - relativeError) {
- return Some(sampled.last.value)
- }
-
- // Target rank
- val rank = math.ceil(quantile * count).toLong
- val targetError = sampled.map(s => s.delta + s.g).max / 2
+ // Index to track the current sample
+ var index = 0
// Minimum rank at current sample
- var minRank = 0L
- var i = 0
- while (i < sampled.length - 1) {
- val curSample = sampled(i)
- minRank += curSample.g
- val maxRank = minRank + curSample.delta
- if (maxRank - targetError <= rank && rank <= minRank + targetError) {
- return Some(curSample.value)
- }
- i += 1
+ var minRank = sampled(0).g
+
+ val sortedPercentiles = percentiles.zipWithIndex.sortBy(_._1)
+ val result = Array.fill(percentiles.length)(0.0)
+ sortedPercentiles.foreach {
+ case (percentile, pos) =>
+ if (percentile <= relativeError) {
+ result(pos) = sampled.head.value
+ } else if (percentile >= 1 - relativeError) {
+ result(pos) = sampled.last.value
+ } else {
+ val (newIndex, newMinRank, approxQuantile) =
+ findApproxQuantile(index, minRank, targetError, percentile)
+ index = newIndex
+ minRank = newMinRank
+ result(pos) = approxQuantile
+ }
}
- Some(sampled.last.value)
+ Some(result)
}
+
+ /**
+ * Runs a query for a given percentile.
+ * The result follows the approximation guarantees detailed above.
+ * The query can only be run on a compressed summary: you need to call
compress() before using
+ * it.
+ *
+ * @param percentile the target percentile
+ * @return the corresponding approximate quantile
+ */
+ def query(percentile: Double): Option[Double] =
+ query(Seq(percentile)) match {
+ case Some(approxSeq) if approxSeq.nonEmpty => Some(approxSeq.head)
+ case _ => None
+ }
+
}
object QuantileSummaries {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala
index e53d0bb..018db3aed 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.util
+import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.apache.spark.SparkFunSuite
@@ -54,25 +55,51 @@ class QuantileSummariesSuite extends SparkFunSuite {
summary
}
- private def checkQuantile(quant: Double, data: Seq[Double], summary:
QuantileSummaries): Unit = {
+ private def validateQuantileApproximation(
+ approx: Double,
+ percentile: Double,
+ data: Seq[Double],
+ summary: QuantileSummaries): Unit = {
+ assert(data.nonEmpty)
+
+ val rankOfValue = data.count(_ <= approx)
+ val rankOfPreValue = data.count(_ < approx)
+ // `rankOfValue` is the last position of the quantile value. If the input
repeats the value
+ // chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2,
then it's
+ // improper to choose the last position as its rank. Instead, we get the
rank by averaging
+ // `rankOfValue` and `rankOfPreValue`.
+ val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0)
+ val lower = math.floor((percentile - summary.relativeError) * data.size)
+ val upper = math.ceil((percentile + summary.relativeError) * data.size)
+ val msg =
+ s"$rank not in [$lower $upper], requested percentile: $percentile,
approx returned: $approx"
+ assert(rank >= lower, msg)
+ assert(rank <= upper, msg)
+ }
+
+ private def checkQuantile(
+ percentile: Double,
+ data: Seq[Double],
+ summary: QuantileSummaries): Unit = {
if (data.nonEmpty) {
- val approx = summary.query(quant).get
- // Get the rank of the approximation.
- val rankOfValue = data.count(_ <= approx)
- val rankOfPreValue = data.count(_ < approx)
- // `rankOfValue` is the last position of the quantile value. If the
input repeats the value
- // chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is
2, then it's
- // improper to choose the last position as its rank. Instead, we get the
rank by averaging
- // `rankOfValue` and `rankOfPreValue`.
- val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0)
- val lower = math.floor((quant - summary.relativeError) * data.size)
- val upper = math.ceil((quant + summary.relativeError) * data.size)
- val msg =
- s"$rank not in [$lower $upper], requested quantile: $quant, approx
returned: $approx"
- assert(rank >= lower, msg)
- assert(rank <= upper, msg)
+ val approx = summary.query(percentile).get
+ validateQuantileApproximation(approx, percentile, data, summary)
+ } else {
+ assert(summary.query(percentile).isEmpty)
+ }
+ }
+
+ private def checkQuantiles(
+ percentiles: Seq[Double],
+ data: Seq[Double],
+ summary: QuantileSummaries): Unit = {
+ if (data.nonEmpty) {
+ val approx = summary.query(percentiles).get
+ for ((q, a) <- percentiles zip approx) {
+ validateQuantileApproximation(a, q, data, summary)
+ }
} else {
- assert(summary.query(quant).isEmpty)
+ assert(summary.query(percentiles).isEmpty)
}
}
@@ -98,6 +125,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
checkQuantile(0.5, data, s)
checkQuantile(0.1, data, s)
checkQuantile(0.001, data, s)
+ checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
+ checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
}
test(s"Some quantile values with epsi=$epsi and seq=$seq_name,
compression=$compression " +
@@ -109,6 +138,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
checkQuantile(0.5, data, s)
checkQuantile(0.1, data, s)
checkQuantile(0.001, data, s)
+ checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
+ checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
}
test(s"Tests on empty data with epsi=$epsi and seq=$seq_name,
compression=$compression") {
@@ -121,6 +152,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
checkQuantile(0.5, emptyData, s)
checkQuantile(0.1, emptyData, s)
checkQuantile(0.001, emptyData, s)
+ checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), emptyData, s)
+ checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), emptyData, s)
}
}
@@ -149,6 +182,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
checkQuantile(0.5, data, s)
checkQuantile(0.1, data, s)
checkQuantile(0.001, data, s)
+ checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
+ checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
}
val (data11, data12) = {
@@ -168,6 +203,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
checkQuantile(0.5, data, s)
checkQuantile(0.1, data, s)
checkQuantile(0.001, data, s)
+ checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
+ checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
}
// length of data21 is 4 * length of data22
@@ -181,10 +218,14 @@ class QuantileSummariesSuite extends SparkFunSuite {
val s2 = buildSummary(data22, epsi, compression)
val s = s1.merge(s2)
// Check all quantiles
+ val percentiles = ArrayBuffer[Double]()
for (queryRank <- 1 to n) {
- val queryQuantile = queryRank.toDouble / n.toDouble
- checkQuantile(queryQuantile, data, s)
+ val percentile = queryRank.toDouble / n.toDouble
+ checkQuantile(percentile, data, s)
+ percentiles += percentile
}
+ checkQuantiles(percentiles.toSeq, data, s)
+ checkQuantiles(percentiles.reverse.toSeq, data, s)
}
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 0a9954e6..5dc0ff0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -102,7 +102,12 @@ object StatFunctions extends Logging {
}
val summaries = df.select(columns:
_*).rdd.treeAggregate(emptySummaries)(apply, merge)
- summaries.map { summary => probabilities.flatMap(summary.query) }
+ summaries.map {
+ summary => summary.query(probabilities) match {
+ case Some(q) => q
+ case None => Seq()
+ }
+ }
}
/** Calculate the Pearson Correlation Coefficient for the given columns */
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index cdd2568..79ab3cd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -204,7 +204,7 @@ class DataFrameStatSuite extends QueryTest with
SharedSparkSession {
val e = intercept[IllegalArgumentException] {
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1),
epsilons.head)
}
- assert(e.getMessage.contains("quantile should be in the range [0.0, 1.0]"))
+ assert(e.getMessage.contains("percentile should be in the range [0.0,
1.0]"))
// relativeError should be non-negative
val e2 = intercept[IllegalArgumentException] {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]