This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0aa854129d37 [SPARK-50280][PYTHON] Refactor result sorting and empty 
bin filling in `compute_hist`
0aa854129d37 is described below

commit 0aa854129d37f672230a8a7a80d63dcc733b7c51
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Nov 11 08:31:24 2024 -0800

    [SPARK-50280][PYTHON] Refactor result sorting and empty bin filling in 
`compute_hist`
    
    ### What changes were proposed in this pull request?
    Refactor result sorting and empty bin filling in `compute_hist`
    
    ### Why are the changes needed?
    to simplify the computation happening at driver (before `toPandas`)
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #48814 from zhengruifeng/plt_hist.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/sql/plot/core.py | 44 ++++++++++++++++++++++++++---------------
 1 file changed, 28 insertions(+), 16 deletions(-)

diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
index 934509f4fcd3..f7133bdb70ed 100644
--- a/python/pyspark/sql/plot/core.py
+++ b/python/pyspark/sql/plot/core.py
@@ -560,10 +560,12 @@ class PySparkHistogramPlotBase:
     @staticmethod
     def compute_hist(sdf: "DataFrame", bins: Sequence[float]) -> 
List["pd.Series"]:
         require_minimum_pandas_version()
-        import pandas as pd
 
         assert isinstance(bins, list)
 
+        spark = sdf._session
+        assert spark is not None
+
         # 1. Make the bucket output flat to:
         #     +----------+--------+
         #     |__group_id|__bucket|
@@ -608,7 +610,7 @@ class PySparkHistogramPlotBase:
             )
         )
 
-        # 2. Calculate the count based on each group and bucket.
+        # 2. Calculate the count based on each group and bucket, also fill 
empty bins.
         #     +----------+--------+------+
         #     |__group_id|__bucket| count|
         #     +----------+--------+------+
@@ -619,15 +621,29 @@ class PySparkHistogramPlotBase:
         #     |1         |0       |2     |
         #     |1         |1       |3     |
         #     |1         |2       |1     |
+        #     |1         |3       |0     | <- fill empty bins with zeros (by 
joining with bin_df)
         #     +----------+--------+------+
-        result = (
-            output_df.groupby("__group_id", "__bucket")
-            .agg(F.count("*").alias("count"))
-            .toPandas()
-            .sort_values(by=["__group_id", "__bucket"])
+        output_df = output_df.groupby("__group_id", 
"__bucket").agg(F.count("*").alias("count"))
+
+        # Generate all possible combinations of group id and bucket
+        bin_df = (
+            spark.range(len(colnames))
+            .select(
+                F.col("id").alias("__group_id"),
+                F.explode(F.lit(list(range(len(bins) - 1)))).alias("__bucket"),
+            )
+            .hint("broadcast")
         )
 
-        # 3. Fill empty bins and calculate based on each group id. From:
+        output_df = (
+            bin_df.join(output_df, ["__group_id", "__bucket"], "left")
+            .select("__group_id", "__bucket", F.nvl(F.col("count"), 
F.lit(0)).alias("count"))
+            .coalesce(1)
+            .sortWithinPartitions("__group_id", "__bucket")
+            .select("__group_id", "count")
+        )
+
+        # 3. Calculate based on each group id. From:
         #     +----------+--------+------+
         #     |__group_id|__bucket| count|
         #     +----------+--------+------+
@@ -642,6 +658,7 @@ class PySparkHistogramPlotBase:
         #     |1         |0       |2     |
         #     |1         |1       |3     |
         #     |1         |2       |1     |
+        #     |1         |3       |0     |
         #     +----------+--------+------+
         #
         # to:
@@ -663,16 +680,11 @@ class PySparkHistogramPlotBase:
         #     |0                |
         #     |0                |
         #     +-----------------+
+        result = output_df.toPandas()
         output_series = []
         for i, input_column_name in enumerate(colnames):
-            current_bucket_result = result[result["__group_id"] == i]
-            # generates a pandas DF with one row for each bin
-            # we need this as some of the bins may be empty
-            indexes = pd.DataFrame({"__bucket": list(range(0, len(bins) - 1))})
-            # merges the bins with counts on it and fills remaining ones with 
zeros
-            pdf = indexes.merge(current_bucket_result, how="left", 
on=["__bucket"]).fillna(0)[
-                ["count"]
-            ]
+            pdf = result[result["__group_id"] == i]
+            pdf = pdf[["count"]]
             pdf.columns = [input_column_name]
             output_series.append(pdf[input_column_name])
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to