This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c14320497d5 [SPARK-40301][PYTHON] Add parameter validations in pyspark.rdd c14320497d5 is described below commit c14320497d5415616b3f65b120336f963d8ab46b Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sun Sep 4 14:29:42 2022 +0800 [SPARK-40301][PYTHON] Add parameter validations in pyspark.rdd ### What changes were proposed in this pull request? 1,compared with the scala side, some parameter validations were missing in `pyspark.rdd` 2, `rdd.sample` checking fraction will raise `ValueError` instead of `AssertionError` ### Why are the changes needed? add missing parameter validations in `pyspark.rdd` ### Does this PR introduce _any_ user-facing change? yes, when fraction is invalide, `ValueError` is raised instead of `AssertionError` ### How was this patch tested? existing testsutes Closes #37752 from zhengruifeng/py_rdd_check. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/rdd.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 5fe463233a2..7ef0014ae75 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1039,7 +1039,8 @@ class RDD(Generic[T_co]): >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14 True """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction + if not fraction >= 0: + raise ValueError("Fraction must be nonnegative.") return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) def randomSplit( @@ -1077,7 +1078,11 @@ class RDD(Generic[T_co]): >>> 250 < rdd2.count() < 350 True """ + if not all(w >= 0 for w in weights): + raise ValueError("Weights must be nonnegative") s = float(sum(weights)) + if not s > 0: + raise ValueError("Sum of weights must be positive") cweights = [0.0] for w in weights: cweights.append(cweights[-1] + w / s) @@ -4565,6 +4570,8 @@ class RDD(Generic[T_co]): >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() [[1, 2, 3, 4, 5]] """ + if not numPartitions > 0: + raise ValueError("Number of partitions must be positive.") if shuffle: # Decrease the batch size in order to distribute evenly the elements across output # partitions. Otherwise, repartition will possibly produce highly skewed partitions. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org