This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4b7191d43b4b [SPARK-49367][PS] Parallelize the KDE computation for 
multiple columns (plotly backend)
4b7191d43b4b is described below

commit 4b7191d43b4b505faa1e26481311c5e83e6340e5
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Aug 26 09:11:16 2024 +0800

    [SPARK-49367][PS] Parallelize the KDE computation for multiple columns 
(plotly backend)
    
    ### What changes were proposed in this pull request?
    Parallelize the KDE computation for `plotly` backend.
    
    Note that `matplotlib` backend is not optimized in this PR, due to the 
computation logic is slightly different between `plotly` and `matplotlib`:
    1, `plotly`: compute a global `ind` across all input columns, and then 
compute all curves based on it;
    2, `matplotlib`: for each input column, compute its `ind` and then the 
curve;
    
    I think `matplotlib`'s approach seems more reasonable, but it make this 
optimization cannot be directly applied on `matplotlib`, so it needs more 
investigation.
    
    ### Why are the changes needed?
    existing implementation compute each curve once, this PR aims to compute 
multiple columns together
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    CI and manually test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #47854 from zhengruifeng/plot_parallelize_kde.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/pandas/plot/core.py   | 32 ++++++++++++++++++--------------
 python/pyspark/pandas/plot/plotly.py | 26 ++++++++++++++++----------
 2 files changed, 34 insertions(+), 24 deletions(-)

diff --git a/python/pyspark/pandas/plot/core.py 
b/python/pyspark/pandas/plot/core.py
index e5db0bd701f1..c1dc7d2dc621 100644
--- a/python/pyspark/pandas/plot/core.py
+++ b/python/pyspark/pandas/plot/core.py
@@ -474,7 +474,7 @@ class KdePlotBase(NumericPlotBase):
         return ind
 
     @staticmethod
-    def compute_kde(sdf, bw_method=None, ind=None):
+    def compute_kde_col(input_col, bw_method=None, ind=None):
         # refers to org.apache.spark.mllib.stat.KernelDensity
         assert bw_method is not None and isinstance(
             bw_method, (int, float)
@@ -497,21 +497,25 @@ class KdePlotBase(NumericPlotBase):
             log_density = -0.5 * x1 * x1 - log_std_plus_half_log2_pi
             return F.exp(log_density)
 
-        dataCol = F.col(sdf.columns[0]).cast("double")
-
-        estimated = [
-            F.avg(
-                norm_pdf(
-                    dataCol,
-                    F.lit(bandwidth),
-                    F.lit(log_std_plus_half_log2_pi),
-                    F.lit(point),
+        return F.array(
+            [
+                F.avg(
+                    norm_pdf(
+                        input_col.cast("double"),
+                        F.lit(bandwidth),
+                        F.lit(log_std_plus_half_log2_pi),
+                        F.lit(point),
+                    )
                 )
-            )
-            for point in points
-        ]
+                for point in points
+            ]
+        )
 
-        row = sdf.select(F.array(estimated)).first()
+    @staticmethod
+    def compute_kde(sdf, bw_method=None, ind=None):
+        input_col = F.col(sdf.columns[0])
+        kde_col = KdePlotBase.compute_kde_col(input_col, bw_method, 
ind).alias("kde")
+        row = sdf.select(kde_col).first()
         return row[0]
 
 
diff --git a/python/pyspark/pandas/plot/plotly.py 
b/python/pyspark/pandas/plot/plotly.py
index d54166a33a0a..4de313b1e831 100644
--- a/python/pyspark/pandas/plot/plotly.py
+++ b/python/pyspark/pandas/plot/plotly.py
@@ -239,22 +239,28 @@ def plot_kde(data: Union["ps.DataFrame", "ps.Series"], 
**kwargs):
     ind = KdePlotBase.get_ind(sdf.select(*data_columns), kwargs.pop("ind", 
None))
     bw_method = kwargs.pop("bw_method", None)
 
-    pdfs = []
-    for label in psdf._internal.column_labels:
-        pdfs.append(
+    kde_cols = [
+        KdePlotBase.compute_kde_col(
+            input_col=psdf._internal.spark_column_for(label),
+            ind=ind,
+            bw_method=bw_method,
+        ).alias(f"kde_{i}")
+        for i, label in enumerate(psdf._internal.column_labels)
+    ]
+    kde_results = sdf.select(*kde_cols).first()
+
+    pdf = pd.concat(
+        [
             pd.DataFrame(
                 {
-                    "Density": KdePlotBase.compute_kde(
-                        sdf.select(psdf._internal.spark_column_for(label)),
-                        ind=ind,
-                        bw_method=bw_method,
-                    ),
+                    "Density": kde_result,
                     "names": name_like_string(label),
                     "index": ind,
                 }
             )
-        )
-    pdf = pd.concat(pdfs)
+            for label, kde_result in zip(psdf._internal.column_labels, 
list(kde_results))
+        ]
+    )
 
     fig = express.line(pdf, x="index", y="Density", **kwargs)
     fig["layout"]["xaxis"]["title"] = None


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to