Repository: spark
Updated Branches:
  refs/heads/master 88446b6ad -> a86f84102


[SPARK-25381][SQL] Stratified sampling by Column argument

## What changes were proposed in this pull request?

In the PR, I propose to add an overloaded method for `sampleBy` which accepts 
the first argument of the `Column` type. This will allow to sample by any 
complex columns as well as sampling by multiple columns. For example:

```Scala
spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17),
  ("Alice", 10))).toDF("name", "age")
  .stat
  .sampleBy(struct($"name", $"age"), Map(Row("Alice", 10) -> 0.3, Row("Nico", 
8) -> 1.0), 36L)
  .show()

+-----+---+
| name|age|
+-----+---+
| Nico|  8|
|Alice| 10|
+-----+---+
```

## How was this patch tested?

Added new test for sampling by multiple columns for Scala and test for Java, 
Python to check that `sampleBy` is able to sample by `Column` type argument.

Closes #22365 from MaxGekk/sample-by-column.

Authored-by: Maxim Gekk <[email protected]>
Signed-off-by: hyukjinkwon <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a86f8410
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a86f8410
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a86f8410

Branch: refs/heads/master
Commit: a86f84102e10a6ca6325c604bc76d81b0f53eba3
Parents: 88446b6
Author: Maxim Gekk <[email protected]>
Authored: Fri Sep 21 01:11:40 2018 +0800
Committer: hyukjinkwon <[email protected]>
Committed: Fri Sep 21 01:11:40 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 11 +++-
 .../spark/sql/DataFrameStatFunctions.scala      | 57 ++++++++++++++++++--
 .../apache/spark/sql/JavaDataFrameSuite.java    | 11 ++++
 .../apache/spark/sql/DataFrameStatSuite.scala   | 20 ++++++-
 4 files changed, 91 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a86f8410/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1affc9b..21bc69b 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -880,16 +880,23 @@ class DataFrame(object):
         |  0|    5|
         |  1|    9|
         +---+-----+
+        >>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count()
+        33
 
+        .. versionchanged:: 2.5
+           Added sampling by a column of :class:`Column`
         """
-        if not isinstance(col, basestring):
-            raise ValueError("col must be a string, but got %r" % type(col))
+        if isinstance(col, basestring):
+            col = Column(col)
+        elif not isinstance(col, Column):
+            raise ValueError("col must be a string or a column, but got %r" % 
type(col))
         if not isinstance(fractions, dict):
             raise ValueError("fractions must be a dict but got %r" % 
type(fractions))
         for k, v in fractions.items():
             if not isinstance(k, (float, int, long, basestring)):
                 raise ValueError("key must be float, int, long, or string, but 
got %r" % type(k))
             fractions[k] = float(v)
+        col = col._jc
         seed = seed if seed is not None else random.randint(0, sys.maxsize)
         return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), 
seed), self.sql_ctx)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a86f8410/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index a417530..75b8477 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -370,19 +370,66 @@ final class DataFrameStatFunctions private[sql](df: 
DataFrame) {
    * @since 1.5.0
    */
   def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): 
DataFrame = {
+    sampleBy(Column(col), fractions, seed)
+  }
+
+  /**
+   * Returns a stratified sample without replacement based on the fraction 
given on each stratum.
+   * @param col column that defines strata
+   * @param fractions sampling fraction for each stratum. If a stratum is not 
specified, we treat
+   *                  its fraction as zero.
+   * @param seed random seed
+   * @tparam T stratum type
+   * @return a new `DataFrame` that represents the stratified sample
+   *
+   * @since 1.5.0
+   */
+  def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): 
DataFrame = {
+    sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+  }
+
+  /**
+   * Returns a stratified sample without replacement based on the fraction 
given on each stratum.
+   * @param col column that defines strata
+   * @param fractions sampling fraction for each stratum. If a stratum is not 
specified, we treat
+   *                  its fraction as zero.
+   * @param seed random seed
+   * @tparam T stratum type
+   * @return a new `DataFrame` that represents the stratified sample
+   *
+   * The stratified sample can be performed over multiple columns:
+   * {{{
+   *    import org.apache.spark.sql.Row
+   *    import org.apache.spark.sql.functions.struct
+   *
+   *    val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), 
("Nico", 8), ("Bob", 17),
+   *      ("Alice", 10))).toDF("name", "age")
+   *    val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0)
+   *    df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show()
+   *    +-----+---+
+   *    | name|age|
+   *    +-----+---+
+   *    | Nico|  8|
+   *    |Alice| 10|
+   *    +-----+---+
+   * }}}
+   *
+   * @since 2.5.0
+   */
+  def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): 
DataFrame = {
     require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
       s"Fractions must be in [0, 1], but got $fractions.")
     import org.apache.spark.sql.functions.{rand, udf}
-    val c = Column(col)
     val r = rand(seed)
     val f = udf { (stratum: Any, x: Double) =>
       x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0)
     }
-    df.filter(f(c, r))
+    df.filter(f(col, r))
   }
 
   /**
-   * Returns a stratified sample without replacement based on the fraction 
given on each stratum.
+   * (Java-specific) Returns a stratified sample without replacement based on 
the fraction given
+   * on each stratum.
    * @param col column that defines strata
    * @param fractions sampling fraction for each stratum. If a stratum is not 
specified, we treat
    *                  its fraction as zero.
@@ -390,9 +437,9 @@ final class DataFrameStatFunctions private[sql](df: 
DataFrame) {
    * @tparam T stratum type
    * @return a new `DataFrame` that represents the stratified sample
    *
-   * @since 1.5.0
+   * @since 2.5.0
    */
-  def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): 
DataFrame = {
+  def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): 
DataFrame = {
     sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a86f8410/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 69a2904..3f37e58 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -291,6 +291,17 @@ public class JavaDataFrameSuite {
   }
 
   @Test
+  public void testSampleByColumn() {
+    Dataset<Row> df = spark.range(0, 100, 1, 
2).select(col("id").mod(3).as("key"));
+    Dataset<Row> sampled = df.stat().sampleBy(col("key"), ImmutableMap.of(0, 
0.1, 1, 0.2), 0L);
+    List<Row> actual = 
sampled.groupBy("key").count().orderBy("key").collectAsList();
+    Assert.assertEquals(0, actual.get(0).getLong(0));
+    Assert.assertTrue(0 <= actual.get(0).getLong(1) && 
actual.get(0).getLong(1) <= 8);
+    Assert.assertEquals(1, actual.get(1).getLong(0));
+    Assert.assertTrue(2 <= actual.get(1).getLong(1) && 
actual.get(1).getLong(1) <= 13);
+  }
+
+  @Test
   public void pivot() {
     Dataset<Row> df = spark.table("courseSales");
     List<Row> actual = df.groupBy("year")

http://git-wip-us.apache.org/repos/asf/spark/blob/a86f8410/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 8eae353..589873b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.Matchers._
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.execution.stat.StatFunctions
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions.{col, lit, struct}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
@@ -374,6 +374,24 @@ class DataFrameStatSuite extends QueryTest with 
SharedSQLContext {
       Seq(Row(0, 6), Row(1, 11)))
   }
 
+  test("sampleBy one column") {
+    val df = spark.range(0, 100).select((col("id") % 3).as("key"))
+    val sampled = df.stat.sampleBy($"key", Map(0 -> 0.1, 1 -> 0.2), 0L)
+    checkAnswer(
+      sampled.groupBy("key").count().orderBy("key"),
+      Seq(Row(0, 6), Row(1, 11)))
+  }
+
+  test("sampleBy multiple columns") {
+    val df = spark.range(0, 100)
+      .select(lit("Foo").as("name"), (col("id") % 3).as("key"))
+    val sampled = df.stat.sampleBy(
+      struct($"name", $"key"), Map(Row("Foo", 0) -> 0.1, Row("Foo", 1) -> 
0.2), 0L)
+    checkAnswer(
+      sampled.groupBy("key").count().orderBy("key"),
+      Seq(Row(0, 6), Row(1, 11)))
+  }
+
   // This test case only verifies that `DataFrame.countMinSketch()` methods do 
return
   // `CountMinSketch`es that meet required specs.  Test cases for 
`CountMinSketch` can be found in
   // `CountMinSketchSuite` in project spark-sketch.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to