This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0aa854129d37 [SPARK-50280][PYTHON] Refactor result sorting and empty
bin filling in `compute_hist`
0aa854129d37 is described below
commit 0aa854129d37f672230a8a7a80d63dcc733b7c51
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Nov 11 08:31:24 2024 -0800
[SPARK-50280][PYTHON] Refactor result sorting and empty bin filling in
`compute_hist`
### What changes were proposed in this pull request?
Refactor result sorting and empty bin filling in `compute_hist`
### Why are the changes needed?
to simplify the computation happening at driver (before `toPandas`)
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48814 from zhengruifeng/plt_hist.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/plot/core.py | 44 ++++++++++++++++++++++++++---------------
1 file changed, 28 insertions(+), 16 deletions(-)
diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
index 934509f4fcd3..f7133bdb70ed 100644
--- a/python/pyspark/sql/plot/core.py
+++ b/python/pyspark/sql/plot/core.py
@@ -560,10 +560,12 @@ class PySparkHistogramPlotBase:
@staticmethod
def compute_hist(sdf: "DataFrame", bins: Sequence[float]) ->
List["pd.Series"]:
require_minimum_pandas_version()
- import pandas as pd
assert isinstance(bins, list)
+ spark = sdf._session
+ assert spark is not None
+
# 1. Make the bucket output flat to:
# +----------+--------+
# |__group_id|__bucket|
@@ -608,7 +610,7 @@ class PySparkHistogramPlotBase:
)
)
- # 2. Calculate the count based on each group and bucket.
+ # 2. Calculate the count based on each group and bucket, also fill
empty bins.
# +----------+--------+------+
# |__group_id|__bucket| count|
# +----------+--------+------+
@@ -619,15 +621,29 @@ class PySparkHistogramPlotBase:
# |1 |0 |2 |
# |1 |1 |3 |
# |1 |2 |1 |
+ # |1 |3 |0 | <- fill empty bins with zeros (by
joining with bin_df)
# +----------+--------+------+
- result = (
- output_df.groupby("__group_id", "__bucket")
- .agg(F.count("*").alias("count"))
- .toPandas()
- .sort_values(by=["__group_id", "__bucket"])
+ output_df = output_df.groupby("__group_id",
"__bucket").agg(F.count("*").alias("count"))
+
+ # Generate all possible combinations of group id and bucket
+ bin_df = (
+ spark.range(len(colnames))
+ .select(
+ F.col("id").alias("__group_id"),
+ F.explode(F.lit(list(range(len(bins) - 1)))).alias("__bucket"),
+ )
+ .hint("broadcast")
)
- # 3. Fill empty bins and calculate based on each group id. From:
+ output_df = (
+ bin_df.join(output_df, ["__group_id", "__bucket"], "left")
+ .select("__group_id", "__bucket", F.nvl(F.col("count"),
F.lit(0)).alias("count"))
+ .coalesce(1)
+ .sortWithinPartitions("__group_id", "__bucket")
+ .select("__group_id", "count")
+ )
+
+ # 3. Calculate based on each group id. From:
# +----------+--------+------+
# |__group_id|__bucket| count|
# +----------+--------+------+
@@ -642,6 +658,7 @@ class PySparkHistogramPlotBase:
# |1 |0 |2 |
# |1 |1 |3 |
# |1 |2 |1 |
+ # |1 |3 |0 |
# +----------+--------+------+
#
# to:
@@ -663,16 +680,11 @@ class PySparkHistogramPlotBase:
# |0 |
# |0 |
# +-----------------+
+ result = output_df.toPandas()
output_series = []
for i, input_column_name in enumerate(colnames):
- current_bucket_result = result[result["__group_id"] == i]
- # generates a pandas DF with one row for each bin
- # we need this as some of the bins may be empty
- indexes = pd.DataFrame({"__bucket": list(range(0, len(bins) - 1))})
- # merges the bins with counts on it and fills remaining ones with
zeros
- pdf = indexes.merge(current_bucket_result, how="left",
on=["__bucket"]).fillna(0)[
- ["count"]
- ]
+ pdf = result[result["__group_id"] == i]
+ pdf = pdf[["count"]]
pdf.columns = [input_column_name]
output_series.append(pdf[input_column_name])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]