This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8409da3d1815 [SPARK-49382][PS] Make frame box plot properly render the 
fliers/outliers
8409da3d1815 is described below

commit 8409da3d1815832132cd1006290679c0bed7d9f4
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Aug 26 13:12:55 2024 +0800

    [SPARK-49382][PS] Make frame box plot properly render the fliers/outliers
    
    ### What changes were proposed in this pull request?
    fliers/outliers was ignored in the initial implementation 
https://github.com/apache/spark/pull/36317
    
    ### Why are the changes needed?
    feature parity for Pandas and Series box plot
    
    ### Does this PR introduce _any_ user-facing change?
    
    ```
    import pyspark.pandas as ps
    df = ps.DataFrame([[5.1, 3.5, 0], [4.9, 3.0, 0], [7.0, 3.2, 1], [6.4, 3.2, 
1], [5.9, 3.0, 2], [100, 200, 300]], columns=['length', 'width', 'species'])
    df.boxplot()
    ```
    
    `df.length.plot.box()`
    
![image](https://github.com/user-attachments/assets/43da563c-5f68-4305-ad27-a4f04815dfd1)
    
    before:
    `df.boxplot()`
    
![image](https://github.com/user-attachments/assets/e25c2760-c12a-4801-a730-3987a020f889)
    
    after:
    `df.boxplot()`
    
![image](https://github.com/user-attachments/assets/c19f13b1-b9e4-423e-bcec-0c47c1c8df32)
    
    ### How was this patch tested?
    CI and manually check
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #47866 from zhengruifeng/plot_hist_fly.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/pandas/plot/core.py                 | 32 ++++++++++++++++++++++
 python/pyspark/pandas/plot/plotly.py               | 10 ++++++-
 python/pyspark/pandas/spark/functions.py           | 13 +++++++++
 .../spark/sql/api/python/PythonSQLUtils.scala      |  3 ++
 4 files changed, 57 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/pandas/plot/core.py 
b/python/pyspark/pandas/plot/core.py
index c1dc7d2dc621..2e188b411df1 100644
--- a/python/pyspark/pandas/plot/core.py
+++ b/python/pyspark/pandas/plot/core.py
@@ -26,6 +26,7 @@ from pandas.core.dtypes.inference import is_integer
 
 from pyspark.sql import functions as F, Column
 from pyspark.sql.types import DoubleType
+from pyspark.pandas.spark import functions as SF
 from pyspark.pandas.missing import unsupported_function
 from pyspark.pandas.config import get_option
 from pyspark.pandas.utils import name_like_string
@@ -437,6 +438,37 @@ class BoxPlotBase:
 
         return fliers
 
+    @staticmethod
+    def get_multicol_fliers(colnames, multicol_outliers, multicol_whiskers):
+        scols = []
+        extract_colnames = []
+        for i, colname in enumerate(colnames):
+            formated_colname = "`{}`".format(colname)
+            outlier_colname = "__{}_outlier".format(colname)
+            min_val = multicol_whiskers[colname]["min"]
+            pair_col = F.struct(
+                F.abs(F.col(formated_colname) - F.lit(min_val)).alias("ord"),
+                F.col(formated_colname).alias("val"),
+            )
+            scols.append(
+                SF.collect_top_k(
+                    F.when(F.col(outlier_colname), pair_col)
+                    .otherwise(F.lit(None))
+                    .alias(f"pair_{i}"),
+                    1001,
+                    False,
+                ).alias(f"top_{i}")
+            )
+            extract_colnames.append(f"top_{i}.val")
+
+        results = 
multicol_outliers.select(scols).select(extract_colnames).first()
+
+        fliers = {}
+        for i, colname in enumerate(colnames):
+            fliers[colname] = results[i]
+
+        return fliers
+
 
 class KdePlotBase(NumericPlotBase):
     @staticmethod
diff --git a/python/pyspark/pandas/plot/plotly.py 
b/python/pyspark/pandas/plot/plotly.py
index 4de313b1e831..0afcd6d7e869 100644
--- a/python/pyspark/pandas/plot/plotly.py
+++ b/python/pyspark/pandas/plot/plotly.py
@@ -199,11 +199,19 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], 
**kwargs):
         # Computes min and max values of non-outliers - the whiskers
         whiskers = BoxPlotBase.calc_multicol_whiskers(numeric_column_names, 
outliers)
 
+        fliers = None
+        if boxpoints:
+            fliers = BoxPlotBase.get_multicol_fliers(numeric_column_names, 
outliers, whiskers)
+
         i = 0
         for colname in numeric_column_names:
             col_stats = multicol_stats[colname]
             col_whiskers = whiskers[colname]
 
+            col_fliers = None
+            if fliers is not None and colname in fliers and 
len(fliers[colname]) > 0:
+                col_fliers = [fliers[colname]]
+
             fig.add_trace(
                 go.Box(
                     x=[i],
@@ -214,7 +222,7 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], 
**kwargs):
                     mean=[col_stats["mean"]],
                     lowerfence=[col_whiskers["min"]],
                     upperfence=[col_whiskers["max"]],
-                    y=None,  # todo: support y=fliers
+                    y=col_fliers,
                     boxpoints=boxpoints,
                     notched=notched,
                     **kwargs,
diff --git a/python/pyspark/pandas/spark/functions.py 
b/python/pyspark/pandas/spark/functions.py
index 8abeff655ea5..6bef3d9b87c0 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -174,6 +174,19 @@ def null_index(col: Column) -> Column:
         return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
 
 
+def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
+    if is_remote():
+        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns
+
+        return _invoke_function_over_columns("collect_top_k", col, F.lit(num), 
F.lit(reverse))
+
+    else:
+        from pyspark import SparkContext
+
+        sc = SparkContext._active_spark_context
+        return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, 
reverse))
+
+
 def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
     unit_mapping = {
         "YEAR": "years",
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 6b497553dcb0..c1c9af2ea427 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -149,6 +149,9 @@ private[sql] object PythonSQLUtils extends Logging {
 
   def nullIndex(e: Column): Column = Column.internalFn("null_index", e)
 
+  def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
+    Column.internalFn("collect_top_k", e, lit(num), lit(reverse))
+
   def pandasProduct(e: Column, ignoreNA: Boolean): Column =
     Column.internalFn("pandas_product", e, lit(ignoreNA))
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to