This is an automated email from the ASF dual-hosted git repository.
yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1f9a12cf7511 [SPARK-50255][PYTHON] Avoid unnecessary casting in
`compute_hist`
1f9a12cf7511 is described below
commit 1f9a12cf7511807895a6c4fadc1d359f508178e5
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Nov 7 10:44:12 2024 +0800
[SPARK-50255][PYTHON] Avoid unnecessary casting in `compute_hist`
### What changes were proposed in this pull request?
Avoid unnecessary casting in `compute_hist`
### Why are the changes needed?
the `__bucket` should be integer by its nature, it was double just because
of the output type of `Bucketizer` in MLlib (almost all ML implementations
returns double in transformation).
After reimplementing it with Spark SQL, it no longer needs to be float.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48785 from zhengruifeng/plt_hist_cast.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Kent Yao <[email protected]>
---
python/pyspark/sql/plot/core.py | 73 ++++++++++++++++++++---------------------
1 file changed, 36 insertions(+), 37 deletions(-)
diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
index 9e67b6bac8b5..934509f4fcd3 100644
--- a/python/pyspark/sql/plot/core.py
+++ b/python/pyspark/sql/plot/core.py
@@ -565,24 +565,23 @@ class PySparkHistogramPlotBase:
assert isinstance(bins, list)
# 1. Make the bucket output flat to:
- # +----------+-------+
- # |__group_id|buckets|
- # +----------+-------+
- # |0 |0.0 |
- # |0 |0.0 |
- # |0 |1.0 |
- # |0 |2.0 |
- # |0 |3.0 |
- # |0 |3.0 |
- # |1 |0.0 |
- # |1 |1.0 |
- # |1 |1.0 |
- # |1 |2.0 |
- # |1 |1.0 |
- # |1 |0.0 |
- # +----------+-------+
+ # +----------+--------+
+ # |__group_id|__bucket|
+ # +----------+--------+
+ # |0 |0 |
+ # |0 |0 |
+ # |0 |1 |
+ # |0 |2 |
+ # |0 |3 |
+ # |0 |3 |
+ # |1 |0 |
+ # |1 |1 |
+ # |1 |1 |
+ # |1 |2 |
+ # |1 |1 |
+ # |1 |0 |
+ # +----------+--------+
colnames = sdf.columns
- bucket_names = ["__{}_bucket".format(colname) for colname in colnames]
# determines which bucket a given value falls into, based on
predefined bin intervals
# refers to
org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets
@@ -605,22 +604,22 @@ class PySparkHistogramPlotBase:
.where(F.col("__value").isNotNull() & ~F.col("__value").isNaN())
.select(
F.col("__group_id"),
-
binary_search_for_buckets(F.col("__value")).cast("double").alias("__bucket"),
+ binary_search_for_buckets(F.col("__value")).alias("__bucket"),
)
)
# 2. Calculate the count based on each group and bucket.
- # +----------+-------+------+
- # |__group_id|buckets| count|
- # +----------+-------+------+
- # |0 |0.0 |2 |
- # |0 |1.0 |1 |
- # |0 |2.0 |1 |
- # |0 |3.0 |2 |
- # |1 |0.0 |2 |
- # |1 |1.0 |3 |
- # |1 |2.0 |1 |
- # +----------+-------+------+
+ # +----------+--------+------+
+ # |__group_id|__bucket| count|
+ # +----------+--------+------+
+ # |0 |0 |2 |
+ # |0 |1 |1 |
+ # |0 |2 |1 |
+ # |0 |3 |2 |
+ # |1 |0 |2 |
+ # |1 |1 |3 |
+ # |1 |2 |1 |
+ # +----------+--------+------+
result = (
output_df.groupby("__group_id", "__bucket")
.agg(F.count("*").alias("count"))
@@ -632,17 +631,17 @@ class PySparkHistogramPlotBase:
# +----------+--------+------+
# |__group_id|__bucket| count|
# +----------+--------+------+
- # |0 |0.0 |2 |
- # |0 |1.0 |1 |
- # |0 |2.0 |1 |
- # |0 |3.0 |2 |
+ # |0 |0 |2 |
+ # |0 |1 |1 |
+ # |0 |2 |1 |
+ # |0 |3 |2 |
# +----------+--------+------+
# +----------+--------+------+
# |__group_id|__bucket| count|
# +----------+--------+------+
- # |1 |0.0 |2 |
- # |1 |1.0 |3 |
- # |1 |2.0 |1 |
+ # |1 |0 |2 |
+ # |1 |1 |3 |
+ # |1 |2 |1 |
# +----------+--------+------+
#
# to:
@@ -665,7 +664,7 @@ class PySparkHistogramPlotBase:
# |0 |
# +-----------------+
output_series = []
- for i, (input_column_name, bucket_name) in enumerate(zip(colnames,
bucket_names)):
+ for i, input_column_name in enumerate(colnames):
current_bucket_result = result[result["__group_id"] == i]
# generates a pandas DF with one row for each bin
# we need this as some of the bins may be empty
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]