Repository: spark
Updated Branches:
refs/heads/master 88446b6ad -> a86f84102
[SPARK-25381][SQL] Stratified sampling by Column argument
## What changes were proposed in this pull request?
In the PR, I propose to add an overloaded method for `sampleBy` which accepts
the first argument of the `Column` type. This will allow to sample by any
complex columns as well as sampling by multiple columns. For example:
```Scala
spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17),
("Alice", 10))).toDF("name", "age")
.stat
.sampleBy(struct($"name", $"age"), Map(Row("Alice", 10) -> 0.3, Row("Nico",
8) -> 1.0), 36L)
.show()
+-----+---+
| name|age|
+-----+---+
| Nico| 8|
|Alice| 10|
+-----+---+
```
## How was this patch tested?
Added new test for sampling by multiple columns for Scala and test for Java,
Python to check that `sampleBy` is able to sample by `Column` type argument.
Closes #22365 from MaxGekk/sample-by-column.
Authored-by: Maxim Gekk <[email protected]>
Signed-off-by: hyukjinkwon <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a86f8410
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a86f8410
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a86f8410
Branch: refs/heads/master
Commit: a86f84102e10a6ca6325c604bc76d81b0f53eba3
Parents: 88446b6
Author: Maxim Gekk <[email protected]>
Authored: Fri Sep 21 01:11:40 2018 +0800
Committer: hyukjinkwon <[email protected]>
Committed: Fri Sep 21 01:11:40 2018 +0800
----------------------------------------------------------------------
python/pyspark/sql/dataframe.py | 11 +++-
.../spark/sql/DataFrameStatFunctions.scala | 57 ++++++++++++++++++--
.../apache/spark/sql/JavaDataFrameSuite.java | 11 ++++
.../apache/spark/sql/DataFrameStatSuite.scala | 20 ++++++-
4 files changed, 91 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a86f8410/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1affc9b..21bc69b 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -880,16 +880,23 @@ class DataFrame(object):
| 0| 5|
| 1| 9|
+---+-----+
+ >>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count()
+ 33
+ .. versionchanged:: 2.5
+ Added sampling by a column of :class:`Column`
"""
- if not isinstance(col, basestring):
- raise ValueError("col must be a string, but got %r" % type(col))
+ if isinstance(col, basestring):
+ col = Column(col)
+ elif not isinstance(col, Column):
+ raise ValueError("col must be a string or a column, but got %r" %
type(col))
if not isinstance(fractions, dict):
raise ValueError("fractions must be a dict but got %r" %
type(fractions))
for k, v in fractions.items():
if not isinstance(k, (float, int, long, basestring)):
raise ValueError("key must be float, int, long, or string, but
got %r" % type(k))
fractions[k] = float(v)
+ col = col._jc
seed = seed if seed is not None else random.randint(0, sys.maxsize)
return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions),
seed), self.sql_ctx)
http://git-wip-us.apache.org/repos/asf/spark/blob/a86f8410/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index a417530..75b8477 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -370,19 +370,66 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* @since 1.5.0
*/
def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long):
DataFrame = {
+ sampleBy(Column(col), fractions, seed)
+ }
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction
given on each stratum.
+ * @param col column that defines strata
+ * @param fractions sampling fraction for each stratum. If a stratum is not
specified, we treat
+ * its fraction as zero.
+ * @param seed random seed
+ * @tparam T stratum type
+ * @return a new `DataFrame` that represents the stratified sample
+ *
+ * @since 1.5.0
+ */
+ def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long):
DataFrame = {
+ sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+ }
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction
given on each stratum.
+ * @param col column that defines strata
+ * @param fractions sampling fraction for each stratum. If a stratum is not
specified, we treat
+ * its fraction as zero.
+ * @param seed random seed
+ * @tparam T stratum type
+ * @return a new `DataFrame` that represents the stratified sample
+ *
+ * The stratified sample can be performed over multiple columns:
+ * {{{
+ * import org.apache.spark.sql.Row
+ * import org.apache.spark.sql.functions.struct
+ *
+ * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10),
("Nico", 8), ("Bob", 17),
+ * ("Alice", 10))).toDF("name", "age")
+ * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0)
+ * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show()
+ * +-----+---+
+ * | name|age|
+ * +-----+---+
+ * | Nico| 8|
+ * |Alice| 10|
+ * +-----+---+
+ * }}}
+ *
+ * @since 2.5.0
+ */
+ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long):
DataFrame = {
require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
s"Fractions must be in [0, 1], but got $fractions.")
import org.apache.spark.sql.functions.{rand, udf}
- val c = Column(col)
val r = rand(seed)
val f = udf { (stratum: Any, x: Double) =>
x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0)
}
- df.filter(f(c, r))
+ df.filter(f(col, r))
}
/**
- * Returns a stratified sample without replacement based on the fraction
given on each stratum.
+ * (Java-specific) Returns a stratified sample without replacement based on
the fraction given
+ * on each stratum.
* @param col column that defines strata
* @param fractions sampling fraction for each stratum. If a stratum is not
specified, we treat
* its fraction as zero.
@@ -390,9 +437,9 @@ final class DataFrameStatFunctions private[sql](df:
DataFrame) {
* @tparam T stratum type
* @return a new `DataFrame` that represents the stratified sample
*
- * @since 1.5.0
+ * @since 2.5.0
*/
- def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long):
DataFrame = {
+ def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long):
DataFrame = {
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a86f8410/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 69a2904..3f37e58 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -291,6 +291,17 @@ public class JavaDataFrameSuite {
}
@Test
+ public void testSampleByColumn() {
+ Dataset<Row> df = spark.range(0, 100, 1,
2).select(col("id").mod(3).as("key"));
+ Dataset<Row> sampled = df.stat().sampleBy(col("key"), ImmutableMap.of(0,
0.1, 1, 0.2), 0L);
+ List<Row> actual =
sampled.groupBy("key").count().orderBy("key").collectAsList();
+ Assert.assertEquals(0, actual.get(0).getLong(0));
+ Assert.assertTrue(0 <= actual.get(0).getLong(1) &&
actual.get(0).getLong(1) <= 8);
+ Assert.assertEquals(1, actual.get(1).getLong(0));
+ Assert.assertTrue(2 <= actual.get(1).getLong(1) &&
actual.get(1).getLong(1) <= 13);
+ }
+
+ @Test
public void pivot() {
Dataset<Row> df = spark.table("courseSales");
List<Row> actual = df.groupBy("year")
http://git-wip-us.apache.org/repos/asf/spark/blob/a86f8410/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 8eae353..589873b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.Matchers._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.stat.StatFunctions
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions.{col, lit, struct}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
@@ -374,6 +374,24 @@ class DataFrameStatSuite extends QueryTest with
SharedSQLContext {
Seq(Row(0, 6), Row(1, 11)))
}
+ test("sampleBy one column") {
+ val df = spark.range(0, 100).select((col("id") % 3).as("key"))
+ val sampled = df.stat.sampleBy($"key", Map(0 -> 0.1, 1 -> 0.2), 0L)
+ checkAnswer(
+ sampled.groupBy("key").count().orderBy("key"),
+ Seq(Row(0, 6), Row(1, 11)))
+ }
+
+ test("sampleBy multiple columns") {
+ val df = spark.range(0, 100)
+ .select(lit("Foo").as("name"), (col("id") % 3).as("key"))
+ val sampled = df.stat.sampleBy(
+ struct($"name", $"key"), Map(Row("Foo", 0) -> 0.1, Row("Foo", 1) ->
0.2), 0L)
+ checkAnswer(
+ sampled.groupBy("key").count().orderBy("key"),
+ Seq(Row(0, 6), Row(1, 11)))
+ }
+
// This test case only verifies that `DataFrame.countMinSketch()` methods do
return
// `CountMinSketch`es that meet required specs. Test cases for
`CountMinSketch` can be found in
// `CountMinSketchSuite` in project spark-sketch.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]