This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1f9a12cf7511 [SPARK-50255][PYTHON] Avoid unnecessary casting in 
`compute_hist`
1f9a12cf7511 is described below

commit 1f9a12cf7511807895a6c4fadc1d359f508178e5
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Nov 7 10:44:12 2024 +0800

    [SPARK-50255][PYTHON] Avoid unnecessary casting in `compute_hist`
    
    ### What changes were proposed in this pull request?
    Avoid unnecessary casting in `compute_hist`
    
    ### Why are the changes needed?
    the `__bucket` should be integer by its nature, it was double just because 
of the output type of `Bucketizer` in MLlib (almost all ML implementations 
returns double in transformation).
    After reimplementing it with Spark SQL, it no longer needs to be float.
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #48785 from zhengruifeng/plt_hist_cast.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Kent Yao <[email protected]>
---
 python/pyspark/sql/plot/core.py | 73 ++++++++++++++++++++---------------------
 1 file changed, 36 insertions(+), 37 deletions(-)

diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
index 9e67b6bac8b5..934509f4fcd3 100644
--- a/python/pyspark/sql/plot/core.py
+++ b/python/pyspark/sql/plot/core.py
@@ -565,24 +565,23 @@ class PySparkHistogramPlotBase:
         assert isinstance(bins, list)
 
         # 1. Make the bucket output flat to:
-        #     +----------+-------+
-        #     |__group_id|buckets|
-        #     +----------+-------+
-        #     |0         |0.0    |
-        #     |0         |0.0    |
-        #     |0         |1.0    |
-        #     |0         |2.0    |
-        #     |0         |3.0    |
-        #     |0         |3.0    |
-        #     |1         |0.0    |
-        #     |1         |1.0    |
-        #     |1         |1.0    |
-        #     |1         |2.0    |
-        #     |1         |1.0    |
-        #     |1         |0.0    |
-        #     +----------+-------+
+        #     +----------+--------+
+        #     |__group_id|__bucket|
+        #     +----------+--------+
+        #     |0         |0       |
+        #     |0         |0       |
+        #     |0         |1       |
+        #     |0         |2       |
+        #     |0         |3       |
+        #     |0         |3       |
+        #     |1         |0       |
+        #     |1         |1       |
+        #     |1         |1       |
+        #     |1         |2       |
+        #     |1         |1       |
+        #     |1         |0       |
+        #     +----------+--------+
         colnames = sdf.columns
-        bucket_names = ["__{}_bucket".format(colname) for colname in colnames]
 
         # determines which bucket a given value falls into, based on 
predefined bin intervals
         # refers to 
org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets
@@ -605,22 +604,22 @@ class PySparkHistogramPlotBase:
             .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN())
             .select(
                 F.col("__group_id"),
-                
binary_search_for_buckets(F.col("__value")).cast("double").alias("__bucket"),
+                binary_search_for_buckets(F.col("__value")).alias("__bucket"),
             )
         )
 
         # 2. Calculate the count based on each group and bucket.
-        #     +----------+-------+------+
-        #     |__group_id|buckets| count|
-        #     +----------+-------+------+
-        #     |0         |0.0    |2     |
-        #     |0         |1.0    |1     |
-        #     |0         |2.0    |1     |
-        #     |0         |3.0    |2     |
-        #     |1         |0.0    |2     |
-        #     |1         |1.0    |3     |
-        #     |1         |2.0    |1     |
-        #     +----------+-------+------+
+        #     +----------+--------+------+
+        #     |__group_id|__bucket| count|
+        #     +----------+--------+------+
+        #     |0         |0       |2     |
+        #     |0         |1       |1     |
+        #     |0         |2       |1     |
+        #     |0         |3       |2     |
+        #     |1         |0       |2     |
+        #     |1         |1       |3     |
+        #     |1         |2       |1     |
+        #     +----------+--------+------+
         result = (
             output_df.groupby("__group_id", "__bucket")
             .agg(F.count("*").alias("count"))
@@ -632,17 +631,17 @@ class PySparkHistogramPlotBase:
         #     +----------+--------+------+
         #     |__group_id|__bucket| count|
         #     +----------+--------+------+
-        #     |0         |0.0     |2     |
-        #     |0         |1.0     |1     |
-        #     |0         |2.0     |1     |
-        #     |0         |3.0     |2     |
+        #     |0         |0       |2     |
+        #     |0         |1       |1     |
+        #     |0         |2       |1     |
+        #     |0         |3       |2     |
         #     +----------+--------+------+
         #     +----------+--------+------+
         #     |__group_id|__bucket| count|
         #     +----------+--------+------+
-        #     |1         |0.0     |2     |
-        #     |1         |1.0     |3     |
-        #     |1         |2.0     |1     |
+        #     |1         |0       |2     |
+        #     |1         |1       |3     |
+        #     |1         |2       |1     |
         #     +----------+--------+------+
         #
         # to:
@@ -665,7 +664,7 @@ class PySparkHistogramPlotBase:
         #     |0                |
         #     +-----------------+
         output_series = []
-        for i, (input_column_name, bucket_name) in enumerate(zip(colnames, 
bucket_names)):
+        for i, input_column_name in enumerate(colnames):
             current_bucket_result = result[result["__group_id"] == i]
             # generates a pandas DF with one row for each bin
             # we need this as some of the bins may be empty


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to