This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a30a3fd0e74d [SPARK-49530][PYTHON] Support pie subplots in pyspark 
plotting
a30a3fd0e74d is described below

commit a30a3fd0e74d36c744af26ac1931dfa3c2883552
Author: Xinrong Meng <[email protected]>
AuthorDate: Tue Dec 24 09:47:51 2024 +0900

    [SPARK-49530][PYTHON] Support pie subplots in pyspark plotting
    
    ### What changes were proposed in this pull request?
    Support pie subplots in pyspark plotting.
    
    ### Why are the changes needed?
    API parity with pandas.DataFrame.plot.pie, see 
[here](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.plot.pie.html)
    
    ### Does this PR introduce _any_ user-facing change?
    Pie subplots are supported as shown below:
    
    ```py
    >>> from datetime import datetime
    >>> data = [
    ...             (3, 5, 20, datetime(2018, 1, 31)),
    ...             (2, 5, 42, datetime(2018, 2, 28)),
    ...             (3, 6, 28, datetime(2018, 3, 31)),
    ...             (9, 12, 62, datetime(2018, 4, 30)),
    ...         ]
    >>> columns = ["sales", "signups", "visits", "date"]
    >>> df = spark.createDataFrame(data, columns)
    >>> fig = df.plot(kind="pie", x="date", subplots=True)
    >>> fig.show()
    ```
    ![newplot 
(2)](https://github.com/user-attachments/assets/2b019c6a-82da-4c12-b1ff-096786801f56)
    
    ### How was this patch tested?
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #49268 from xinrong-meng/pie_subplot.
    
    Lead-authored-by: Xinrong Meng <[email protected]>
    Co-authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/errors/error-conditions.json        |  5 +++
 python/pyspark/sql/plot/core.py                    | 22 +++---------
 python/pyspark/sql/plot/plotly.py                  | 25 ++++++++++++--
 .../sql/tests/plot/test_frame_plot_plotly.py       | 39 +++++++++++++++++++---
 4 files changed, 68 insertions(+), 23 deletions(-)

diff --git a/python/pyspark/errors/error-conditions.json 
b/python/pyspark/errors/error-conditions.json
index c4ad3f8d5feb..b7c1ec23c3af 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -1103,6 +1103,11 @@
       "Function `<func_name>` should use only POSITIONAL or POSITIONAL OR 
KEYWORD arguments."
     ]
   },
+  "UNSUPPORTED_PIE_PLOT_PARAM": {
+    "message": [
+      "Pie plot requires either a `y` column or `subplots=True`."
+    ]
+  },
   "UNSUPPORTED_PLOT_BACKEND": {
     "message": [
       "`<backend>` is not supported, it should be one of the values from 
<supported_backends>"
diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
index f7133bdb70ed..e565a5d1ebf3 100644
--- a/python/pyspark/sql/plot/core.py
+++ b/python/pyspark/sql/plot/core.py
@@ -19,11 +19,10 @@ import math
 
 from typing import Any, TYPE_CHECKING, List, Optional, Union, Sequence
 from types import ModuleType
-from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors import PySparkValueError
 from pyspark.sql import Column, functions as F
 from pyspark.sql.internal import InternalFunction as SF
 from pyspark.sql.pandas.utils import require_minimum_pandas_version
-from pyspark.sql.types import NumericType
 from pyspark.sql.utils import NumpyHelper, require_minimum_plotly_version
 
 if TYPE_CHECKING:
@@ -295,7 +294,7 @@ class PySparkPlotAccessor:
         """
         return self(kind="area", x=x, y=y, **kwargs)
 
-    def pie(self, x: str, y: str, **kwargs: Any) -> "Figure":
+    def pie(self, x: str, y: Optional[str], **kwargs: Any) -> "Figure":
         """
         Generate a pie plot.
 
@@ -306,8 +305,8 @@ class PySparkPlotAccessor:
         ----------
         x : str
             Name of column to be used as the category labels for the pie plot.
-        y : str
-            Name of the column to plot.
+        y : str, optional
+            Name of the column to plot. If not provided, `subplots=True` must 
be passed at `kwargs`.
         **kwargs
             Additional keyword arguments.
 
@@ -327,19 +326,8 @@ class PySparkPlotAccessor:
         >>> columns = ["sales", "signups", "visits", "date"]
         >>> df = spark.createDataFrame(data, columns)
         >>> df.plot.pie(x='date', y='sales')  # doctest: +SKIP
+        >>> df.plot.pie(x='date', subplots=True)  # doctest: +SKIP
         """
-        schema = self.data.schema
-
-        # Check if 'y' is a numerical column
-        y_field = schema[y] if y in schema.names else None
-        if y_field is None or not isinstance(y_field.dataType, NumericType):
-            raise PySparkTypeError(
-                errorClass="PLOT_NOT_NUMERIC_COLUMN_ARGUMENT",
-                messageParameters={
-                    "arg_name": "y",
-                    "arg_type": str(y_field.dataType.__class__.__name__) if 
y_field else "None",
-                },
-            )
         return self(kind="pie", x=x, y=y, **kwargs)
 
     def box(self, column: Optional[Union[str, List[str]]] = None, **kwargs: 
Any) -> "Figure":
diff --git a/python/pyspark/sql/plot/plotly.py 
b/python/pyspark/sql/plot/plotly.py
index 959562b43552..c7691f144ffa 100644
--- a/python/pyspark/sql/plot/plotly.py
+++ b/python/pyspark/sql/plot/plotly.py
@@ -48,13 +48,34 @@ def plot_pyspark(data: "DataFrame", kind: str, **kwargs: 
Any) -> "Figure":
 
 
 def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure":
-    # TODO(SPARK-49530): Support pie subplots with plotly backend
     from plotly import express
 
     pdf = PySparkPlotAccessor.plot_data_map["pie"](data)
     x = kwargs.pop("x", None)
     y = kwargs.pop("y", None)
-    fig = express.pie(pdf, values=y, names=x, **kwargs)
+    subplots = kwargs.pop("subplots", False)
+    if y is None and not subplots:
+        raise PySparkValueError(errorClass="UNSUPPORTED_PIE_PLOT_PARAM", 
messageParameters={})
+
+    numeric_ys = process_column_param(y, data)
+
+    if subplots:
+        # One pie chart per numeric column
+        from plotly.subplots import make_subplots
+
+        fig = make_subplots(
+            rows=1,
+            cols=len(numeric_ys),
+            # To accommodate domain-based trace - pie chart
+            specs=[[{"type": "domain"}] * len(numeric_ys)],
+        )
+        for i, y_col in enumerate(numeric_ys):
+            subplot_fig = express.pie(pdf, values=y_col, names=x, **kwargs)
+            fig.add_trace(
+                subplot_fig.data[0], row=1, col=i + 1
+            )  # A single pie chart has only one trace
+    else:
+        fig = express.pie(pdf, values=numeric_ys[0], names=x, **kwargs)
 
     return fig
 
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py 
b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
index fd264c348882..3dafd71c1a32 100644
--- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
+++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
@@ -301,6 +301,7 @@ class DataFramePlotPlotlyTestsMixin:
         self._check_fig_data(fig["data"][2], **expected_fig_data)
 
     def test_pie_plot(self):
+        # single column as 'y'
         fig = self.sdf3.plot(kind="pie", x="date", y="sales")
         expected_x = [
             datetime(2018, 1, 31, 0, 0),
@@ -308,13 +309,39 @@ class DataFramePlotPlotlyTestsMixin:
             datetime(2018, 3, 31, 0, 0),
             datetime(2018, 4, 30, 0, 0),
         ]
-        expected_fig_data = {
+        expected_fig_data_sales = {
             "name": "",
             "labels": expected_x,
             "values": [3, 2, 3, 9],
             "type": "pie",
         }
-        self._check_fig_data(fig["data"][0], **expected_fig_data)
+        self._check_fig_data(fig["data"][0], **expected_fig_data_sales)
+
+        # all numeric columns as 'y'
+        expected_fig_data_signups = {
+            "name": "",
+            "labels": expected_x,
+            "values": [5, 5, 6, 12],
+            "type": "pie",
+        }
+        expected_fig_data_visits = {
+            "name": "",
+            "labels": expected_x,
+            "values": [20, 42, 28, 62],
+            "type": "pie",
+        }
+        fig = self.sdf3.plot(kind="pie", x="date", subplots=True)
+        self._check_fig_data(fig["data"][0], **expected_fig_data_sales)
+        self._check_fig_data(fig["data"][1], **expected_fig_data_signups)
+        self._check_fig_data(fig["data"][2], **expected_fig_data_visits)
+
+        # not specify subplots
+        with self.assertRaises(PySparkValueError) as pe:
+            self.sdf3.plot(kind="pie", x="date")
+
+        self.check_error(
+            exception=pe.exception, errorClass="UNSUPPORTED_PIE_PLOT_PARAM", 
messageParameters={}
+        )
 
         # y is not a numerical column
         with self.assertRaises(PySparkTypeError) as pe:
@@ -322,8 +349,12 @@ class DataFramePlotPlotlyTestsMixin:
 
         self.check_error(
             exception=pe.exception,
-            errorClass="PLOT_NOT_NUMERIC_COLUMN_ARGUMENT",
-            messageParameters={"arg_name": "y", "arg_type": "StringType"},
+            errorClass="PLOT_INVALID_TYPE_COLUMN",
+            messageParameters={
+                "col_name": "category",
+                "valid_types": "NumericType",
+                "col_type": "StringType",
+            },
         )
 
     def test_box_plot(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to