This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 44d165113dd [SPARK-41830][CONNECT][PYTHON] Make `DataFrame.sample` 
accept the same parameters as PySpark
44d165113dd is described below

commit 44d165113ddce621f0090d89624bcff554ae49bb
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Jan 5 19:19:00 2023 +0900

    [SPARK-41830][CONNECT][PYTHON] Make `DataFrame.sample` accept the same 
parameters as PySpark
    
    ### What changes were proposed in this pull request?
    Make `DataFrame.sample` accept the same parameters as PySpark.
    
    ### Why are the changes needed?
    For consistency
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    enabled doctests
    
    Closes #39403 from zhengruifeng/connect_fix_41830.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py | 55 ++++++++++++++++++++++++---------
 1 file changed, 41 insertions(+), 14 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 13a421ca72a..639e3faa748 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -405,26 +405,56 @@ class DataFrame:
 
     def sample(
         self,
-        fraction: float,
-        *,
-        withReplacement: bool = False,
+        withReplacement: Optional[Union[float, bool]] = None,
+        fraction: Optional[Union[int, float]] = None,
         seed: Optional[int] = None,
     ) -> "DataFrame":
-        if not isinstance(fraction, float):
-            raise TypeError(f"'fraction' must be float, but got 
{type(fraction).__name__}")
-        if not isinstance(withReplacement, bool):
+
+        # For the cases below:
+        #   sample(True, 0.5 [, seed])
+        #   sample(True, fraction=0.5 [, seed])
+        #   sample(withReplacement=False, fraction=0.5 [, seed])
+        is_withReplacement_set = type(withReplacement) == bool and 
isinstance(fraction, float)
+
+        # For the case below:
+        #   sample(faction=0.5 [, seed])
+        is_withReplacement_omitted_kwargs = withReplacement is None and 
isinstance(fraction, float)
+
+        # For the case below:
+        #   sample(0.5 [, seed])
+        is_withReplacement_omitted_args = isinstance(withReplacement, float)
+
+        if not (
+            is_withReplacement_set
+            or is_withReplacement_omitted_kwargs
+            or is_withReplacement_omitted_args
+        ):
+            argtypes = [
+                str(type(arg)) for arg in [withReplacement, fraction, seed] if 
arg is not None
+            ]
             raise TypeError(
-                f"'withReplacement' must be bool, but got 
{type(withReplacement).__name__}"
+                "withReplacement (optional), fraction (required) and seed 
(optional)"
+                " should be a bool, float and number; however, "
+                "got [%s]." % ", ".join(argtypes)
             )
-        if seed is not None and not isinstance(seed, int):
-            raise TypeError(f"'seed' must be None or int, but got 
{type(seed).__name__}")
+
+        if is_withReplacement_omitted_args:
+            if fraction is not None:
+                seed = cast(int, fraction)
+            fraction = withReplacement
+            withReplacement = None
+
+        if withReplacement is None:
+            withReplacement = False
+
+        seed = int(seed) if seed is not None else None
 
         return DataFrame.withPlan(
             plan.Sample(
                 child=self._plan,
                 lower_bound=0.0,
-                upper_bound=fraction,
-                with_replacement=withReplacement,
+                upper_bound=fraction,  # type: ignore[arg-type]
+                with_replacement=withReplacement,  # type: ignore[arg-type]
                 seed=seed,
             ),
             session=self._session,
@@ -1485,9 +1515,6 @@ def _test() -> None:
         # TODO(SPARK-41827): groupBy requires all cols be Column or str
         del pyspark.sql.connect.dataframe.DataFrame.groupBy.__doc__
 
-        # TODO(SPARK-41830): fix sample parameters
-        del pyspark.sql.connect.dataframe.DataFrame.sample.__doc__
-
         # TODO(SPARK-41831): fix transform to accept ColumnReference
         del pyspark.sql.connect.dataframe.DataFrame.transform.__doc__
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to