This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dad2b763f004 [SPARK-49419][CONNECT][SQL] Create shared
DataFrameStatFunctions
dad2b763f004 is described below
commit dad2b763f004a72613276b31738a958e80d02b37
Author: Herman van Hovell <[email protected]>
AuthorDate: Thu Aug 29 11:00:51 2024 -0400
[SPARK-49419][CONNECT][SQL] Create shared DataFrameStatFunctions
### What changes were proposed in this pull request?
This PR creates an interface for DataFrameStatFunctions that is shared
between Classic and Connect.
### Why are the changes needed?
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47898 from hvanhovell/SPARK-49419.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../apache/spark/sql/DataFrameStatFunctions.scala | 588 ++-------------------
.../main/scala/org/apache/spark/sql/Dataset.scala | 2 +-
sql/api/pom.xml | 5 +
.../spark/sql/api}/DataFrameStatFunctions.scala | 106 ++--
.../scala/org/apache/spark/sql/api/Dataset.scala | 12 +
.../apache/spark/sql/DataFrameStatFunctions.scala | 494 +----------------
6 files changed, 110 insertions(+), 1097 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 6365f387afce..9f5ada0d7ec3 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -18,89 +18,23 @@
package org.apache.spark.sql
import java.{lang => jl, util => ju}
-import java.io.ByteArrayInputStream
-
-import scala.jdk.CollectionConverters._
import org.apache.spark.connect.proto.{Relation, StatSampleBy}
import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BinaryEncoder, PrimitiveDoubleEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
PrimitiveDoubleEncoder}
import org.apache.spark.sql.functions.lit
-import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
/**
* Statistic functions for `DataFrame`s.
*
* @since 3.4.0
*/
-final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession,
root: Relation) {
- import sparkSession.RichColumn
-
- /**
- * Calculates the approximate quantiles of a numerical column of a DataFrame.
- *
- * The result of this algorithm has the following deterministic bound: If
the DataFrame has N
- * elements and if we request the quantile at probability `p` up to error
`err`, then the
- * algorithm will return a sample `x` from the DataFrame so that the *exact*
rank of `x` is
- * close to (p * N). More precisely,
- *
- * {{{
- * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N)
- * }}}
- *
- * This method implements a variation of the Greenwald-Khanna algorithm
(with some speed
- * optimizations). The algorithm was first present in <a
- * href="https://doi.org/10.1145/375663.375670"> Space-efficient Online
Computation of Quantile
- * Summaries</a> by Greenwald and Khanna.
- *
- * @param col
- * the name of the numerical column
- * @param probabilities
- * a list of quantile probabilities Each number must belong to [0, 1]. For
example 0 is the
- * minimum, 0.5 is the median, 1 is the maximum.
- * @param relativeError
- * The relative target precision to achieve (greater than or equal to 0).
If set to zero, the
- * exact quantiles are computed, which could be very expensive. Note that
values greater than
- * 1 are accepted but give the same result as 1.
- * @return
- * the approximate quantiles at the given probabilities
- *
- * @note
- * null and NaN values will be removed from the numerical column before
calculation. If the
- * dataframe is empty or the column only contains null or NaN, an empty
array is returned.
- *
- * @since 3.4.0
- */
- def approxQuantile(
- col: String,
- probabilities: Array[Double],
- relativeError: Double): Array[Double] = {
- approxQuantile(Array(col), probabilities, relativeError).head
- }
+final class DataFrameStatFunctions private[sql] (protected val df: DataFrame)
+ extends api.DataFrameStatFunctions[Dataset] {
+ private def root: Relation = df.plan.getRoot
+ private val sparkSession: SparkSession = df.sparkSession
- /**
- * Calculates the approximate quantiles of numerical columns of a DataFrame.
- * @see
- * `approxQuantile(col:Str* approxQuantile)` for detailed description.
- *
- * @param cols
- * the names of the numerical columns
- * @param probabilities
- * a list of quantile probabilities Each number must belong to [0, 1]. For
example 0 is the
- * minimum, 0.5 is the median, 1 is the maximum.
- * @param relativeError
- * The relative target precision to achieve (greater than or equal to 0).
If set to zero, the
- * exact quantiles are computed, which could be very expensive. Note that
values greater than
- * 1 are accepted but give the same result as 1.
- * @return
- * the approximate quantiles at the given probabilities of each column
- *
- * @note
- * null and NaN values will be ignored in numerical columns before
calculation. For columns
- * only containing null or NaN values, an empty array is returned.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def approxQuantile(
cols: Array[String],
probabilities: Array[Double],
@@ -120,24 +54,7 @@ final class DataFrameStatFunctions private[sql]
(sparkSession: SparkSession, roo
.head()
}
- /**
- * Calculate the sample covariance of two numerical columns of a DataFrame.
- * @param col1
- * the name of the first column
- * @param col2
- * the name of the second column
- * @return
- * the covariance of the two columns.
- *
- * {{{
- * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1",
rand(seed=10))
- * .withColumn("rand2", rand(seed=27))
- * df.stat.cov("rand1", "rand2")
- * res1: Double = 0.065...
- * }}}
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def cov(col1: String, col2: String): Double = {
sparkSession
.newDataset(PrimitiveDoubleEncoder) { builder =>
@@ -146,27 +63,7 @@ final class DataFrameStatFunctions private[sql]
(sparkSession: SparkSession, roo
.head()
}
- /**
- * Calculates the correlation of two columns of a DataFrame. Currently only
supports the Pearson
- * Correlation Coefficient. For Spearman Correlation, consider using RDD
methods found in
- * MLlib's Statistics.
- *
- * @param col1
- * the name of the column
- * @param col2
- * the name of the column to calculate the correlation against
- * @return
- * The Pearson Correlation Coefficient as a Double.
- *
- * {{{
- * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1",
rand(seed=10))
- * .withColumn("rand2", rand(seed=27))
- * df.stat.corr("rand1", "rand2")
- * res1: Double = 0.613...
- * }}}
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def corr(col1: String, col2: String, method: String): Double = {
require(
method == "pearson",
@@ -179,289 +76,48 @@ final class DataFrameStatFunctions private[sql]
(sparkSession: SparkSession, roo
.head()
}
- /**
- * Calculates the Pearson Correlation Coefficient of two columns of a
DataFrame.
- *
- * @param col1
- * the name of the column
- * @param col2
- * the name of the column to calculate the correlation against
- * @return
- * The Pearson Correlation Coefficient as a Double.
- *
- * {{{
- * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1",
rand(seed=10))
- * .withColumn("rand2", rand(seed=27))
- * df.stat.corr("rand1", "rand2", "pearson")
- * res1: Double = 0.613...
- * }}}
- *
- * @since 3.4.0
- */
- def corr(col1: String, col2: String): Double = {
- corr(col1, col2, "pearson")
- }
-
- /**
- * Computes a pair-wise frequency table of the given columns. Also known as
a contingency table.
- * The first column of each row will be the distinct values of `col1` and
the column names will
- * be the distinct values of `col2`. The name of the first column will be
`col1_col2`. Counts
- * will be returned as `Long`s. Pairs that have no occurrences will have
zero as their counts.
- * Null elements will be replaced by "null", and back ticks will be dropped
from elements if
- * they exist.
- *
- * @param col1
- * The name of the first column. Distinct items will make the first item
of each row.
- * @param col2
- * The name of the second column. Distinct items will make the column
names of the DataFrame.
- * @return
- * A DataFrame containing for the contingency table.
- *
- * {{{
- * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2,
3), (3, 2), (3, 3)))
- * .toDF("key", "value")
- * val ct = df.stat.crosstab("key", "value")
- * ct.show()
- * +---------+---+---+---+
- * |key_value| 1| 2| 3|
- * +---------+---+---+---+
- * | 2| 2| 0| 1|
- * | 1| 1| 1| 0|
- * | 3| 0| 1| 1|
- * +---------+---+---+---+
- * }}}
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def crosstab(col1: String, col2: String): DataFrame = {
sparkSession.newDataFrame { builder =>
builder.getCrosstabBuilder.setInput(root).setCol1(col1).setCol2(col2)
}
}
- /**
- * Finding frequent items for columns, possibly with false positives. Using
the frequent element
- * count algorithm described in <a
href="https://doi.org/10.1145/762471.762473">here</a>,
- * proposed by Karp, Schenker, and Papadimitriou. The `support` should be
greater than 1e-4.
- *
- * This function is meant for exploratory data analysis, as we make no
guarantee about the
- * backward compatibility of the schema of the resulting `DataFrame`.
- *
- * @param cols
- * the names of the columns to search frequent items in.
- * @param support
- * The minimum frequency for an item to be considered `frequent`. Should
be greater than 1e-4.
- * @return
- * A Local DataFrame with the Array of frequent items for each column.
- *
- * {{{
- * val rows = Seq.tabulate(100) { i =>
- * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0)
- * }
- * val df = spark.createDataFrame(rows).toDF("a", "b")
- * // find the items with a frequency greater than 0.4 (observed 40% of
the time) for columns
- * // "a" and "b"
- * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4)
- * freqSingles.show()
- * +-----------+-------------+
- * |a_freqItems| b_freqItems|
- * +-----------+-------------+
- * | [1, 99]|[-1.0, -99.0]|
- * +-----------+-------------+
- * // find the pair of items with a frequency greater than 0.1 in columns
"a" and "b"
- * val pairDf = df.select(struct("a", "b").as("a-b"))
- * val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1)
- * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show()
- * +----------+
- * | freq_ab|
- * +----------+
- * | [1,-1.0]|
- * | ... |
- * +----------+
- * }}}
- *
- * @since 3.4.0
- */
- def freqItems(cols: Array[String], support: Double): DataFrame = {
- sparkSession.newDataFrame { builder =>
- val freqItemsBuilder =
builder.getFreqItemsBuilder.setInput(root).setSupport(support)
- cols.foreach(freqItemsBuilder.addCols)
- }
- }
+ /** @inheritdoc */
+ override def freqItems(cols: Array[String], support: Double): DataFrame =
+ super.freqItems(cols, support)
- /**
- * Finding frequent items for columns, possibly with false positives. Using
the frequent element
- * count algorithm described in <a
href="https://doi.org/10.1145/762471.762473">here</a>,
- * proposed by Karp, Schenker, and Papadimitriou. Uses a `default` support
of 1%.
- *
- * This function is meant for exploratory data analysis, as we make no
guarantee about the
- * backward compatibility of the schema of the resulting `DataFrame`.
- *
- * @param cols
- * the names of the columns to search frequent items in.
- * @return
- * A Local DataFrame with the Array of frequent items for each column.
- *
- * @since 3.4.0
- */
- def freqItems(cols: Array[String]): DataFrame = {
- freqItems(cols, 0.01)
- }
+ /** @inheritdoc */
+ override def freqItems(cols: Array[String]): DataFrame =
super.freqItems(cols)
+
+ /** @inheritdoc */
+ override def freqItems(cols: Seq[String]): DataFrame = super.freqItems(cols)
- /**
- * (Scala-specific) Finding frequent items for columns, possibly with false
positives. Using the
- * frequent element count algorithm described in <a
- * href="https://doi.org/10.1145/762471.762473">here</a>, proposed by Karp,
Schenker, and
- * Papadimitriou.
- *
- * This function is meant for exploratory data analysis, as we make no
guarantee about the
- * backward compatibility of the schema of the resulting `DataFrame`.
- *
- * @param cols
- * the names of the columns to search frequent items in.
- * @return
- * A Local DataFrame with the Array of frequent items for each column.
- *
- * {{{
- * val rows = Seq.tabulate(100) { i =>
- * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0)
- * }
- * val df = spark.createDataFrame(rows).toDF("a", "b")
- * // find the items with a frequency greater than 0.4 (observed 40% of
the time) for columns
- * // "a" and "b"
- * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4)
- * freqSingles.show()
- * +-----------+-------------+
- * |a_freqItems| b_freqItems|
- * +-----------+-------------+
- * | [1, 99]|[-1.0, -99.0]|
- * +-----------+-------------+
- * // find the pair of items with a frequency greater than 0.1 in columns
"a" and "b"
- * val pairDf = df.select(struct("a", "b").as("a-b"))
- * val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1)
- * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show()
- * +----------+
- * | freq_ab|
- * +----------+
- * | [1,-1.0]|
- * | ... |
- * +----------+
- * }}}
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def freqItems(cols: Seq[String], support: Double): DataFrame = {
- freqItems(cols.toArray, support)
+ df.sparkSession.newDataFrame { builder =>
+ val freqItemsBuilder = builder.getFreqItemsBuilder
+ .setInput(df.plan.getRoot)
+ .setSupport(support)
+ cols.foreach(freqItemsBuilder.addCols)
+ }
}
- /**
- * (Scala-specific) Finding frequent items for columns, possibly with false
positives. Using the
- * frequent element count algorithm described in <a
- * href="https://doi.org/10.1145/762471.762473">here</a>, proposed by Karp,
Schenker, and
- * Papadimitriou. Uses a `default` support of 1%.
- *
- * This function is meant for exploratory data analysis, as we make no
guarantee about the
- * backward compatibility of the schema of the resulting `DataFrame`.
- *
- * @param cols
- * the names of the columns to search frequent items in.
- * @return
- * A Local DataFrame with the Array of frequent items for each column.
- *
- * @since 3.4.0
- */
- def freqItems(cols: Seq[String]): DataFrame = {
- freqItems(cols.toArray, 0.01)
- }
+ /** @inheritdoc */
+ override def sampleBy[T](col: String, fractions: Map[T, Double], seed:
Long): DataFrame =
+ super.sampleBy(col, fractions, seed)
- /**
- * Returns a stratified sample without replacement based on the fraction
given on each stratum.
- * @param col
- * column that defines strata
- * @param fractions
- * sampling fraction for each stratum. If a stratum is not specified, we
treat its fraction as
- * zero.
- * @param seed
- * random seed
- * @tparam T
- * stratum type
- * @return
- * a new `DataFrame` that represents the stratified sample
- *
- * {{{
- * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2,
3), (3, 2),
- * (3, 3))).toDF("key", "value")
- * val fractions = Map(1 -> 1.0, 3 -> 0.5)
- * df.stat.sampleBy("key", fractions, 36L).show()
- * +---+-----+
- * |key|value|
- * +---+-----+
- * | 1| 1|
- * | 1| 2|
- * | 3| 2|
- * +---+-----+
- * }}}
- *
- * @since 3.4.0
- */
- def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long):
DataFrame = {
- sampleBy(Column(col), fractions, seed)
- }
+ /** @inheritdoc */
+ override def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed:
Long): DataFrame =
+ super.sampleBy(col, fractions, seed)
- /**
- * Returns a stratified sample without replacement based on the fraction
given on each stratum.
- * @param col
- * column that defines strata
- * @param fractions
- * sampling fraction for each stratum. If a stratum is not specified, we
treat its fraction as
- * zero.
- * @param seed
- * random seed
- * @tparam T
- * stratum type
- * @return
- * a new `DataFrame` that represents the stratified sample
- *
- * @since 3.4.0
- */
- def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long):
DataFrame = {
- sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
- }
+ /** @inheritdoc */
+ override def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed:
Long): DataFrame =
+ super.sampleBy(col, fractions, seed)
- /**
- * Returns a stratified sample without replacement based on the fraction
given on each stratum.
- * @param col
- * column that defines strata
- * @param fractions
- * sampling fraction for each stratum. If a stratum is not specified, we
treat its fraction as
- * zero.
- * @param seed
- * random seed
- * @tparam T
- * stratum type
- * @return
- * a new `DataFrame` that represents the stratified sample
- *
- * The stratified sample can be performed over multiple columns:
- * {{{
- * import org.apache.spark.sql.Row
- * import org.apache.spark.sql.functions.struct
- *
- * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10),
("Nico", 8), ("Bob", 17),
- * ("Alice", 10))).toDF("name", "age")
- * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0)
- * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show()
- * +-----+---+
- * | name|age|
- * +-----+---+
- * | Nico| 8|
- * |Alice| 10|
- * +-----+---+
- * }}}
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long):
DataFrame = {
+ import sparkSession.RichColumn
require(
fractions.values.forall(p => p >= 0.0 && p <= 1.0),
s"Fractions must be in [0, 1], but got $fractions.")
@@ -479,180 +135,6 @@ final class DataFrameStatFunctions private[sql]
(sparkSession: SparkSession, roo
}
}
}
-
- /**
- * (Java-specific) Returns a stratified sample without replacement based on
the fraction given
- * on each stratum.
- * @param col
- * column that defines strata
- * @param fractions
- * sampling fraction for each stratum. If a stratum is not specified, we
treat its fraction as
- * zero.
- * @param seed
- * random seed
- * @tparam T
- * stratum type
- * @return
- * a new `DataFrame` that represents the stratified sample
- *
- * @since 3.4.0
- */
- def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long):
DataFrame = {
- sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
- }
-
- /**
- * Builds a Count-min Sketch over a specified column.
- *
- * @param colName
- * name of the column over which the sketch is built
- * @param depth
- * depth of the sketch
- * @param width
- * width of the sketch
- * @param seed
- * random seed
- * @return
- * a `CountMinSketch` over column `colName`
- * @since 3.4.0
- */
- def countMinSketch(colName: String, depth: Int, width: Int, seed: Int):
CountMinSketch = {
- countMinSketch(Column(colName), depth, width, seed)
- }
-
- /**
- * Builds a Count-min Sketch over a specified column.
- *
- * @param colName
- * name of the column over which the sketch is built
- * @param eps
- * relative error of the sketch
- * @param confidence
- * confidence of the sketch
- * @param seed
- * random seed
- * @return
- * a `CountMinSketch` over column `colName`
- * @since 3.4.0
- */
- def countMinSketch(
- colName: String,
- eps: Double,
- confidence: Double,
- seed: Int): CountMinSketch = {
- countMinSketch(Column(colName), eps, confidence, seed)
- }
-
- /**
- * Builds a Count-min Sketch over a specified column.
- *
- * @param col
- * the column over which the sketch is built
- * @param depth
- * depth of the sketch
- * @param width
- * width of the sketch
- * @param seed
- * random seed
- * @return
- * a `CountMinSketch` over column `colName`
- * @since 3.4.0
- */
- def countMinSketch(col: Column, depth: Int, width: Int, seed: Int):
CountMinSketch = {
- countMinSketch(col, eps = 2.0 / width, confidence = 1 - 1 / Math.pow(2,
depth), seed)
- }
-
- /**
- * Builds a Count-min Sketch over a specified column.
- *
- * @param col
- * the column over which the sketch is built
- * @param eps
- * relative error of the sketch
- * @param confidence
- * confidence of the sketch
- * @param seed
- * random seed
- * @return
- * a `CountMinSketch` over column `colName`
- * @since 3.4.0
- */
- def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int):
CountMinSketch = {
- val agg = Column.fn("count_min_sketch", col, lit(eps), lit(confidence),
lit(seed))
- val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
- builder.getProjectBuilder
- .setInput(root)
- .addExpressions(agg.expr)
- }
- CountMinSketch.readFrom(ds.head())
- }
-
- /**
- * Builds a Bloom filter over a specified column.
- *
- * @param colName
- * name of the column over which the filter is built
- * @param expectedNumItems
- * expected number of items which will be put into the filter.
- * @param fpp
- * expected false positive probability of the filter.
- * @since 3.5.0
- */
- def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double):
BloomFilter = {
- bloomFilter(Column(colName), expectedNumItems, fpp)
- }
-
- /**
- * Builds a Bloom filter over a specified column.
- *
- * @param col
- * the column over which the filter is built
- * @param expectedNumItems
- * expected number of items which will be put into the filter.
- * @param fpp
- * expected false positive probability of the filter.
- * @since 3.5.0
- */
- def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double):
BloomFilter = {
- val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
- bloomFilter(col, expectedNumItems, numBits)
- }
-
- /**
- * Builds a Bloom filter over a specified column.
- *
- * @param colName
- * name of the column over which the filter is built
- * @param expectedNumItems
- * expected number of items which will be put into the filter.
- * @param numBits
- * expected number of bits of the filter.
- * @since 3.5.0
- */
- def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long):
BloomFilter = {
- bloomFilter(Column(colName), expectedNumItems, numBits)
- }
-
- /**
- * Builds a Bloom filter over a specified column.
- *
- * @param col
- * the column over which the filter is built
- * @param expectedNumItems
- * expected number of items which will be put into the filter.
- * @param numBits
- * expected number of bits of the filter.
- * @since 3.5.0
- */
- def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long):
BloomFilter = {
- val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems),
lit(numBits))
- val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
- builder.getProjectBuilder
- .setInput(root)
- .addExpressions(agg.expr)
- }
- BloomFilter.readFrom(new ByteArrayInputStream(ds.head()))
- }
}
private object DataFrameStatFunctions {
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index a4d1c804685f..37a182675b6c 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -300,7 +300,7 @@ class Dataset[T] private[sql] (
* @group untypedrel
* @since 3.4.0
*/
- def stat: DataFrameStatFunctions = new DataFrameStatFunctions(sparkSession,
plan.getRoot)
+ def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())
private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit):
DataFrame = {
checkSameSparkSession(right)
diff --git a/sql/api/pom.xml b/sql/api/pom.xml
index 9a63f73ab191..54cdc96fc40a 100644
--- a/sql/api/pom.xml
+++ b/sql/api/pom.xml
@@ -53,6 +53,11 @@
<artifactId>spark-unsafe_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sketch_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
<dependency>
<groupId>org.json4s</groupId>
<artifactId>json4s-jackson_${scala.binary.version}</artifactId>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala
similarity index 87%
copy from
sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
copy to
sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala
index 9346739cbbd9..c3ecc7b90d5b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala
@@ -14,19 +14,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
-package org.apache.spark.sql
-
-import java.{lang => jl, util => ju}
+package org.apache.spark.sql.api
import scala.jdk.CollectionConverters._
+import _root_.java.{lang => jl, util => ju}
+
import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.Encoders.BINARY
+import org.apache.spark.sql.{Column, Row}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.BinaryEncoder
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
-import org.apache.spark.sql.execution.stat._
-import org.apache.spark.sql.functions.{col, count_min_sketch, lit}
-import org.apache.spark.util.ArrayImplicits._
+import org.apache.spark.sql.functions.{count_min_sketch, lit}
+import org.apache.spark.util.ArrayImplicits.SparkArrayOps
import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
/**
@@ -35,7 +34,8 @@ import org.apache.spark.util.sketch.{BloomFilter,
CountMinSketch}
* @since 1.4.0
*/
@Stable
-final class DataFrameStatFunctions private[sql](df: DataFrame) {
+abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
+ protected def df: DS[Row]
/**
* Calculates the approximate quantiles of a numerical column of a DataFrame.
@@ -97,28 +97,11 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
def approxQuantile(
cols: Array[String],
probabilities: Array[Double],
- relativeError: Double): Array[Array[Double]] = withOrigin {
- StatFunctions.multipleApproxQuantiles(
- df.select(cols.map(col).toImmutableArraySeq: _*),
- cols.toImmutableArraySeq,
- probabilities.toImmutableArraySeq,
- relativeError).map(_.toArray).toArray
- }
-
-
- /**
- * Python-friendly version of [[approxQuantile()]]
- */
- private[spark] def approxQuantile(
- cols: List[String],
- probabilities: List[Double],
- relativeError: Double): java.util.List[java.util.List[Double]] = {
- approxQuantile(cols.toArray, probabilities.toArray, relativeError)
- .map(_.toList.asJava).toList.asJava
- }
+ relativeError: Double): Array[Array[Double]]
/**
* Calculate the sample covariance of two numerical columns of a DataFrame.
+ *
* @param col1 the name of the first column
* @param col2 the name of the second column
* @return the covariance of the two columns.
@@ -129,12 +112,9 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* df.stat.cov("rand1", "rand2")
* res1: Double = 0.065...
* }}}
- *
* @since 1.4.0
*/
- def cov(col1: String, col2: String): Double = withOrigin {
- StatFunctions.calculateCov(df, Seq(col1, col2))
- }
+ def cov(col1: String, col2: String): Double
/**
* Calculates the correlation of two columns of a DataFrame. Currently only
supports the Pearson
@@ -151,14 +131,9 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* df.stat.corr("rand1", "rand2")
* res1: Double = 0.613...
* }}}
- *
* @since 1.4.0
*/
- def corr(col1: String, col2: String, method: String): Double = withOrigin {
- require(method == "pearson", "Currently only the calculation of the
Pearson Correlation " +
- "coefficient is supported.")
- StatFunctions.pearsonCorrelation(df, Seq(col1, col2))
- }
+ def corr(col1: String, col2: String, method: String): Double
/**
* Calculates the Pearson Correlation Coefficient of two columns of a
DataFrame.
@@ -173,7 +148,6 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* df.stat.corr("rand1", "rand2", "pearson")
* res1: Double = 0.613...
* }}}
- *
* @since 1.4.0
*/
def corr(col1: String, col2: String): Double = {
@@ -210,9 +184,7 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
*
* @since 1.4.0
*/
- def crosstab(col1: String, col2: String): DataFrame = withOrigin {
- StatFunctions.crossTabulate(df, col1, col2)
- }
+ def crosstab(col1: String, col2: String): DS[Row]
/**
* Finding frequent items for columns, possibly with false positives. Using
the
@@ -224,7 +196,7 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* This function is meant for exploratory data analysis, as we make no
guarantee about the
* backward compatibility of the schema of the resulting `DataFrame`.
*
- * @param cols the names of the columns to search frequent items in.
+ * @param cols the names of the columns to search frequent items in.
* @param support The minimum frequency for an item to be considered
`frequent`. Should be greater
* than 1e-4.
* @return A Local DataFrame with the Array of frequent items for each
column.
@@ -254,12 +226,10 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* | ... |
* +----------+
* }}}
- *
* @since 1.4.0
*/
- def freqItems(cols: Array[String], support: Double): DataFrame = withOrigin {
- FrequentItems.singlePassFreqItems(df, cols.toImmutableArraySeq, support)
- }
+ def freqItems(cols: Array[String], support: Double): DS[Row] =
+ freqItems(cols.toImmutableArraySeq, support)
/**
* Finding frequent items for columns, possibly with false positives. Using
the
@@ -273,12 +243,9 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
*
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each
column.
- *
* @since 1.4.0
*/
- def freqItems(cols: Array[String]): DataFrame = withOrigin {
- FrequentItems.singlePassFreqItems(df, cols.toImmutableArraySeq, 0.01)
- }
+ def freqItems(cols: Array[String]): DS[Row] = freqItems(cols, 0.01)
/**
* (Scala-specific) Finding frequent items for columns, possibly with false
positives. Using the
@@ -320,9 +287,7 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
*
* @since 1.4.0
*/
- def freqItems(cols: Seq[String], support: Double): DataFrame = withOrigin {
- FrequentItems.singlePassFreqItems(df, cols, support)
- }
+ def freqItems(cols: Seq[String], support: Double): DS[Row]
/**
* (Scala-specific) Finding frequent items for columns, possibly with false
positives. Using the
@@ -336,12 +301,9 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
*
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each
column.
- *
* @since 1.4.0
*/
- def freqItems(cols: Seq[String]): DataFrame = withOrigin {
- FrequentItems.singlePassFreqItems(df, cols, 0.01)
- }
+ def freqItems(cols: Seq[String]): DS[Row] = freqItems(cols, 0.01)
/**
* Returns a stratified sample without replacement based on the fraction
given on each stratum.
@@ -368,7 +330,7 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
*
* @since 1.5.0
*/
- def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long):
DataFrame = {
+ def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DS[Row]
= {
sampleBy(Column(col), fractions, seed)
}
@@ -383,7 +345,7 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
*
* @since 1.5.0
*/
- def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long):
DataFrame = {
+ def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long):
DS[Row] = {
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
}
@@ -415,30 +377,22 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
*
* @since 3.0.0
*/
- def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long):
DataFrame = withOrigin {
- require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
- s"Fractions must be in [0, 1], but got $fractions.")
- import org.apache.spark.sql.functions.{rand, udf}
- val r = rand(seed)
- val f = udf { (stratum: Any, x: Double) =>
- x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0)
- }
- df.filter(f(col, r))
- }
+ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DS[Row]
+
/**
* (Java-specific) Returns a stratified sample without replacement based on
the fraction given
* on each stratum.
- * @param col column that defines strata
+ *
+ * @param col column that defines strata
* @param fractions sampling fraction for each stratum. If a stratum is not
specified, we treat
* its fraction as zero.
- * @param seed random seed
+ * @param seed random seed
* @tparam T stratum type
* @return a new `DataFrame` that represents the stratified sample
- *
* @since 3.0.0
*/
- def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long):
DataFrame = {
+ def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long):
DS[Row] = {
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
}
@@ -503,7 +457,7 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
confidence: Double,
seed: Int): CountMinSketch = withOrigin {
val cms = count_min_sketch(col, lit(eps), lit(confidence), lit(seed))
- val bytes = df.select(cms).as(BINARY).head()
+ val bytes: Array[Byte] = df.select(cms).as(BinaryEncoder).head()
CountMinSketch.readFrom(bytes)
}
@@ -554,7 +508,7 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
*/
def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long):
BloomFilter = withOrigin {
val bf = Column.internalFn("bloom_filter_agg", col, lit(expectedNumItems),
lit(numBits))
- val bytes = df.select(bf).as(BINARY).head()
+ val bytes: Array[Byte] = df.select(bf).as(BinaryEncoder).head()
BloomFilter.readFrom(bytes)
}
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
index 5b4ebed12c17..16f15205cabe 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
@@ -538,6 +538,18 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends
Serializable {
// scalastyle:off println
def show(numRows: Int, truncate: Int, vertical: Boolean): Unit
+ /**
+ * Returns a [[DataFrameStatFunctions]] for working statistic functions
support.
+ * {{{
+ * // Finding frequent items in column with name 'a'.
+ * ds.stat.freqItems(Seq("a"))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 1.6.0
+ */
+ def stat: DataFrameStatFunctions[DS]
+
/**
* Join with another `DataFrame`.
*
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 9346739cbbd9..a5ab237bb704 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -22,12 +22,10 @@ import java.{lang => jl, util => ju}
import scala.jdk.CollectionConverters._
import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.Encoders.BINARY
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.execution.stat._
-import org.apache.spark.sql.functions.{col, count_min_sketch, lit}
+import org.apache.spark.sql.functions.col
import org.apache.spark.util.ArrayImplicits._
-import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
/**
* Statistic functions for `DataFrame`s.
@@ -35,65 +33,10 @@ import org.apache.spark.util.sketch.{BloomFilter,
CountMinSketch}
* @since 1.4.0
*/
@Stable
-final class DataFrameStatFunctions private[sql](df: DataFrame) {
+final class DataFrameStatFunctions private[sql](protected val df: DataFrame)
+ extends api.DataFrameStatFunctions[Dataset] {
- /**
- * Calculates the approximate quantiles of a numerical column of a DataFrame.
- *
- * The result of this algorithm has the following deterministic bound:
- * If the DataFrame has N elements and if we request the quantile at
probability `p` up to error
- * `err`, then the algorithm will return a sample `x` from the DataFrame so
that the *exact* rank
- * of `x` is close to (p * N).
- * More precisely,
- *
- * {{{
- * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N)
- * }}}
- *
- * This method implements a variation of the Greenwald-Khanna algorithm
(with some speed
- * optimizations).
- * The algorithm was first present in <a
href="https://doi.org/10.1145/375663.375670">
- * Space-efficient Online Computation of Quantile Summaries</a> by Greenwald
and Khanna.
- *
- * @param col the name of the numerical column
- * @param probabilities a list of quantile probabilities
- * Each number must belong to [0, 1].
- * For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
- * @param relativeError The relative target precision to achieve (greater
than or equal to 0).
- * If set to zero, the exact quantiles are computed, which could be very
expensive.
- * Note that values greater than 1 are accepted but give the same result
as 1.
- * @return the approximate quantiles at the given probabilities
- *
- * @note null and NaN values will be removed from the numerical column
before calculation. If
- * the dataframe is empty or the column only contains null or NaN, an
empty array is returned.
- *
- * @since 2.0.0
- */
- def approxQuantile(
- col: String,
- probabilities: Array[Double],
- relativeError: Double): Array[Double] = withOrigin {
- approxQuantile(Array(col), probabilities, relativeError).head
- }
-
- /**
- * Calculates the approximate quantiles of numerical columns of a DataFrame.
- * @see `approxQuantile(col:Str* approxQuantile)` for detailed description.
- *
- * @param cols the names of the numerical columns
- * @param probabilities a list of quantile probabilities
- * Each number must belong to [0, 1].
- * For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
- * @param relativeError The relative target precision to achieve (greater
than or equal to 0).
- * If set to zero, the exact quantiles are computed, which could be very
expensive.
- * Note that values greater than 1 are accepted but give the same result
as 1.
- * @return the approximate quantiles at the given probabilities of each
column
- *
- * @note null and NaN values will be ignored in numerical columns before
calculation. For
- * columns only containing null or NaN values, an empty array is returned.
- *
- * @since 2.2.0
- */
+ /** @inheritdoc */
def approxQuantile(
cols: Array[String],
probabilities: Array[Double],
@@ -105,7 +48,6 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
relativeError).map(_.toArray).toArray
}
-
/**
* Python-friendly version of [[approxQuantile()]]
*/
@@ -117,304 +59,49 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
.map(_.toList.asJava).toList.asJava
}
- /**
- * Calculate the sample covariance of two numerical columns of a DataFrame.
- * @param col1 the name of the first column
- * @param col2 the name of the second column
- * @return the covariance of the two columns.
- *
- * {{{
- * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1",
rand(seed=10))
- * .withColumn("rand2", rand(seed=27))
- * df.stat.cov("rand1", "rand2")
- * res1: Double = 0.065...
- * }}}
- *
- * @since 1.4.0
- */
+ /** @inheritdoc */
def cov(col1: String, col2: String): Double = withOrigin {
StatFunctions.calculateCov(df, Seq(col1, col2))
}
- /**
- * Calculates the correlation of two columns of a DataFrame. Currently only
supports the Pearson
- * Correlation Coefficient. For Spearman Correlation, consider using RDD
methods found in
- * MLlib's Statistics.
- *
- * @param col1 the name of the column
- * @param col2 the name of the column to calculate the correlation against
- * @return The Pearson Correlation Coefficient as a Double.
- *
- * {{{
- * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1",
rand(seed=10))
- * .withColumn("rand2", rand(seed=27))
- * df.stat.corr("rand1", "rand2")
- * res1: Double = 0.613...
- * }}}
- *
- * @since 1.4.0
- */
+ /** @inheritdoc */
def corr(col1: String, col2: String, method: String): Double = withOrigin {
require(method == "pearson", "Currently only the calculation of the
Pearson Correlation " +
"coefficient is supported.")
StatFunctions.pearsonCorrelation(df, Seq(col1, col2))
}
- /**
- * Calculates the Pearson Correlation Coefficient of two columns of a
DataFrame.
- *
- * @param col1 the name of the column
- * @param col2 the name of the column to calculate the correlation against
- * @return The Pearson Correlation Coefficient as a Double.
- *
- * {{{
- * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1",
rand(seed=10))
- * .withColumn("rand2", rand(seed=27))
- * df.stat.corr("rand1", "rand2", "pearson")
- * res1: Double = 0.613...
- * }}}
- *
- * @since 1.4.0
- */
- def corr(col1: String, col2: String): Double = {
- corr(col1, col2, "pearson")
- }
-
- /**
- * Computes a pair-wise frequency table of the given columns. Also known as
a contingency table.
- * The first column of each row will be the distinct values of `col1` and
the column names will
- * be the distinct values of `col2`. The name of the first column will be
`col1_col2`. Counts
- * will be returned as `Long`s. Pairs that have no occurrences will have
zero as their counts.
- * Null elements will be replaced by "null", and back ticks will be dropped
from elements if they
- * exist.
- *
- * @param col1 The name of the first column. Distinct items will make the
first item of
- * each row.
- * @param col2 The name of the second column. Distinct items will make the
column names
- * of the DataFrame.
- * @return A DataFrame containing for the contingency table.
- *
- * {{{
- * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2,
3), (3, 2), (3, 3)))
- * .toDF("key", "value")
- * val ct = df.stat.crosstab("key", "value")
- * ct.show()
- * +---------+---+---+---+
- * |key_value| 1| 2| 3|
- * +---------+---+---+---+
- * | 2| 2| 0| 1|
- * | 1| 1| 1| 0|
- * | 3| 0| 1| 1|
- * +---------+---+---+---+
- * }}}
- *
- * @since 1.4.0
- */
+ /** @inheritdoc */
def crosstab(col1: String, col2: String): DataFrame = withOrigin {
StatFunctions.crossTabulate(df, col1, col2)
}
- /**
- * Finding frequent items for columns, possibly with false positives. Using
the
- * frequent element count algorithm described in
- * <a href="https://doi.org/10.1145/762471.762473">here</a>, proposed by
Karp,
- * Schenker, and Papadimitriou.
- * The `support` should be greater than 1e-4.
- *
- * This function is meant for exploratory data analysis, as we make no
guarantee about the
- * backward compatibility of the schema of the resulting `DataFrame`.
- *
- * @param cols the names of the columns to search frequent items in.
- * @param support The minimum frequency for an item to be considered
`frequent`. Should be greater
- * than 1e-4.
- * @return A Local DataFrame with the Array of frequent items for each
column.
- *
- * {{{
- * val rows = Seq.tabulate(100) { i =>
- * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0)
- * }
- * val df = spark.createDataFrame(rows).toDF("a", "b")
- * // find the items with a frequency greater than 0.4 (observed 40% of
the time) for columns
- * // "a" and "b"
- * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4)
- * freqSingles.show()
- * +-----------+-------------+
- * |a_freqItems| b_freqItems|
- * +-----------+-------------+
- * | [1, 99]|[-1.0, -99.0]|
- * +-----------+-------------+
- * // find the pair of items with a frequency greater than 0.1 in columns
"a" and "b"
- * val pairDf = df.select(struct("a", "b").as("a-b"))
- * val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1)
- * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show()
- * +----------+
- * | freq_ab|
- * +----------+
- * | [1,-1.0]|
- * | ... |
- * +----------+
- * }}}
- *
- * @since 1.4.0
- */
- def freqItems(cols: Array[String], support: Double): DataFrame = withOrigin {
- FrequentItems.singlePassFreqItems(df, cols.toImmutableArraySeq, support)
- }
-
- /**
- * Finding frequent items for columns, possibly with false positives. Using
the
- * frequent element count algorithm described in
- * <a href="https://doi.org/10.1145/762471.762473">here</a>, proposed by
Karp,
- * Schenker, and Papadimitriou.
- * Uses a `default` support of 1%.
- *
- * This function is meant for exploratory data analysis, as we make no
guarantee about the
- * backward compatibility of the schema of the resulting `DataFrame`.
- *
- * @param cols the names of the columns to search frequent items in.
- * @return A Local DataFrame with the Array of frequent items for each
column.
- *
- * @since 1.4.0
- */
- def freqItems(cols: Array[String]): DataFrame = withOrigin {
- FrequentItems.singlePassFreqItems(df, cols.toImmutableArraySeq, 0.01)
- }
-
- /**
- * (Scala-specific) Finding frequent items for columns, possibly with false
positives. Using the
- * frequent element count algorithm described in
- * <a href="https://doi.org/10.1145/762471.762473">here</a>, proposed by
Karp, Schenker,
- * and Papadimitriou.
- *
- * This function is meant for exploratory data analysis, as we make no
guarantee about the
- * backward compatibility of the schema of the resulting `DataFrame`.
- *
- * @param cols the names of the columns to search frequent items in.
- * @return A Local DataFrame with the Array of frequent items for each
column.
- *
- * {{{
- * val rows = Seq.tabulate(100) { i =>
- * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0)
- * }
- * val df = spark.createDataFrame(rows).toDF("a", "b")
- * // find the items with a frequency greater than 0.4 (observed 40% of
the time) for columns
- * // "a" and "b"
- * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4)
- * freqSingles.show()
- * +-----------+-------------+
- * |a_freqItems| b_freqItems|
- * +-----------+-------------+
- * | [1, 99]|[-1.0, -99.0]|
- * +-----------+-------------+
- * // find the pair of items with a frequency greater than 0.1 in columns
"a" and "b"
- * val pairDf = df.select(struct("a", "b").as("a-b"))
- * val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1)
- * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show()
- * +----------+
- * | freq_ab|
- * +----------+
- * | [1,-1.0]|
- * | ... |
- * +----------+
- * }}}
- *
- * @since 1.4.0
- */
+ /** @inheritdoc */
def freqItems(cols: Seq[String], support: Double): DataFrame = withOrigin {
FrequentItems.singlePassFreqItems(df, cols, support)
}
- /**
- * (Scala-specific) Finding frequent items for columns, possibly with false
positives. Using the
- * frequent element count algorithm described in
- * <a href="https://doi.org/10.1145/762471.762473">here</a>, proposed by
Karp, Schenker,
- * and Papadimitriou.
- * Uses a `default` support of 1%.
- *
- * This function is meant for exploratory data analysis, as we make no
guarantee about the
- * backward compatibility of the schema of the resulting `DataFrame`.
- *
- * @param cols the names of the columns to search frequent items in.
- * @return A Local DataFrame with the Array of frequent items for each
column.
- *
- * @since 1.4.0
- */
- def freqItems(cols: Seq[String]): DataFrame = withOrigin {
- FrequentItems.singlePassFreqItems(df, cols, 0.01)
- }
+ /** @inheritdoc */
+ override def freqItems(cols: Array[String], support: Double): DataFrame =
+ super.freqItems(cols, support)
- /**
- * Returns a stratified sample without replacement based on the fraction
given on each stratum.
- * @param col column that defines strata
- * @param fractions sampling fraction for each stratum. If a stratum is not
specified, we treat
- * its fraction as zero.
- * @param seed random seed
- * @tparam T stratum type
- * @return a new `DataFrame` that represents the stratified sample
- *
- * {{{
- * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2,
3), (3, 2),
- * (3, 3))).toDF("key", "value")
- * val fractions = Map(1 -> 1.0, 3 -> 0.5)
- * df.stat.sampleBy("key", fractions, 36L).show()
- * +---+-----+
- * |key|value|
- * +---+-----+
- * | 1| 1|
- * | 1| 2|
- * | 3| 2|
- * +---+-----+
- * }}}
- *
- * @since 1.5.0
- */
- def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long):
DataFrame = {
- sampleBy(Column(col), fractions, seed)
+ /** @inheritdoc */
+ override def freqItems(cols: Array[String]): DataFrame =
super.freqItems(cols)
+
+ /** @inheritdoc */
+ override def freqItems(cols: Seq[String]): DataFrame = super.freqItems(cols)
+
+ /** @inheritdoc */
+ override def sampleBy[T](col: String, fractions: Map[T, Double], seed:
Long): DataFrame = {
+ super.sampleBy(col, fractions, seed)
}
- /**
- * Returns a stratified sample without replacement based on the fraction
given on each stratum.
- * @param col column that defines strata
- * @param fractions sampling fraction for each stratum. If a stratum is not
specified, we treat
- * its fraction as zero.
- * @param seed random seed
- * @tparam T stratum type
- * @return a new `DataFrame` that represents the stratified sample
- *
- * @since 1.5.0
- */
- def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long):
DataFrame = {
- sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+ /** @inheritdoc */
+ override def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed:
Long): DataFrame = {
+ super.sampleBy(col, fractions, seed)
}
- /**
- * Returns a stratified sample without replacement based on the fraction
given on each stratum.
- * @param col column that defines strata
- * @param fractions sampling fraction for each stratum. If a stratum is not
specified, we treat
- * its fraction as zero.
- * @param seed random seed
- * @tparam T stratum type
- * @return a new `DataFrame` that represents the stratified sample
- *
- * The stratified sample can be performed over multiple columns:
- * {{{
- * import org.apache.spark.sql.Row
- * import org.apache.spark.sql.functions.struct
- *
- * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10),
("Nico", 8), ("Bob", 17),
- * ("Alice", 10))).toDF("name", "age")
- * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0)
- * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show()
- * +-----+---+
- * | name|age|
- * +-----+---+
- * | Nico| 8|
- * |Alice| 10|
- * +-----+---+
- * }}}
- *
- * @since 3.0.0
- */
+ /** @inheritdoc */
def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long):
DataFrame = withOrigin {
require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
s"Fractions must be in [0, 1], but got $fractions.")
@@ -426,135 +113,8 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
df.filter(f(col, r))
}
- /**
- * (Java-specific) Returns a stratified sample without replacement based on
the fraction given
- * on each stratum.
- * @param col column that defines strata
- * @param fractions sampling fraction for each stratum. If a stratum is not
specified, we treat
- * its fraction as zero.
- * @param seed random seed
- * @tparam T stratum type
- * @return a new `DataFrame` that represents the stratified sample
- *
- * @since 3.0.0
- */
- def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long):
DataFrame = {
- sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
- }
-
- /**
- * Builds a Count-min Sketch over a specified column.
- *
- * @param colName name of the column over which the sketch is built
- * @param depth depth of the sketch
- * @param width width of the sketch
- * @param seed random seed
- * @return a `CountMinSketch` over column `colName`
- * @since 2.0.0
- */
- def countMinSketch(colName: String, depth: Int, width: Int, seed: Int):
CountMinSketch = {
- countMinSketch(Column(colName), depth, width, seed)
- }
-
- /**
- * Builds a Count-min Sketch over a specified column.
- *
- * @param colName name of the column over which the sketch is built
- * @param eps relative error of the sketch
- * @param confidence confidence of the sketch
- * @param seed random seed
- * @return a `CountMinSketch` over column `colName`
- * @since 2.0.0
- */
- def countMinSketch(
- colName: String, eps: Double, confidence: Double, seed: Int):
CountMinSketch = {
- countMinSketch(Column(colName), eps, confidence, seed)
- }
-
- /**
- * Builds a Count-min Sketch over a specified column.
- *
- * @param col the column over which the sketch is built
- * @param depth depth of the sketch
- * @param width width of the sketch
- * @param seed random seed
- * @return a `CountMinSketch` over column `colName`
- * @since 2.0.0
- */
- def countMinSketch(col: Column, depth: Int, width: Int, seed: Int):
CountMinSketch = {
- val eps = 2.0 / width
- val confidence = 1 - 1 / Math.pow(2, depth)
- countMinSketch(col, eps, confidence, seed)
- }
-
- /**
- * Builds a Count-min Sketch over a specified column.
- *
- * @param col the column over which the sketch is built
- * @param eps relative error of the sketch
- * @param confidence confidence of the sketch
- * @param seed random seed
- * @return a `CountMinSketch` over column `colName`
- * @since 2.0.0
- */
- def countMinSketch(
- col: Column,
- eps: Double,
- confidence: Double,
- seed: Int): CountMinSketch = withOrigin {
- val cms = count_min_sketch(col, lit(eps), lit(confidence), lit(seed))
- val bytes = df.select(cms).as(BINARY).head()
- CountMinSketch.readFrom(bytes)
- }
-
- /**
- * Builds a Bloom filter over a specified column.
- *
- * @param colName name of the column over which the filter is built
- * @param expectedNumItems expected number of items which will be put into
the filter.
- * @param fpp expected false positive probability of the filter.
- * @since 2.0.0
- */
- def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double):
BloomFilter = {
- bloomFilter(Column(colName), expectedNumItems, fpp)
- }
-
- /**
- * Builds a Bloom filter over a specified column.
- *
- * @param col the column over which the filter is built
- * @param expectedNumItems expected number of items which will be put into
the filter.
- * @param fpp expected false positive probability of the filter.
- * @since 2.0.0
- */
- def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double):
BloomFilter = {
- val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
- bloomFilter(col, expectedNumItems, numBits)
- }
-
- /**
- * Builds a Bloom filter over a specified column.
- *
- * @param colName name of the column over which the filter is built
- * @param expectedNumItems expected number of items which will be put into
the filter.
- * @param numBits expected number of bits of the filter.
- * @since 2.0.0
- */
- def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long):
BloomFilter = {
- bloomFilter(Column(colName), expectedNumItems, numBits)
- }
-
- /**
- * Builds a Bloom filter over a specified column.
- *
- * @param col the column over which the filter is built
- * @param expectedNumItems expected number of items which will be put into
the filter.
- * @param numBits expected number of bits of the filter.
- * @since 2.0.0
- */
- def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long):
BloomFilter = withOrigin {
- val bf = Column.internalFn("bloom_filter_agg", col, lit(expectedNumItems),
lit(numBits))
- val bytes = df.select(bf).as(BINARY).head()
- BloomFilter.readFrom(bytes)
+ /** @inheritdoc */
+ override def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed:
Long): DataFrame = {
+ super.sampleBy(col, fractions, seed)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]