This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 488c3f604490 [SPARK-49776][PYTHON][CONNECT] Support pie plots
488c3f604490 is described below

commit 488c3f604490c8632dde67a00118d49ccfcbf578
Author: Xinrong Meng <[email protected]>
AuthorDate: Fri Sep 27 08:35:10 2024 +0800

    [SPARK-49776][PYTHON][CONNECT] Support pie plots
    
    ### What changes were proposed in this pull request?
    Support area plots with plotly backend on both Spark Connect and Spark 
classic.
    
    ### Why are the changes needed?
    While Pandas on Spark supports plotting, PySpark currently lacks this 
feature. The proposed API will enable users to generate visualizations. This 
will provide users with an intuitive, interactive way to explore and understand 
large datasets directly from PySpark DataFrames, streamlining the data analysis 
workflow in distributed environments.
    
    See more at [PySpark Plotting API 
Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing)
 in progress.
    
    Part of https://issues.apache.org/jira/browse/SPARK-49530.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. Area plots are supported as shown below.
    
    ```py
    >>> from datetime import datetime
    >>> data = [
    ...     (3, 5, 20, datetime(2018, 1, 31)),
    ...     (2, 5, 42, datetime(2018, 2, 28)),
    ...     (3, 6, 28, datetime(2018, 3, 31)),
    ...     (9, 12, 62, datetime(2018, 4, 30))]
    >>> columns = ["sales", "signups", "visits", "date"]
    >>> df = spark.createDataFrame(data, columns)
    >>> fig = df.plot(kind="pie", x="date", y="sales")  # df.plot(kind="pie", 
x="date", y="sales")
    >>> fig.show()
    ```
    ![newplot 
(8)](https://github.com/user-attachments/assets/c4078bb7-4d84-4607-bcd7-bdd6fbbf8e28)
    
    ### How was this patch tested?
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48256 from xinrong-meng/plot_pie.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Xinrong Meng <[email protected]>
---
 python/pyspark/errors/error-conditions.json        |  5 +++
 python/pyspark/sql/plot/core.py                    | 41 +++++++++++++++++++++-
 python/pyspark/sql/plot/plotly.py                  | 15 ++++++++
 .../sql/tests/plot/test_frame_plot_plotly.py       | 25 +++++++++++++
 4 files changed, 85 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/errors/error-conditions.json 
b/python/pyspark/errors/error-conditions.json
index 115ad658e32f..ed62ea117d36 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -812,6 +812,11 @@
       "Pipe function `<func_name>` exited with error code <error_code>."
     ]
   },
+  "PLOT_NOT_NUMERIC_COLUMN": {
+    "message": [
+      "Argument <arg_name> must be a numerical column for plotting, got 
<arg_type>."
+    ]
+  },
   "PYTHON_HASH_SEED_NOT_SET": {
     "message": [
       "Randomness of hash of string should be disabled via PYTHONHASHSEED."
diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
index 9f83d0069652..f9667ee2c0d6 100644
--- a/python/pyspark/sql/plot/core.py
+++ b/python/pyspark/sql/plot/core.py
@@ -17,7 +17,8 @@
 
 from typing import Any, TYPE_CHECKING, Optional, Union
 from types import ModuleType
-from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.errors import PySparkRuntimeError, PySparkTypeError, 
PySparkValueError
+from pyspark.sql.types import NumericType
 from pyspark.sql.utils import require_minimum_plotly_version
 
 
@@ -97,6 +98,7 @@ class PySparkPlotAccessor:
         "bar": PySparkTopNPlotBase().get_top_n,
         "barh": PySparkTopNPlotBase().get_top_n,
         "line": PySparkSampledPlotBase().get_sampled,
+        "pie": PySparkTopNPlotBase().get_top_n,
         "scatter": PySparkSampledPlotBase().get_sampled,
     }
     _backends = {}  # type: ignore[var-annotated]
@@ -299,3 +301,40 @@ class PySparkPlotAccessor:
         >>> df.plot.area(x='date', y=['sales', 'signups', 'visits'])  # 
doctest: +SKIP
         """
         return self(kind="area", x=x, y=y, **kwargs)
+
+    def pie(self, x: str, y: str, **kwargs: Any) -> "Figure":
+        """
+        Generate a pie plot.
+
+        A pie plot is a proportional representation of the numerical data in a
+        column.
+
+        Parameters
+        ----------
+        x : str
+            Name of column to be used as the category labels for the pie plot.
+        y : str
+            Name of the column to plot.
+        **kwargs
+            Additional keyword arguments.
+
+        Returns
+        -------
+        :class:`plotly.graph_objs.Figure`
+
+        Examples
+        --------
+        """
+        schema = self.data.schema
+
+        # Check if 'y' is a numerical column
+        y_field = schema[y] if y in schema.names else None
+        if y_field is None or not isinstance(y_field.dataType, NumericType):
+            raise PySparkTypeError(
+                errorClass="PLOT_NOT_NUMERIC_COLUMN",
+                messageParameters={
+                    "arg_name": "y",
+                    "arg_type": str(y_field.dataType) if y_field else "None",
+                },
+            )
+        return self(kind="pie", x=x, y=y, **kwargs)
diff --git a/python/pyspark/sql/plot/plotly.py 
b/python/pyspark/sql/plot/plotly.py
index 5efc19476057..91f536346471 100644
--- a/python/pyspark/sql/plot/plotly.py
+++ b/python/pyspark/sql/plot/plotly.py
@@ -27,4 +27,19 @@ if TYPE_CHECKING:
 def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
     import plotly
 
+    if kind == "pie":
+        return plot_pie(data, **kwargs)
+
     return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, 
**kwargs)
+
+
+def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure":
+    # TODO(SPARK-49530): Support pie subplots with plotly backend
+    from plotly import express
+
+    pdf = PySparkPlotAccessor.plot_data_map["pie"](data)
+    x = kwargs.pop("x", None)
+    y = kwargs.pop("y", None)
+    fig = express.pie(pdf, values=y, names=x, **kwargs)
+
+    return fig
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py 
b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
index 6176525b4955..70a1b336f734 100644
--- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
+++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
@@ -19,6 +19,7 @@ import unittest
 from datetime import datetime
 
 import pyspark.sql.plot  # noqa: F401
+from pyspark.errors import PySparkTypeError
 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, 
plotly_requirement_message
 
 
@@ -64,6 +65,11 @@ class DataFramePlotPlotlyTestsMixin:
             self.assertEqual(fig_data["type"], "scatter")
             self.assertEqual(fig_data["orientation"], "v")
             self.assertEqual(fig_data["mode"], "lines")
+        elif kind == "pie":
+            self.assertEqual(fig_data["type"], "pie")
+            self.assertEqual(list(fig_data["labels"]), expected_x)
+            self.assertEqual(list(fig_data["values"]), expected_y)
+            return
 
         self.assertEqual(fig_data["xaxis"], "x")
         self.assertEqual(list(fig_data["x"]), expected_x)
@@ -133,6 +139,25 @@ class DataFramePlotPlotlyTestsMixin:
         self._check_fig_data("area", fig["data"][1], expected_x, [5, 5, 6, 
12], "signups")
         self._check_fig_data("area", fig["data"][2], expected_x, [20, 42, 28, 
62], "visits")
 
+    def test_pie_plot(self):
+        fig = self.sdf3.plot(kind="pie", x="date", y="sales")
+        expected_x = [
+            datetime(2018, 1, 31, 0, 0),
+            datetime(2018, 2, 28, 0, 0),
+            datetime(2018, 3, 31, 0, 0),
+            datetime(2018, 4, 30, 0, 0),
+        ]
+        self._check_fig_data("pie", fig["data"][0], expected_x, [3, 2, 3, 9])
+
+        # y is not a numerical column
+        with self.assertRaises(PySparkTypeError) as pe:
+            self.sdf.plot.pie(x="int_val", y="category")
+        self.check_error(
+            exception=pe.exception,
+            errorClass="PLOT_NOT_NUMERIC_COLUMN",
+            messageParameters={"arg_name": "y", "arg_type": "StringType()"},
+        )
+
 
 class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, 
ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to