This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9b165793d53 [SPARK-38993][PYTHON] Impl DataFrame.boxplot and 
DataFrame.plot.box
9b165793d53 is described below

commit 9b165793d53fd6190173c54383ec3373222231cf
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Fri Apr 29 13:06:54 2022 +0900

    [SPARK-38993][PYTHON] Impl DataFrame.boxplot and DataFrame.plot.box
    
    ### What changes were proposed in this pull request?
    Impl DataFrame.boxplot and DataFrame.plot.box
    
    ### Why are the changes needed?
    to increase pandas API coverage in PySpark
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ```
    In [2]: df = ps.DataFrame([[5.1, 3.5, 0], [4.9, 3.0, 0], [7.0, 3.2, 1], 
[6.4, 3.2, 1], [5.9, 3.0, 2]], columns=['length', 'width', 'species'])
    
    In [3]: df.boxplot()
    Out[3]:
    In [4]: df.plot.box()
    Out[4]:
    ```
    
    
![image](https://user-images.githubusercontent.com/7322292/164674307-d7622e22-bbfb-45d0-9fd8-318a1a11258f.png)
    
    ### How was this patch tested?
    added ut and manually tests
    
    Closes #36317 from zhengruifeng/impl_box_plot.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../docs/source/reference/pyspark.pandas/frame.rst |   2 +
 .../pandas_on_spark/supported_pandas_api.rst       |   4 +-
 python/pyspark/pandas/frame.py                     |   6 ++
 python/pyspark/pandas/missing/frame.py             |   1 -
 python/pyspark/pandas/plot/core.py                 |  85 +++++++++++++++-
 python/pyspark/pandas/plot/plotly.py               | 110 ++++++++++++++-------
 .../pyspark/pandas/tests/plot/test_frame_plot.py   |  46 ++++++++-
 7 files changed, 210 insertions(+), 44 deletions(-)

diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst 
b/python/docs/source/reference/pyspark.pandas/frame.rst
index 05c215110c6..75a8941ad78 100644
--- a/python/docs/source/reference/pyspark.pandas/frame.rst
+++ b/python/docs/source/reference/pyspark.pandas/frame.rst
@@ -323,11 +323,13 @@ specific plotting methods of the form 
``DataFrame.plot.<kind>``.
    DataFrame.plot.barh
    DataFrame.plot.bar
    DataFrame.plot.hist
+   DataFrame.plot.box
    DataFrame.plot.line
    DataFrame.plot.pie
    DataFrame.plot.scatter
    DataFrame.plot.density
    DataFrame.hist
+   DataFrame.boxplot
    DataFrame.kde
 
 Pandas-on-Spark specific
diff --git 
a/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst 
b/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
index 450742a20f7..d2ac0b78861 100644
--- a/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
+++ b/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
@@ -103,7 +103,7 @@ Supported DataFrame APIs
 
+--------------------------------------------+-------------+--------------------------------------+
 | :func:`bool`                               | Y           |                   
                   |
 
+--------------------------------------------+-------------+--------------------------------------+
-| boxplot                                    | N           |                   
                   |
+| boxplot                                    | Y           |                   
                   |
 
+--------------------------------------------+-------------+--------------------------------------+
 | :func:`clip`                               | P           | ``axis``, 
``inplace``                |
 
+--------------------------------------------+-------------+--------------------------------------+
@@ -315,7 +315,7 @@ Supported DataFrame APIs
 
+--------------------------------------------+-------------+--------------------------------------+
 | :func:`plot.barh`                          | Y           |                   
                   |
 
+--------------------------------------------+-------------+--------------------------------------+
-| :func:`plot.box`                           | N           |                   
                   |
+| :func:`plot.box`                           | Y           |                   
                   |
 
+--------------------------------------------+-------------+--------------------------------------+
 | :func:`plot.density`                       | Y           |                   
                   |
 
+--------------------------------------------+-------------+--------------------------------------+
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 9880e2a18d8..4ec0c9e0605 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -841,6 +841,12 @@ class DataFrame(Frame, Generic[T]):
 
     hist.__doc__ = PandasOnSparkPlotAccessor.hist.__doc__
 
+    @no_type_check
+    def boxplot(self, **kwds):
+        return self.plot.box(**kwds)
+
+    boxplot.__doc__ = PandasOnSparkPlotAccessor.box.__doc__
+
     @no_type_check
     def kde(self, bw_method=None, ind=None, **kwds):
         return self.plot.kde(bw_method, ind, **kwds)
diff --git a/python/pyspark/pandas/missing/frame.py 
b/python/pyspark/pandas/missing/frame.py
index ba2d01c5225..cd5e447cf0b 100644
--- a/python/pyspark/pandas/missing/frame.py
+++ b/python/pyspark/pandas/missing/frame.py
@@ -36,7 +36,6 @@ class _MissingPandasLikeDataFrame:
     # Functions
     asfreq = _unsupported_function("asfreq")
     asof = _unsupported_function("asof")
-    boxplot = _unsupported_function("boxplot")
     combine = _unsupported_function("combine")
     compare = _unsupported_function("compare")
     convert_dtypes = _unsupported_function("convert_dtypes")
diff --git a/python/pyspark/pandas/plot/core.py 
b/python/pyspark/pandas/plot/core.py
index 8ee959db481..57f62e22e51 100644
--- a/python/pyspark/pandas/plot/core.py
+++ b/python/pyspark/pandas/plot/core.py
@@ -272,6 +272,45 @@ class HistogramPlotBase(NumericPlotBase):
 
 
 class BoxPlotBase:
+    @staticmethod
+    def compute_multicol_stats(data, colnames, whis, precision):
+        # Computes mean, median, Q1 and Q3 with approx_percentile and precision
+        scol = []
+        for colname in colnames:
+            scol.append(
+                F.percentile_approx(
+                    "`%s`" % colname, [0.25, 0.50, 0.75], int(1.0 / precision)
+                ).alias("{}_percentiles%".format(colname))
+            )
+            scol.append(F.mean("`%s`" % 
colname).alias("{}_mean".format(colname)))
+
+        #      a_percentiles  a_mean    b_percentiles  b_mean
+        # 0  [3.0, 3.2, 3.2]    3.18  [5.1, 5.9, 6.4]    5.86
+        pdf = data._internal.resolved_copy.spark_frame.select(*scol).toPandas()
+
+        i = 0
+        multicol_stats = {}
+        for colname in colnames:
+            q1, med, q3 = pdf.iloc[0, i]
+            iqr = q3 - q1
+            lfence = q1 - whis * iqr
+            ufence = q3 + whis * iqr
+            i += 1
+
+            mean = pdf.iloc[0, i]
+            i += 1
+
+            multicol_stats[colname] = {
+                "mean": mean,
+                "med": med,
+                "q1": q1,
+                "q3": q3,
+                "lfence": lfence,
+                "ufence": ufence,
+            }
+
+        return multicol_stats
+
     @staticmethod
     def compute_stats(data, colname, whis, precision):
         # Computes mean, median, Q1 and Q3 with approx_percentile and precision
@@ -307,6 +346,15 @@ class BoxPlotBase:
 
         return stats, (lfence.values[0], ufence.values[0])
 
+    @staticmethod
+    def multicol_outliers(data, multicol_stats):
+        scols = {}
+        for colname, stats in multicol_stats.items():
+            scols["__{}_outlier".format(colname)] = ~F.col("`%s`" % 
colname).between(
+                stats["lfence"], stats["ufence"]
+            )
+        return data._internal.resolved_copy.spark_frame.withColumns(scols)
+
     @staticmethod
     def outliers(data, colname, lfence, ufence):
         # Builds expression to identify outliers
@@ -316,6 +364,39 @@ class BoxPlotBase:
             "__{}_outlier".format(colname), ~expression
         )
 
+    @staticmethod
+    def calc_multicol_whiskers(colnames, multicol_outliers):
+        # Computes min and max values of non-outliers - the whiskers
+        scols = []
+        for colname in colnames:
+            outlier_colname = "__{}_outlier".format(colname)
+            scols.append(
+                F.min(
+                    F.when(~F.col(outlier_colname), 
F.col(colname)).otherwise(SF.lit(None))
+                ).alias("__{}_min".format(colname))
+            )
+            scols.append(
+                F.max(
+                    F.when(~F.col(outlier_colname), 
F.col(colname)).otherwise(SF.lit(None))
+                ).alias("__{}_max".format(colname))
+            )
+
+        pdf = multicol_outliers.select(*scols).toPandas()
+
+        i = 0
+        whiskers = {}
+        for colname in colnames:
+            min = pdf.iloc[0, i]
+            i += 1
+            max = pdf.iloc[0, i]
+            i += 1
+            whiskers[colname] = {
+                "min": min,
+                "max": max,
+            }
+
+        return whiskers
+
     @staticmethod
     def calc_whiskers(colname, outliers):
         # Computes min and max values of non-outliers - the whiskers
@@ -815,10 +896,8 @@ class PandasOnSparkPlotAccessor(PandasObject):
         """
         from pyspark.pandas import DataFrame, Series
 
-        if isinstance(self.data, Series):
+        if isinstance(self.data, (Series, DataFrame)):
             return self(kind="box", **kwds)
-        elif isinstance(self.data, DataFrame):
-            return unsupported_function(class_name="pd.DataFrame", 
method_name="box")()
 
     def hist(self, bins=10, **kwds):
         """
diff --git a/python/pyspark/pandas/plot/plotly.py 
b/python/pyspark/pandas/plot/plotly.py
index ebf23416344..d54166a33a0 100644
--- a/python/pyspark/pandas/plot/plotly.py
+++ b/python/pyspark/pandas/plot/plotly.py
@@ -123,11 +123,7 @@ def plot_histogram(data: Union["ps.DataFrame", 
"ps.Series"], **kwargs):
 def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
     import plotly.graph_objs as go
     import pyspark.pandas as ps
-
-    if isinstance(data, ps.DataFrame):
-        raise RuntimeError(
-            "plotly does not support a box plot with pandas-on-Spark 
DataFrame. Use Series instead."
-        )
+    from pyspark.sql.types import NumericType
 
     # 'whis' isn't actually an argument in plotly (but in matplotlib). But 
seems like
     # plotly doesn't expose the reach of the whiskers to the beyond the first 
and
@@ -150,40 +146,82 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], 
**kwargs):
             "Set to False." % notched
         )
 
-    colname = name_like_string(data.name)
-    spark_column_name = 
data._internal.spark_column_name_for(data._column_label)
-
-    # Computes mean, median, Q1 and Q3 with approx_percentile and precision
-    col_stats, col_fences = BoxPlotBase.compute_stats(data, spark_column_name, 
whis, precision)
-
-    # Creates a column to flag rows as outliers or not
-    outliers = BoxPlotBase.outliers(data, spark_column_name, *col_fences)
+    fig = go.Figure()
+    if isinstance(data, ps.Series):
+        colname = name_like_string(data.name)
+        spark_column_name = 
data._internal.spark_column_name_for(data._column_label)
+
+        # Computes mean, median, Q1 and Q3 with approx_percentile and precision
+        col_stats, col_fences = BoxPlotBase.compute_stats(data, 
spark_column_name, whis, precision)
+
+        # Creates a column to flag rows as outliers or not
+        outliers = BoxPlotBase.outliers(data, spark_column_name, *col_fences)
+
+        # Computes min and max values of non-outliers - the whiskers
+        whiskers = BoxPlotBase.calc_whiskers(spark_column_name, outliers)
+
+        fliers = None
+        if boxpoints:
+            fliers = BoxPlotBase.get_fliers(spark_column_name, outliers, 
whiskers[0])
+            fliers = [fliers] if len(fliers) > 0 else None
+
+        fig.add_trace(
+            go.Box(
+                name=colname,
+                q1=[col_stats["q1"]],
+                median=[col_stats["med"]],
+                q3=[col_stats["q3"]],
+                mean=[col_stats["mean"]],
+                lowerfence=[whiskers[0]],
+                upperfence=[whiskers[1]],
+                y=fliers,
+                boxpoints=boxpoints,
+                notched=notched,
+                **kwargs,  # this is for workarounds. Box takes different 
options from express.box.
+            )
+        )
+        fig["layout"]["xaxis"]["title"] = colname
 
-    # Computes min and max values of non-outliers - the whiskers
-    whiskers = BoxPlotBase.calc_whiskers(spark_column_name, outliers)
+    else:
+        numeric_column_names = []
+        for column_label in data._internal.column_labels:
+            if isinstance(data._internal.spark_type_for(column_label), 
NumericType):
+                numeric_column_names.append(name_like_string(column_label))
+
+        # Computes mean, median, Q1 and Q3 with approx_percentile and precision
+        multicol_stats = BoxPlotBase.compute_multicol_stats(
+            data, numeric_column_names, whis, precision
+        )
 
-    fliers = None
-    if boxpoints:
-        fliers = BoxPlotBase.get_fliers(spark_column_name, outliers, 
whiskers[0])
-        fliers = [fliers] if len(fliers) > 0 else None
+        # Creates a column to flag rows as outliers or not
+        outliers = BoxPlotBase.multicol_outliers(data, multicol_stats)
+
+        # Computes min and max values of non-outliers - the whiskers
+        whiskers = BoxPlotBase.calc_multicol_whiskers(numeric_column_names, 
outliers)
+
+        i = 0
+        for colname in numeric_column_names:
+            col_stats = multicol_stats[colname]
+            col_whiskers = whiskers[colname]
+
+            fig.add_trace(
+                go.Box(
+                    x=[i],
+                    name=colname,
+                    q1=[col_stats["q1"]],
+                    median=[col_stats["med"]],
+                    q3=[col_stats["q3"]],
+                    mean=[col_stats["mean"]],
+                    lowerfence=[col_whiskers["min"]],
+                    upperfence=[col_whiskers["max"]],
+                    y=None,  # todo: support y=fliers
+                    boxpoints=boxpoints,
+                    notched=notched,
+                    **kwargs,
+                )
+            )
+            i += 1
 
-    fig = go.Figure()
-    fig.add_trace(
-        go.Box(
-            name=colname,
-            q1=[col_stats["q1"]],
-            median=[col_stats["med"]],
-            q3=[col_stats["q3"]],
-            mean=[col_stats["mean"]],
-            lowerfence=[whiskers[0]],
-            upperfence=[whiskers[1]],
-            y=fliers,
-            boxpoints=boxpoints,
-            notched=notched,
-            **kwargs,  # this is for workarounds. Box takes different options 
from express.box.
-        )
-    )
-    fig["layout"]["xaxis"]["title"] = colname
     fig["layout"]["yaxis"]["title"] = "value"
     return fig
 
diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot.py 
b/python/pyspark/pandas/tests/plot/test_frame_plot.py
index 4b457f80788..5d265ff2eee 100644
--- a/python/pyspark/pandas/tests/plot/test_frame_plot.py
+++ b/python/pyspark/pandas/tests/plot/test_frame_plot.py
@@ -20,7 +20,7 @@ import numpy as np
 
 from pyspark import pandas as ps
 from pyspark.pandas.config import set_option, reset_option, option_context
-from pyspark.pandas.plot import TopNPlotBase, SampledPlotBase, 
HistogramPlotBase
+from pyspark.pandas.plot import TopNPlotBase, SampledPlotBase, 
HistogramPlotBase, BoxPlotBase
 from pyspark.pandas.exceptions import PandasNotImplementedError
 from pyspark.testing.pandasutils import PandasOnSparkTestCase
 
@@ -41,7 +41,7 @@ class DataFramePlotTest(PandasOnSparkTestCase):
     def test_missing(self):
         psdf = ps.DataFrame(np.random.rand(2500, 4), columns=["a", "b", "c", 
"d"])
 
-        unsupported_functions = ["box", "hexbin"]
+        unsupported_functions = ["hexbin"]
 
         for name in unsupported_functions:
             with self.assertRaisesRegex(
@@ -110,6 +110,48 @@ class DataFramePlotTest(PandasOnSparkTestCase):
                 pd.Series(expected_histogram, name=expected_name), histogram, 
almost=True
             )
 
+    def test_compute_box_multi_columns(self):
+        # compare compute_multicol_stats with compute_stats
+        def check_box_multi_columns(psdf):
+            k = 1.5
+            multicol_stats = BoxPlotBase.compute_multicol_stats(
+                psdf, ["a", "b", "c"], whis=k, precision=0.01
+            )
+            multicol_outliers = BoxPlotBase.multicol_outliers(psdf, 
multicol_stats)
+            multicol_whiskers = BoxPlotBase.calc_multicol_whiskers(
+                ["a", "b", "c"], multicol_outliers
+            )
+
+            for col in ["a", "b", "c"]:
+                col_stats = multicol_stats[col]
+                col_whiskers = multicol_whiskers[col]
+
+                stats, fences = BoxPlotBase.compute_stats(psdf[col], col, 
whis=k, precision=0.01)
+                outliers = BoxPlotBase.outliers(psdf[col], col, *fences)
+                whiskers = BoxPlotBase.calc_whiskers(col, outliers)
+
+                self.assertEqual(stats["mean"], col_stats["mean"])
+                self.assertEqual(stats["med"], col_stats["med"])
+                self.assertEqual(stats["q1"], col_stats["q1"])
+                self.assertEqual(stats["q3"], col_stats["q3"])
+                self.assertEqual(fences[0], col_stats["lfence"])
+                self.assertEqual(fences[1], col_stats["ufence"])
+                self.assertEqual(whiskers[0], col_whiskers["min"])
+                self.assertEqual(whiskers[1], col_whiskers["max"])
+
+        pdf = pd.DataFrame(
+            {
+                "a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 50],
+                "b": [3, 2, 5, 4, 5, 6, 8, 8, 11, 60, 90],
+                "c": [-30, -2, 5, 4, 5, 6, -8, 8, 11, 12, 18],
+            },
+            index=[0, 1, 3, 5, 6, 8, 9, 9, 9, 10, 10],
+        )
+        psdf = ps.from_pandas(pdf)
+
+        check_box_multi_columns(psdf)
+        check_box_multi_columns(-psdf)
+
 
 if __name__ == "__main__":
     import unittest


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to