This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7f0ecd4221a7 [SPARK-49764][PYTHON][CONNECT] Support area plots
7f0ecd4221a7 is described below

commit 7f0ecd4221a7043b539fb20a792c00f379a5885e
Author: Xinrong Meng <[email protected]>
AuthorDate: Wed Sep 25 19:24:05 2024 +0900

    [SPARK-49764][PYTHON][CONNECT] Support area plots
    
    ### What changes were proposed in this pull request?
    Support area plots with plotly backend on both Spark Connect and Spark 
classic.
    
    ### Why are the changes needed?
    While Pandas on Spark supports plotting, PySpark currently lacks this 
feature. The proposed API will enable users to generate visualizations. This 
will provide users with an intuitive, interactive way to explore and understand 
large datasets directly from PySpark DataFrames, streamlining the data analysis 
workflow in distributed environments.
    
    See more at [PySpark Plotting API 
Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing)
 in progress.
    
    Part of https://issues.apache.org/jira/browse/SPARK-49530.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. Area plots are supported as shown below.
    
    ```py
    >>> from datetime import datetime
    >>> data = [
    ...     (3, 5, 20, datetime(2018, 1, 31)),
    ...     (2, 5, 42, datetime(2018, 2, 28)),
    ...     (3, 6, 28, datetime(2018, 3, 31)),
    ...     (9, 12, 62, datetime(2018, 4, 30))]
    >>> columns = ["sales", "signups", "visits", "date"]
    >>> df = spark.createDataFrame(data, columns)
    >>> fig = df.plot.area(x="date", y=["sales", "signups", "visits"])  # 
df.plot(kind="area", x="date", y=["sales", "signups", "visits"])
    >>> fig.show()
    ```
    ![newplot 
(7)](https://github.com/user-attachments/assets/e603cd99-ce8b-4448-8e1f-cbc093097c45)
    
    ### How was this patch tested?
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48236 from xinrong-meng/plot_area.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/plot/core.py                    | 35 ++++++++++++++++++++++
 .../sql/tests/plot/test_frame_plot_plotly.py       | 35 ++++++++++++++++++++++
 2 files changed, 70 insertions(+)

diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
index 0a3a0101e189..9f83d0069652 100644
--- a/python/pyspark/sql/plot/core.py
+++ b/python/pyspark/sql/plot/core.py
@@ -93,6 +93,7 @@ class PySparkSampledPlotBase:
 
 class PySparkPlotAccessor:
     plot_data_map = {
+        "area": PySparkSampledPlotBase().get_sampled,
         "bar": PySparkTopNPlotBase().get_top_n,
         "barh": PySparkTopNPlotBase().get_top_n,
         "line": PySparkSampledPlotBase().get_sampled,
@@ -264,3 +265,37 @@ class PySparkPlotAccessor:
         >>> df.plot.scatter(x='length', y='width')  # doctest: +SKIP
         """
         return self(kind="scatter", x=x, y=y, **kwargs)
+
+    def area(self, x: str, y: str, **kwargs: Any) -> "Figure":
+        """
+        Draw a stacked area plot.
+
+        An area plot displays quantitative data visually.
+
+        Parameters
+        ----------
+        x : str
+            Name of column to use for the horizontal axis.
+        y : str or list of str
+            Name(s) of the column(s) to plot.
+        **kwargs: Optional
+            Additional keyword arguments.
+
+        Returns
+        -------
+        :class:`plotly.graph_objs.Figure`
+
+        Examples
+        --------
+        >>> from datetime import datetime
+        >>> data = [
+        ...     (3, 5, 20, datetime(2018, 1, 31)),
+        ...     (2, 5, 42, datetime(2018, 2, 28)),
+        ...     (3, 6, 28, datetime(2018, 3, 31)),
+        ...     (9, 12, 62, datetime(2018, 4, 30))
+        ... ]
+        >>> columns = ["sales", "signups", "visits", "date"]
+        >>> df = spark.createDataFrame(data, columns)
+        >>> df.plot.area(x='date', y=['sales', 'signups', 'visits'])  # 
doctest: +SKIP
+        """
+        return self(kind="area", x=x, y=y, **kwargs)
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py 
b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
index ccfe1a75424e..6176525b4955 100644
--- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
+++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
@@ -16,6 +16,8 @@
 #
 
 import unittest
+from datetime import datetime
+
 import pyspark.sql.plot  # noqa: F401
 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, 
plotly_requirement_message
 
@@ -34,6 +36,17 @@ class DataFramePlotPlotlyTestsMixin:
         columns = ["length", "width", "species"]
         return self.spark.createDataFrame(data, columns)
 
+    @property
+    def sdf3(self):
+        data = [
+            (3, 5, 20, datetime(2018, 1, 31)),
+            (2, 5, 42, datetime(2018, 2, 28)),
+            (3, 6, 28, datetime(2018, 3, 31)),
+            (9, 12, 62, datetime(2018, 4, 30)),
+        ]
+        columns = ["sales", "signups", "visits", "date"]
+        return self.spark.createDataFrame(data, columns)
+
     def _check_fig_data(self, kind, fig_data, expected_x, expected_y, 
expected_name=""):
         if kind == "line":
             self.assertEqual(fig_data["mode"], "lines")
@@ -46,6 +59,11 @@ class DataFramePlotPlotlyTestsMixin:
         elif kind == "scatter":
             self.assertEqual(fig_data["type"], "scatter")
             self.assertEqual(fig_data["orientation"], "v")
+            self.assertEqual(fig_data["mode"], "markers")
+        elif kind == "area":
+            self.assertEqual(fig_data["type"], "scatter")
+            self.assertEqual(fig_data["orientation"], "v")
+            self.assertEqual(fig_data["mode"], "lines")
 
         self.assertEqual(fig_data["xaxis"], "x")
         self.assertEqual(list(fig_data["x"]), expected_x)
@@ -98,6 +116,23 @@ class DataFramePlotPlotlyTestsMixin:
             "scatter", fig["data"][0], [3.5, 3.0, 3.2, 3.2, 3.0], [5.1, 4.9, 
7.0, 6.4, 5.9]
         )
 
+    def test_area_plot(self):
+        # single column as vertical axis
+        fig = self.sdf3.plot(kind="area", x="date", y="sales")
+        expected_x = [
+            datetime(2018, 1, 31, 0, 0),
+            datetime(2018, 2, 28, 0, 0),
+            datetime(2018, 3, 31, 0, 0),
+            datetime(2018, 4, 30, 0, 0),
+        ]
+        self._check_fig_data("area", fig["data"][0], expected_x, [3, 2, 3, 9])
+
+        # multiple columns as vertical axis
+        fig = self.sdf3.plot.area(x="date", y=["sales", "signups", "visits"])
+        self._check_fig_data("area", fig["data"][0], expected_x, [3, 2, 3, 9], 
"sales")
+        self._check_fig_data("area", fig["data"][1], expected_x, [5, 5, 6, 
12], "signups")
+        self._check_fig_data("area", fig["data"][2], expected_x, [20, 42, 28, 
62], "visits")
+
 
 class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, 
ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to