This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c14320497d5 [SPARK-40301][PYTHON] Add parameter validations in 
pyspark.rdd
c14320497d5 is described below

commit c14320497d5415616b3f65b120336f963d8ab46b
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sun Sep 4 14:29:42 2022 +0800

    [SPARK-40301][PYTHON] Add parameter validations in pyspark.rdd
    
    ### What changes were proposed in this pull request?
    1,compared with the scala side, some parameter validations were missing in 
`pyspark.rdd`
    2, `rdd.sample` checking fraction will raise `ValueError` instead of 
`AssertionError`
    
    ### Why are the changes needed?
    add missing parameter validations in `pyspark.rdd`
    
    ### Does this PR introduce _any_ user-facing change?
    yes, when fraction is invalide, `ValueError` is raised instead of 
`AssertionError`
    
    ### How was this patch tested?
    existing testsutes
    
    Closes #37752 from zhengruifeng/py_rdd_check.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/rdd.py | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 5fe463233a2..7ef0014ae75 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1039,7 +1039,8 @@ class RDD(Generic[T_co]):
         >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14
         True
         """
-        assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+        if not fraction >= 0:
+            raise ValueError("Fraction must be nonnegative.")
         return self.mapPartitionsWithIndex(RDDSampler(withReplacement, 
fraction, seed).func, True)
 
     def randomSplit(
@@ -1077,7 +1078,11 @@ class RDD(Generic[T_co]):
         >>> 250 < rdd2.count() < 350
         True
         """
+        if not all(w >= 0 for w in weights):
+            raise ValueError("Weights must be nonnegative")
         s = float(sum(weights))
+        if not s > 0:
+            raise ValueError("Sum of weights must be positive")
         cweights = [0.0]
         for w in weights:
             cweights.append(cweights[-1] + w / s)
@@ -4565,6 +4570,8 @@ class RDD(Generic[T_co]):
         >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect()
         [[1, 2, 3, 4, 5]]
         """
+        if not numPartitions > 0:
+            raise ValueError("Number of partitions must be positive.")
         if shuffle:
             # Decrease the batch size in order to distribute evenly the 
elements across output
             # partitions. Otherwise, repartition will possibly produce highly 
skewed partitions.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to