This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4fb4648ce3d [SPARK-38879][PYTHON][TEST] Improve the test coverage for
pyspark/rddsampler.py
4fb4648ce3d is described below
commit 4fb4648ce3d7fab65ccfceb86cb6c839d0c921da
Author: Kumar, Pralabh <[email protected]>
AuthorDate: Wed Apr 27 10:11:16 2022 +0900
[SPARK-38879][PYTHON][TEST] Improve the test coverage for
pyspark/rddsampler.py
### What changes were proposed in this pull request?
This PR add test cases for rddsampler
### Why are the changes needed?
To cover corner test cases and increase coverage
### Does this PR introduce _any_ user-facing change?
No - test only
### How was this patch tested?
CI in this PR should test it out
Closes #36342 from pralabhkumar/rk_rdd_sampler.
Lead-authored-by: Kumar, Pralabh <[email protected]>
Co-authored-by: pralabhkumar <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
dev/sparktestsupport/modules.py | 1 +
python/pyspark/tests/test_rddsampler.py | 66 +++++++++++++++++++++++++++++++++
2 files changed, 67 insertions(+)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 5514df11f9a..ed1eeb9b807 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -397,6 +397,7 @@ pyspark_core = Module(
"pyspark.tests.test_profiler",
"pyspark.tests.test_rdd",
"pyspark.tests.test_rddbarrier",
+ "pyspark.tests.test_rddsampler",
"pyspark.tests.test_readwrite",
"pyspark.tests.test_serializers",
"pyspark.tests.test_shuffle",
diff --git a/python/pyspark/tests/test_rddsampler.py
b/python/pyspark/tests/test_rddsampler.py
new file mode 100644
index 00000000000..b504c4ab980
--- /dev/null
+++ b/python/pyspark/tests/test_rddsampler.py
@@ -0,0 +1,66 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.testing.utils import ReusedPySparkTestCase
+from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
+
+
+class RDDSamplerTests(ReusedPySparkTestCase):
+ def test_rdd_sampler_func(self):
+ # SPARK-38879: Test case to improve test coverage for RDDSampler
+ # RDDSampler.func
+ rdd = self.sc.parallelize(range(20), 2)
+ sample_count = rdd.mapPartitionsWithIndex(RDDSampler(False, 0.4,
10).func).count()
+ self.assertGreater(sample_count, 3)
+ self.assertLess(sample_count, 10)
+ sample_data = rdd.mapPartitionsWithIndex(RDDSampler(True, 1,
10).func).collect()
+ sample_data.sort()
+ # check if at least one element is repeated.
+ self.assertTrue(
+ any(sample_data[i] == sample_data[i - 1] for i in range(1,
len(sample_data)))
+ )
+
+ def test_rdd_stratified_sampler_func(self):
+ # SPARK-38879: Test case to improve test coverage for RDDSampler
+ # RDDStratifiedSampler.func
+
+ fractions = {"a": 0.8, "b": 0.2}
+ rdd =
self.sc.parallelize(fractions.keys()).cartesian(self.sc.parallelize(range(0,
100)))
+ sample_data = dict(
+ rdd.mapPartitionsWithIndex(
+ RDDStratifiedSampler(False, fractions, 10).func, True
+ ).countByKey()
+ )
+ # Since a have higher sampling rate (0.8),
+ # it will occur more number of times than b.
+ self.assertGreater(sample_data["a"], sample_data["b"])
+ self.assertGreater(sample_data["a"], 60)
+ self.assertLess(sample_data["a"], 90)
+ self.assertGreater(sample_data["b"], 15)
+ self.assertLess(sample_data["b"], 30)
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.tests.test_rddsampler import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]