Repository: spark
Updated Branches:
  refs/heads/master f5e10a34e -> 5cd8ea99f


[SPARK-21779][PYTHON] Simpler DataFrame.sample API in Python

## What changes were proposed in this pull request?

This PR make `DataFrame.sample(...)` can omit `withReplacement` defaulting 
`False`, consistently with equivalent Scala / Java API.

In short, the following examples are allowed:

```python
>>> df = spark.range(10)
>>> df.sample(0.5).count()
7
>>> df.sample(fraction=0.5).count()
3
>>> df.sample(0.5, seed=42).count()
5
>>> df.sample(fraction=0.5, seed=42).count()
5
```

In addition, this PR also adds some type checking logics as below:

```python
>>> df = spark.range(10)
>>> df.sample().count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) 
should be a bool, float and number; however, got [].
>>> df.sample(True).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) 
should be a bool, float and number; however, got [<type 'bool'>].
>>> df.sample(42).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) 
should be a bool, float and number; however, got [<type 'int'>].
>>> df.sample(fraction=False, seed="a").count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) 
should be a bool, float and number; however, got [<type 'bool'>, <type 'str'>].
>>> df.sample(seed=[1]).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) 
should be a bool, float and number; however, got [<type 'list'>].
>>> df.sample(withReplacement="a", fraction=0.5, seed=1)
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) 
should be a bool, float and number; however, got [<type 'str'>, <type 'float'>, 
<type 'int'>].
```

## How was this patch tested?

Manually tested, unit tests added in doc tests and manually checked the built 
documentation for Python.

Author: hyukjinkwon <gurwls...@gmail.com>

Closes #18999 from HyukjinKwon/SPARK-21779.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5cd8ea99
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5cd8ea99
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5cd8ea99

Branch: refs/heads/master
Commit: 5cd8ea99f084bee40ee18a0c8e33d0ca0aa6bb60
Parents: f5e10a3
Author: hyukjinkwon <gurwls...@gmail.com>
Authored: Fri Sep 1 13:01:23 2017 +0900
Committer: hyukjinkwon <gurwls...@gmail.com>
Committed: Fri Sep 1 13:01:23 2017 +0900

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 64 +++++++++++++++++---
 python/pyspark/sql/tests.py                     | 18 ++++++
 .../scala/org/apache/spark/sql/Dataset.scala    |  3 +-
 3 files changed, 77 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5cd8ea99/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index d1b2a9c..c19e599 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -659,19 +659,69 @@ class DataFrame(object):
         return DataFrame(self._jdf.distinct(), self.sql_ctx)
 
     @since(1.3)
-    def sample(self, withReplacement, fraction, seed=None):
+    def sample(self, withReplacement=None, fraction=None, seed=None):
         """Returns a sampled subset of this :class:`DataFrame`.
 
+        :param withReplacement: Sample with replacement or not (default False).
+        :param fraction: Fraction of rows to generate, range [0.0, 1.0].
+        :param seed: Seed for sampling (default a random seed).
+
         .. note:: This is not guaranteed to provide exactly the fraction 
specified of the total
             count of the given :class:`DataFrame`.
 
-        >>> df.sample(False, 0.5, 42).count()
-        2
+        .. note:: `fraction` is required and, `withReplacement` and `seed` are 
optional.
+
+        >>> df = spark.range(10)
+        >>> df.sample(0.5, 3).count()
+        4
+        >>> df.sample(fraction=0.5, seed=3).count()
+        4
+        >>> df.sample(withReplacement=True, fraction=0.5, seed=3).count()
+        1
+        >>> df.sample(1.0).count()
+        10
+        >>> df.sample(fraction=1.0).count()
+        10
+        >>> df.sample(False, fraction=1.0).count()
+        10
         """
-        assert fraction >= 0.0, "Negative fraction value: %s" % fraction
-        seed = seed if seed is not None else random.randint(0, sys.maxsize)
-        rdd = self._jdf.sample(withReplacement, fraction, long(seed))
-        return DataFrame(rdd, self.sql_ctx)
+
+        # For the cases below:
+        #   sample(True, 0.5 [, seed])
+        #   sample(True, fraction=0.5 [, seed])
+        #   sample(withReplacement=False, fraction=0.5 [, seed])
+        is_withReplacement_set = \
+            type(withReplacement) == bool and isinstance(fraction, float)
+
+        # For the case below:
+        #   sample(faction=0.5 [, seed])
+        is_withReplacement_omitted_kwargs = \
+            withReplacement is None and isinstance(fraction, float)
+
+        # For the case below:
+        #   sample(0.5 [, seed])
+        is_withReplacement_omitted_args = isinstance(withReplacement, float)
+
+        if not (is_withReplacement_set
+                or is_withReplacement_omitted_kwargs
+                or is_withReplacement_omitted_args):
+            argtypes = [
+                str(type(arg)) for arg in [withReplacement, fraction, seed] if 
arg is not None]
+            raise TypeError(
+                "withReplacement (optional), fraction (required) and seed 
(optional)"
+                " should be a bool, float and number; however, "
+                "got [%s]." % ", ".join(argtypes))
+
+        if is_withReplacement_omitted_args:
+            if fraction is not None:
+                seed = fraction
+            fraction = withReplacement
+            withReplacement = None
+
+        seed = long(seed) if seed is not None else None
+        args = [arg for arg in [withReplacement, fraction, seed] if arg is not 
None]
+        jdf = self._jdf.sample(*args)
+        return DataFrame(jdf, self.sql_ctx)
 
     @since(1.5)
     def sampleBy(self, col, fractions, seed=None):

http://git-wip-us.apache.org/repos/asf/spark/blob/5cd8ea99/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b310285..a2a3ceb 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -2108,6 +2108,24 @@ class SQLTests(ReusedPySparkTestCase):
         plan = df1.join(df2.hint("broadcast"), 
"id")._jdf.queryExecution().executedPlan()
         self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
 
+    def test_sample(self):
+        self.assertRaisesRegexp(
+            TypeError,
+            "should be a bool, float and number",
+            lambda: self.spark.range(1).sample())
+
+        self.assertRaises(
+            TypeError,
+            lambda: self.spark.range(1).sample("a"))
+
+        self.assertRaises(
+            TypeError,
+            lambda: self.spark.range(1).sample(seed="abc"))
+
+        self.assertRaises(
+            IllegalArgumentException,
+            lambda: self.spark.range(1).sample(-1.0))
+
     def test_toDF_with_schema_string(self):
         data = [Row(key=i, value=str(i)) for i in range(100)]
         rdd = self.sc.parallelize(data, 5)

http://git-wip-us.apache.org/repos/asf/spark/blob/5cd8ea99/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c670739..5d8a183 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1867,7 +1867,8 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Returns a new [[Dataset]] by sampling a fraction of rows (without 
replacement).
+   * Returns a new [[Dataset]] by sampling a fraction of rows (without 
replacement),
+   * using a random seed.
    *
    * @param fraction Fraction of rows to generate, range [0.0, 1.0].
    *


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to