This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8409da3d1815 [SPARK-49382][PS] Make frame box plot properly render the
fliers/outliers
8409da3d1815 is described below
commit 8409da3d1815832132cd1006290679c0bed7d9f4
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Aug 26 13:12:55 2024 +0800
[SPARK-49382][PS] Make frame box plot properly render the fliers/outliers
### What changes were proposed in this pull request?
fliers/outliers was ignored in the initial implementation
https://github.com/apache/spark/pull/36317
### Why are the changes needed?
feature parity for Pandas and Series box plot
### Does this PR introduce _any_ user-facing change?
```
import pyspark.pandas as ps
df = ps.DataFrame([[5.1, 3.5, 0], [4.9, 3.0, 0], [7.0, 3.2, 1], [6.4, 3.2,
1], [5.9, 3.0, 2], [100, 200, 300]], columns=['length', 'width', 'species'])
df.boxplot()
```
`df.length.plot.box()`

before:
`df.boxplot()`

after:
`df.boxplot()`

### How was this patch tested?
CI and manually check
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47866 from zhengruifeng/plot_hist_fly.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/pandas/plot/core.py | 32 ++++++++++++++++++++++
python/pyspark/pandas/plot/plotly.py | 10 ++++++-
python/pyspark/pandas/spark/functions.py | 13 +++++++++
.../spark/sql/api/python/PythonSQLUtils.scala | 3 ++
4 files changed, 57 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/pandas/plot/core.py
b/python/pyspark/pandas/plot/core.py
index c1dc7d2dc621..2e188b411df1 100644
--- a/python/pyspark/pandas/plot/core.py
+++ b/python/pyspark/pandas/plot/core.py
@@ -26,6 +26,7 @@ from pandas.core.dtypes.inference import is_integer
from pyspark.sql import functions as F, Column
from pyspark.sql.types import DoubleType
+from pyspark.pandas.spark import functions as SF
from pyspark.pandas.missing import unsupported_function
from pyspark.pandas.config import get_option
from pyspark.pandas.utils import name_like_string
@@ -437,6 +438,37 @@ class BoxPlotBase:
return fliers
+ @staticmethod
+ def get_multicol_fliers(colnames, multicol_outliers, multicol_whiskers):
+ scols = []
+ extract_colnames = []
+ for i, colname in enumerate(colnames):
+ formated_colname = "`{}`".format(colname)
+ outlier_colname = "__{}_outlier".format(colname)
+ min_val = multicol_whiskers[colname]["min"]
+ pair_col = F.struct(
+ F.abs(F.col(formated_colname) - F.lit(min_val)).alias("ord"),
+ F.col(formated_colname).alias("val"),
+ )
+ scols.append(
+ SF.collect_top_k(
+ F.when(F.col(outlier_colname), pair_col)
+ .otherwise(F.lit(None))
+ .alias(f"pair_{i}"),
+ 1001,
+ False,
+ ).alias(f"top_{i}")
+ )
+ extract_colnames.append(f"top_{i}.val")
+
+ results =
multicol_outliers.select(scols).select(extract_colnames).first()
+
+ fliers = {}
+ for i, colname in enumerate(colnames):
+ fliers[colname] = results[i]
+
+ return fliers
+
class KdePlotBase(NumericPlotBase):
@staticmethod
diff --git a/python/pyspark/pandas/plot/plotly.py
b/python/pyspark/pandas/plot/plotly.py
index 4de313b1e831..0afcd6d7e869 100644
--- a/python/pyspark/pandas/plot/plotly.py
+++ b/python/pyspark/pandas/plot/plotly.py
@@ -199,11 +199,19 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"],
**kwargs):
# Computes min and max values of non-outliers - the whiskers
whiskers = BoxPlotBase.calc_multicol_whiskers(numeric_column_names,
outliers)
+ fliers = None
+ if boxpoints:
+ fliers = BoxPlotBase.get_multicol_fliers(numeric_column_names,
outliers, whiskers)
+
i = 0
for colname in numeric_column_names:
col_stats = multicol_stats[colname]
col_whiskers = whiskers[colname]
+ col_fliers = None
+ if fliers is not None and colname in fliers and
len(fliers[colname]) > 0:
+ col_fliers = [fliers[colname]]
+
fig.add_trace(
go.Box(
x=[i],
@@ -214,7 +222,7 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"],
**kwargs):
mean=[col_stats["mean"]],
lowerfence=[col_whiskers["min"]],
upperfence=[col_whiskers["max"]],
- y=None, # todo: support y=fliers
+ y=col_fliers,
boxpoints=boxpoints,
notched=notched,
**kwargs,
diff --git a/python/pyspark/pandas/spark/functions.py
b/python/pyspark/pandas/spark/functions.py
index 8abeff655ea5..6bef3d9b87c0 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -174,6 +174,19 @@ def null_index(col: Column) -> Column:
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
+def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
+ if is_remote():
+ from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns
+
+ return _invoke_function_over_columns("collect_top_k", col, F.lit(num),
F.lit(reverse))
+
+ else:
+ from pyspark import SparkContext
+
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num,
reverse))
+
+
def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
unit_mapping = {
"YEAR": "years",
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 6b497553dcb0..c1c9af2ea427 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -149,6 +149,9 @@ private[sql] object PythonSQLUtils extends Logging {
def nullIndex(e: Column): Column = Column.internalFn("null_index", e)
+ def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
+ Column.internalFn("collect_top_k", e, lit(num), lit(reverse))
+
def pandasProduct(e: Column, ignoreNA: Boolean): Column =
Column.internalFn("pandas_product", e, lit(ignoreNA))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]