This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a7f191ba5947 [SPARK-49640][PS] Apply reservoir sampling in 
`SampledPlotBase`
a7f191ba5947 is described below

commit a7f191ba5947075066154a33da7908b24c412ccb
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Sep 18 08:44:22 2024 +0800

    [SPARK-49640][PS] Apply reservoir sampling in `SampledPlotBase`
    
    ### What changes were proposed in this pull request?
    Apply reservoir sampling in `SampledPlotBase`
    
    ### Why are the changes needed?
    Existing sampling approach has two drawbacks:
    
    1, it needs two jobs to sample `max_rows` rows:
    
    - df.count() to compute `fraction = max_rows / count`
    - df.sample(fraction).to_pandas() to do the sampling
    
    2, the df.sample is based on Bernoulli sampling which **cannot** guarantee 
the sampled size == expected `max_rows`, e.g.
    ```
    In [1]: df = spark.range(10000)
    
    In [2]: [df.sample(0.01).count() for i in range(0, 10)]
    Out[2]: [96, 97, 95, 97, 105, 105, 105, 87, 95, 110]
    ```
    The size of sampled data is floating near the target size 10000*0.01=100.
    This relative deviation cannot be ignored, when the input dataset is large 
and the sampling fraction is small.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    CI and manually check
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #48105 from zhengruifeng/ps_sampling.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/pandas/plot/core.py | 51 +++++++++++++++++++++++++++++++-------
 1 file changed, 42 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/pandas/plot/core.py 
b/python/pyspark/pandas/plot/core.py
index 067c7db664de..7630ecc39895 100644
--- a/python/pyspark/pandas/plot/core.py
+++ b/python/pyspark/pandas/plot/core.py
@@ -68,19 +68,52 @@ class SampledPlotBase:
     def get_sampled(self, data):
         from pyspark.pandas import DataFrame, Series
 
+        if not isinstance(data, (DataFrame, Series)):
+            raise TypeError("Only DataFrame and Series are supported for 
plotting.")
+        if isinstance(data, Series):
+            data = data.to_frame()
+
         fraction = get_option("plotting.sample_ratio")
-        if fraction is None:
-            fraction = 1 / (len(data) / get_option("plotting.max_rows"))
-            fraction = min(1.0, fraction)
-        self.fraction = fraction
-
-        if isinstance(data, (DataFrame, Series)):
-            if isinstance(data, Series):
-                data = data.to_frame()
+        if fraction is not None:
+            self.fraction = fraction
             sampled = 
data._internal.resolved_copy.spark_frame.sample(fraction=self.fraction)
             return DataFrame(data._internal.with_new_sdf(sampled))._to_pandas()
         else:
-            raise TypeError("Only DataFrame and Series are supported for 
plotting.")
+            from pyspark.sql import Observation
+
+            max_rows = get_option("plotting.max_rows")
+            observation = Observation("ps plotting")
+            sdf = data._internal.resolved_copy.spark_frame.observe(
+                observation, F.count(F.lit(1)).alias("count")
+            )
+
+            rand_col_name = "__ps_plotting_sampled_plot_base_rand__"
+            id_col_name = "__ps_plotting_sampled_plot_base_id__"
+
+            sampled = (
+                sdf.select(
+                    "*",
+                    F.rand().alias(rand_col_name),
+                    F.monotonically_increasing_id().alias(id_col_name),
+                )
+                .sort(rand_col_name)
+                .limit(max_rows + 1)
+                .coalesce(1)
+                .sortWithinPartitions(id_col_name)
+                .drop(rand_col_name, id_col_name)
+            )
+
+            pdf = DataFrame(data._internal.with_new_sdf(sampled))._to_pandas()
+
+            if len(pdf) > max_rows:
+                try:
+                    self.fraction = float(max_rows) / observation.get["count"]
+                except Exception:
+                    pass
+                return pdf[:max_rows]
+            else:
+                self.fraction = 1.0
+                return pdf
 
     def set_result_text(self, ax):
         assert hasattr(self, "fraction")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to