This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3b8dddac65bc [SPARK-49531][PYTHON][CONNECT] Support line plot with 
plotly backend
3b8dddac65bc is described below

commit 3b8dddac65bce6f88f51e23e777d521d65fa3373
Author: Xinrong Meng <[email protected]>
AuthorDate: Fri Sep 13 09:21:20 2024 +0800

    [SPARK-49531][PYTHON][CONNECT] Support line plot with plotly backend
    
    ### What changes were proposed in this pull request?
    Support line plot with plotly backend on both Spark Connect and Spark 
classic.
    
    ### Why are the changes needed?
    While Pandas on Spark supports plotting, PySpark currently lacks this 
feature. The proposed API will enable users to generate visualizations, such as 
line plots, by leveraging libraries like Plotly. This will provide users with 
an intuitive, interactive way to explore and understand large datasets directly 
from PySpark DataFrames, streamlining the data analysis workflow in distributed 
environments.
    
    See more at [PySpark Plotting API 
Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing)
 in progress.
    
    Part of https://issues.apache.org/jira/browse/SPARK-49530.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    ```python
    >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
    >>> columns = ["category", "int_val", "float_val"]
    >>> sdf = spark.createDataFrame(data, columns)
    >>> sdf.show()
    +--------+-------+---------+
    |category|int_val|float_val|
    +--------+-------+---------+
    |       A|     10|      1.5|
    |       B|     30|      2.5|
    |       C|     20|      3.5|
    +--------+-------+---------+
    
    >>> f = sdf.plot(kind="line", x="category", y="int_val")
    >>> f.show()  # see below
    >>> g = sdf.plot.line(x="category", y=["int_val", "float_val"])
    >>> g.show()  # see below
    ```
    `f.show()`:
    
![newplot](https://github.com/user-attachments/assets/ebd50bbc-0dd1-437f-ae0c-0b4de8f3c722)
    
    `g.show()`:
    ![newplot 
(1)](https://github.com/user-attachments/assets/46d28840-a147-428f-8d88-d424aa76ad06)
    
    ### How was this patch tested?
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48008 from xinrong-meng/plot_line.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   4 +
 python/pyspark/errors/error-conditions.json        |   5 +
 python/pyspark/sql/classic/dataframe.py            |   5 +
 python/pyspark/sql/connect/dataframe.py            |   5 +
 python/pyspark/sql/dataframe.py                    |  27 +++++
 python/pyspark/sql/plot/__init__.py                |  21 ++++
 python/pyspark/sql/plot/core.py                    | 135 +++++++++++++++++++++
 python/pyspark/sql/plot/plotly.py                  |  30 +++++
 .../sql/tests/connect/test_parity_frame_plot.py    |  36 ++++++
 .../tests/connect/test_parity_frame_plot_plotly.py |  36 ++++++
 python/pyspark/sql/tests/plot/__init__.py          |  16 +++
 python/pyspark/sql/tests/plot/test_frame_plot.py   |  79 ++++++++++++
 .../sql/tests/plot/test_frame_plot_plotly.py       |  64 ++++++++++
 python/pyspark/sql/utils.py                        |  17 +++
 python/pyspark/testing/sqlutils.py                 |   7 ++
 .../org/apache/spark/sql/internal/SQLConf.scala    |  27 +++++
 16 files changed, 514 insertions(+)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 34fbb8450d54..b9a4bed715f6 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -548,6 +548,8 @@ pyspark_sql = Module(
         "pyspark.sql.tests.test_udtf",
         "pyspark.sql.tests.test_utils",
         "pyspark.sql.tests.test_resources",
+        "pyspark.sql.tests.plot.test_frame_plot",
+        "pyspark.sql.tests.plot.test_frame_plot_plotly",
     ],
 )
 
@@ -1051,6 +1053,8 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map",
         "pyspark.sql.tests.connect.test_parity_python_datasource",
         "pyspark.sql.tests.connect.test_parity_python_streaming_datasource",
+        "pyspark.sql.tests.connect.test_parity_frame_plot",
+        "pyspark.sql.tests.connect.test_parity_frame_plot_plotly",
         "pyspark.sql.tests.connect.test_utils",
         "pyspark.sql.tests.connect.client.test_artifact",
         "pyspark.sql.tests.connect.client.test_artifact_localcluster",
diff --git a/python/pyspark/errors/error-conditions.json 
b/python/pyspark/errors/error-conditions.json
index 4061d024a83c..92aeb15e21d1 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -1088,6 +1088,11 @@
       "Function `<func_name>` should use only POSITIONAL or POSITIONAL OR 
KEYWORD arguments."
     ]
   },
+  "UNSUPPORTED_PLOT_BACKEND": {
+    "message": [
+      "`<backend>` is not supported, it should be one of the values from 
<supported_backends>"
+    ]
+  },
   "UNSUPPORTED_SIGNATURE": {
     "message": [
       "Unsupported signature: <signature>."
diff --git a/python/pyspark/sql/classic/dataframe.py 
b/python/pyspark/sql/classic/dataframe.py
index 91b959162590..d174f7774cc5 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -58,6 +58,7 @@ from pyspark.sql.column import Column
 from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column
 from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
 from pyspark.sql.merge import MergeIntoWriter
+from pyspark.sql.plot import PySparkPlotAccessor
 from pyspark.sql.streaming import DataStreamWriter
 from pyspark.sql.types import (
     StructType,
@@ -1862,6 +1863,10 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
             messageParameters={"member": "queryExecution"},
         )
 
+    @property
+    def plot(self) -> PySparkPlotAccessor:
+        return PySparkPlotAccessor(self)
+
 
 class DataFrameNaFunctions(ParentDataFrameNaFunctions):
     def __init__(self, df: ParentDataFrame):
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 768abd655d49..e3b1d35b2d5d 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -83,6 +83,7 @@ from pyspark.sql.connect.expressions import (
     UnresolvedStar,
 )
 from pyspark.sql.connect.functions import builtin as F
+from pyspark.sql.plot import PySparkPlotAccessor
 from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
 from pyspark.sql.pandas.functions import _validate_pandas_udf  # type: 
ignore[attr-defined]
 
@@ -2239,6 +2240,10 @@ class DataFrame(ParentDataFrame):
     def executionInfo(self) -> Optional["ExecutionInfo"]:
         return self._execution_info
 
+    @property
+    def plot(self) -> PySparkPlotAccessor:
+        return PySparkPlotAccessor(self)
+
 
 class DataFrameNaFunctions(ParentDataFrameNaFunctions):
     def __init__(self, df: ParentDataFrame):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ef35b7333257..7748510258ea 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -39,6 +39,7 @@ from pyspark.resource import ResourceProfile
 from pyspark.sql.column import Column
 from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
 from pyspark.sql.merge import MergeIntoWriter
+from pyspark.sql.plot import PySparkPlotAccessor
 from pyspark.sql.streaming import DataStreamWriter
 from pyspark.sql.types import StructType, Row
 from pyspark.sql.utils import dispatch_df_method
@@ -6394,6 +6395,32 @@ class DataFrame:
         """
         ...
 
+    @property
+    def plot(self) -> PySparkPlotAccessor:
+        """
+        Returns a :class:`PySparkPlotAccessor` for plotting functions.
+
+        .. versionadded:: 4.0.0
+
+        Returns
+        -------
+        :class:`PySparkPlotAccessor`
+
+        Notes
+        -----
+        This API is experimental.
+
+        Examples
+        --------
+        >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+        >>> columns = ["category", "int_val", "float_val"]
+        >>> df = spark.createDataFrame(data, columns)
+        >>> type(df.plot)
+        <class 'pyspark.sql.plot.core.PySparkPlotAccessor'>
+        >>> df.plot.line(x="category", y=["int_val", "float_val"])  # doctest: 
+SKIP
+        """
+        ...
+
 
 class DataFrameNaFunctions:
     """Functionality for working with missing data in :class:`DataFrame`.
diff --git a/python/pyspark/sql/plot/__init__.py 
b/python/pyspark/sql/plot/__init__.py
new file mode 100644
index 000000000000..6da07061b2a0
--- /dev/null
+++ b/python/pyspark/sql/plot/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+This package includes the plotting APIs for PySpark DataFrame.
+"""
+from pyspark.sql.plot.core import *  # noqa: F403, F401
diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
new file mode 100644
index 000000000000..baee610dc6bd
--- /dev/null
+++ b/python/pyspark/sql/plot/core.py
@@ -0,0 +1,135 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Any, TYPE_CHECKING, Optional, Union
+from types import ModuleType
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.sql.utils import require_minimum_plotly_version
+
+
+if TYPE_CHECKING:
+    from pyspark.sql import DataFrame
+    import pandas as pd
+    from plotly.graph_objs import Figure
+
+
+class PySparkTopNPlotBase:
+    def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame":
+        from pyspark.sql import SparkSession
+
+        session = SparkSession.getActiveSession()
+        if session is None:
+            raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", 
messageParameters=dict())
+
+        max_rows = int(
+            session.conf.get("spark.sql.pyspark.plotting.max_rows")  # type: 
ignore[arg-type]
+        )
+        pdf = sdf.limit(max_rows + 1).toPandas()
+
+        self.partial = False
+        if len(pdf) > max_rows:
+            self.partial = True
+            pdf = pdf.iloc[:max_rows]
+
+        return pdf
+
+
+class PySparkSampledPlotBase:
+    def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame":
+        from pyspark.sql import SparkSession
+
+        session = SparkSession.getActiveSession()
+        if session is None:
+            raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", 
messageParameters=dict())
+
+        sample_ratio = 
session.conf.get("spark.sql.pyspark.plotting.sample_ratio")
+        max_rows = int(
+            session.conf.get("spark.sql.pyspark.plotting.max_rows")  # type: 
ignore[arg-type]
+        )
+
+        if sample_ratio is None:
+            fraction = 1 / (sdf.count() / max_rows)
+            fraction = min(1.0, fraction)
+        else:
+            fraction = float(sample_ratio)
+
+        sampled_sdf = sdf.sample(fraction=fraction)
+        pdf = sampled_sdf.toPandas()
+
+        return pdf
+
+
+class PySparkPlotAccessor:
+    plot_data_map = {
+        "line": PySparkSampledPlotBase().get_sampled,
+    }
+    _backends = {}  # type: ignore[var-annotated]
+
+    def __init__(self, data: "DataFrame"):
+        self.data = data
+
+    def __call__(
+        self, kind: str = "line", backend: Optional[str] = None, **kwargs: Any
+    ) -> "Figure":
+        plot_backend = PySparkPlotAccessor._get_plot_backend(backend)
+
+        return plot_backend.plot_pyspark(self.data, kind=kind, **kwargs)
+
+    @staticmethod
+    def _get_plot_backend(backend: Optional[str] = None) -> ModuleType:
+        backend = backend or "plotly"
+
+        if backend in PySparkPlotAccessor._backends:
+            return PySparkPlotAccessor._backends[backend]
+
+        if backend == "plotly":
+            require_minimum_plotly_version()
+        else:
+            raise PySparkValueError(
+                errorClass="UNSUPPORTED_PLOT_BACKEND",
+                messageParameters={"backend": backend, "supported_backends": 
", ".join(["plotly"])},
+            )
+        from pyspark.sql.plot import plotly as module
+
+        return module
+
+    def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> 
"Figure":
+        """
+        Plot DataFrame as lines.
+
+        Parameters
+        ----------
+        x : str
+            Name of column to use for the horizontal axis.
+        y : str or list of str
+            Name(s) of the column(s) to use for the vertical axis. Multiple 
columns can be plotted.
+        **kwds : optional
+            Additional keyword arguments.
+
+        Returns
+        -------
+        :class:`plotly.graph_objs.Figure`
+
+        Examples
+        --------
+        >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+        >>> columns = ["category", "int_val", "float_val"]
+        >>> df = spark.createDataFrame(data, columns)
+        >>> df.plot.line(x="category", y="int_val")  # doctest: +SKIP
+        >>> df.plot.line(x="category", y=["int_val", "float_val"])  # doctest: 
+SKIP
+        """
+        return self(kind="line", x=x, y=y, **kwargs)
diff --git a/python/pyspark/sql/plot/plotly.py 
b/python/pyspark/sql/plot/plotly.py
new file mode 100644
index 000000000000..5efc19476057
--- /dev/null
+++ b/python/pyspark/sql/plot/plotly.py
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import TYPE_CHECKING, Any
+
+from pyspark.sql.plot import PySparkPlotAccessor
+
+if TYPE_CHECKING:
+    from pyspark.sql import DataFrame
+    from plotly.graph_objs import Figure
+
+
+def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
+    import plotly
+
+    return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, 
**kwargs)
diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot.py 
b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py
new file mode 100644
index 000000000000..c69e438bf7eb
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.tests.plot.test_frame_plot import DataFramePlotTestsMixin
+
+
+class FramePlotParityTests(DataFramePlotTestsMixin, ReusedConnectTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.connect.test_parity_frame_plot import *  # noqa: 
F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py 
b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py
new file mode 100644
index 000000000000..78508fe53337
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.tests.plot.test_frame_plot_plotly import 
DataFramePlotPlotlyTestsMixin
+
+
+class FramePlotPlotlyParityTests(DataFramePlotPlotlyTestsMixin, 
ReusedConnectTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.connect.test_parity_frame_plot_plotly import *  # 
noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/plot/__init__.py 
b/python/pyspark/sql/tests/plot/__init__.py
new file mode 100644
index 000000000000..cce3acad34a4
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py 
b/python/pyspark/sql/tests/plot/test_frame_plot.py
new file mode 100644
index 000000000000..19ef53e46b2f
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/test_frame_plot.py
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.errors import PySparkValueError
+from pyspark.sql import Row
+from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class DataFramePlotTestsMixin:
+    def test_backend(self):
+        accessor = self.spark.range(2).plot
+        backend = accessor._get_plot_backend()
+        self.assertEqual(backend.__name__, "pyspark.sql.plot.plotly")
+
+        with self.assertRaises(PySparkValueError) as pe:
+            accessor._get_plot_backend("matplotlib")
+
+        self.check_error(
+            exception=pe.exception,
+            errorClass="UNSUPPORTED_PLOT_BACKEND",
+            messageParameters={"backend": "matplotlib", "supported_backends": 
"plotly"},
+        )
+
+    def test_topn_max_rows(self):
+        try:
+            self.spark.conf.set("spark.sql.pyspark.plotting.max_rows", "1000")
+            sdf = self.spark.range(2500)
+            pdf = PySparkTopNPlotBase().get_top_n(sdf)
+            self.assertEqual(len(pdf), 1000)
+        finally:
+            self.spark.conf.unset("spark.sql.pyspark.plotting.max_rows")
+
+    def test_sampled_plot_with_ratio(self):
+        try:
+            self.spark.conf.set("spark.sql.pyspark.plotting.sample_ratio", 
"0.5")
+            data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2500)]
+            sdf = self.spark.createDataFrame(data)
+            pdf = PySparkSampledPlotBase().get_sampled(sdf)
+            self.assertEqual(round(len(pdf) / 2500, 1), 0.5)
+        finally:
+            self.spark.conf.unset("spark.sql.pyspark.plotting.sample_ratio")
+
+    def test_sampled_plot_with_max_rows(self):
+        data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2000)]
+        sdf = self.spark.createDataFrame(data)
+        pdf = PySparkSampledPlotBase().get_sampled(sdf)
+        self.assertEqual(round(len(pdf) / 2000, 1), 0.5)
+
+
+class DataFramePlotTests(DataFramePlotTestsMixin, ReusedSQLTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.plot.test_frame_plot import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py 
b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
new file mode 100644
index 000000000000..72a3ed267d19
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import pyspark.sql.plot  # noqa: F401
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, 
plotly_requirement_message
+
+
[email protected](not have_plotly, plotly_requirement_message)
+class DataFramePlotPlotlyTestsMixin:
+    @property
+    def sdf(self):
+        data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+        columns = ["category", "int_val", "float_val"]
+        return self.spark.createDataFrame(data, columns)
+
+    def _check_fig_data(self, fig_data, expected_x, expected_y, 
expected_name=""):
+        self.assertEqual(fig_data["mode"], "lines")
+        self.assertEqual(fig_data["type"], "scatter")
+        self.assertEqual(fig_data["xaxis"], "x")
+        self.assertEqual(list(fig_data["x"]), expected_x)
+        self.assertEqual(fig_data["yaxis"], "y")
+        self.assertEqual(list(fig_data["y"]), expected_y)
+        self.assertEqual(fig_data["name"], expected_name)
+
+    def test_line_plot(self):
+        # single column as vertical axis
+        fig = self.sdf.plot(kind="line", x="category", y="int_val")
+        self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20])
+
+        # multiple columns as vertical axis
+        fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"])
+        self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20], 
"int_val")
+        self._check_fig_data(fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], 
"float_val")
+
+
+class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, 
ReusedSQLTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.plot.test_frame_plot_plotly import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 11b91612419a..5d9ec92cbc83 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -41,6 +41,7 @@ from pyspark.errors import (  # noqa: F401
     PythonException,
     UnknownException,
     SparkUpgradeException,
+    PySparkImportError,
     PySparkNotImplementedError,
     PySparkRuntimeError,
 )
@@ -115,6 +116,22 @@ def require_test_compiled() -> None:
         )
 
 
+def require_minimum_plotly_version() -> None:
+    """Raise ImportError if plotly is not installed"""
+    minimum_plotly_version = "4.8"
+
+    try:
+        import plotly  # noqa: F401
+    except ImportError as error:
+        raise PySparkImportError(
+            errorClass="PACKAGE_NOT_INSTALLED",
+            messageParameters={
+                "package_name": "plotly",
+                "minimum_version": str(minimum_plotly_version),
+            },
+        ) from error
+
+
 class ForeachBatchFunction:
     """
     This is the Python implementation of Java interface 
'ForeachBatchFunction'. This wraps
diff --git a/python/pyspark/testing/sqlutils.py 
b/python/pyspark/testing/sqlutils.py
index 9f07c44c084c..00ad40e68bd7 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -48,6 +48,13 @@ try:
 except Exception as e:
     test_not_compiled_message = str(e)
 
+plotly_requirement_message = None
+try:
+    import plotly
+except ImportError as e:
+    plotly_requirement_message = str(e)
+have_plotly = plotly_requirement_message is None
+
 from pyspark.sql import SparkSession
 from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
 from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index a87b0613292c..5853e4b66dcc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3169,6 +3169,29 @@ object SQLConf {
       .version("4.0.0")
       .fallbackConf(Python.PYTHON_WORKER_FAULTHANLDER_ENABLED)
 
+  val PYSPARK_PLOT_MAX_ROWS =
+    buildConf("spark.sql.pyspark.plotting.max_rows")
+      .doc(
+        "The visual limit on top-n-based plots. If set to 1000, the first 1000 
data points " +
+        "will be used for plotting.")
+      .version("4.0.0")
+      .intConf
+      .createWithDefault(1000)
+
+  val PYSPARK_PLOT_SAMPLE_RATIO =
+    buildConf("spark.sql.pyspark.plotting.sample_ratio")
+      .doc(
+        "The proportion of data that will be plotted for sample-based plots. 
It is determined " +
+          "based on spark.sql.pyspark.plotting.max_rows if not explicitly set."
+      )
+      .version("4.0.0")
+      .doubleConf
+      .checkValue(
+        ratio => ratio >= 0.0 && ratio <= 1.0,
+        "The value should be between 0.0 and 1.0 inclusive."
+      )
+      .createOptional
+
   val ARROW_SPARKR_EXECUTION_ENABLED =
     buildConf("spark.sql.execution.arrow.sparkr.enabled")
       .doc("When true, make use of Apache Arrow for columnar data transfers in 
SparkR. " +
@@ -5855,6 +5878,10 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def pythonUDFWorkerFaulthandlerEnabled: Boolean = 
getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED)
 
+  def pysparkPlotMaxRows: Int = getConf(PYSPARK_PLOT_MAX_ROWS)
+
+  def pysparkPlotSampleRatio: Option[Double] = 
getConf(PYSPARK_PLOT_SAMPLE_RATIO)
+
   def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED)
 
   def arrowPySparkFallbackEnabled: Boolean = 
getConf(ARROW_PYSPARK_FALLBACK_ENABLED)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to