This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 22a7edce0a7c [SPARK-49531][PYTHON][CONNECT] Support line plot with 
plotly backend
22a7edce0a7c is described below

commit 22a7edce0a7c70d6c1a5dcf995c6c723f0c3352b
Author: Xinrong Meng <[email protected]>
AuthorDate: Fri Sep 20 08:53:52 2024 -0700

    [SPARK-49531][PYTHON][CONNECT] Support line plot with plotly backend
    
    ### What changes were proposed in this pull request?
    Support line plot with plotly backend on both Spark Connect and Spark 
classic.
    
    ### Why are the changes needed?
    While Pandas on Spark supports plotting, PySpark currently lacks this 
feature. The proposed API will enable users to generate visualizations, such as 
line plots, by leveraging libraries like Plotly. This will provide users with 
an intuitive, interactive way to explore and understand large datasets directly 
from PySpark DataFrames, streamlining the data analysis workflow in distributed 
environments.
    
    See more at [PySpark Plotting API 
Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing)
 in progress.
    
    Part of https://issues.apache.org/jira/browse/SPARK-49530.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    ```python
    >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
    >>> columns = ["category", "int_val", "float_val"]
    >>> sdf = spark.createDataFrame(data, columns)
    >>> sdf.show()
    +--------+-------+---------+
    |category|int_val|float_val|
    +--------+-------+---------+
    |       A|     10|      1.5|
    |       B|     30|      2.5|
    |       C|     20|      3.5|
    +--------+-------+---------+
    
    >>> f = sdf.plot(kind="line", x="category", y="int_val")
    >>> f.show()  # see below
    >>> g = sdf.plot.line(x="category", y=["int_val", "float_val"])
    >>> g.show()  # see below
    ```
    `f.show()`:
    
![newplot](https://github.com/user-attachments/assets/ebd50bbc-0dd1-437f-ae0c-0b4de8f3c722)
    
    `g.show()`:
    ![newplot 
(1)](https://github.com/user-attachments/assets/46d28840-a147-428f-8d88-d424aa76ad06)
    
    ### How was this patch tested?
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48139 from xinrong-meng/plot_line_w_dep.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .github/workflows/build_python_connect.yml         |   2 +-
 dev/requirements.txt                               |   2 +-
 dev/sparktestsupport/modules.py                    |   4 +
 python/docs/source/getting_started/install.rst     |   1 +
 python/packaging/classic/setup.py                  |   1 +
 python/packaging/connect/setup.py                  |   2 +
 python/pyspark/errors/error-conditions.json        |   5 +
 python/pyspark/sql/classic/dataframe.py            |   9 ++
 python/pyspark/sql/connect/dataframe.py            |   8 ++
 python/pyspark/sql/dataframe.py                    |  28 +++++
 python/pyspark/sql/plot/__init__.py                |  21 ++++
 python/pyspark/sql/plot/core.py                    | 135 +++++++++++++++++++++
 python/pyspark/sql/plot/plotly.py                  |  30 +++++
 .../sql/tests/connect/test_parity_frame_plot.py    |  36 ++++++
 .../tests/connect/test_parity_frame_plot_plotly.py |  36 ++++++
 python/pyspark/sql/tests/plot/__init__.py          |  16 +++
 python/pyspark/sql/tests/plot/test_frame_plot.py   |  80 ++++++++++++
 .../sql/tests/plot/test_frame_plot_plotly.py       |  64 ++++++++++
 python/pyspark/sql/utils.py                        |  17 +++
 python/pyspark/testing/sqlutils.py                 |   7 ++
 .../org/apache/spark/sql/internal/SQLConf.scala    |  27 +++++
 21 files changed, 529 insertions(+), 2 deletions(-)

diff --git a/.github/workflows/build_python_connect.yml 
b/.github/workflows/build_python_connect.yml
index 3ac1a0117e41..f668d813ef26 100644
--- a/.github/workflows/build_python_connect.yml
+++ b/.github/workflows/build_python_connect.yml
@@ -71,7 +71,7 @@ jobs:
           python packaging/connect/setup.py sdist
           cd dist
           pip install pyspark*connect-*.tar.gz
-          pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 
'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 
'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed 
unittest-xml-reporting
+          pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 
'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 
'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed 
unittest-xml-reporting 'plotly>=4.8'
       - name: Run tests
         env:
           SPARK_TESTING: 1
diff --git a/dev/requirements.txt b/dev/requirements.txt
index 5486c98ab8f8..cafc73405aaa 100644
--- a/dev/requirements.txt
+++ b/dev/requirements.txt
@@ -7,7 +7,7 @@ pyarrow>=10.0.0
 six==1.16.0
 pandas>=2.0.0
 scipy
-plotly
+plotly>=4.8
 mlflow>=2.3.1
 scikit-learn
 matplotlib
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 34fbb8450d54..b9a4bed715f6 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -548,6 +548,8 @@ pyspark_sql = Module(
         "pyspark.sql.tests.test_udtf",
         "pyspark.sql.tests.test_utils",
         "pyspark.sql.tests.test_resources",
+        "pyspark.sql.tests.plot.test_frame_plot",
+        "pyspark.sql.tests.plot.test_frame_plot_plotly",
     ],
 )
 
@@ -1051,6 +1053,8 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map",
         "pyspark.sql.tests.connect.test_parity_python_datasource",
         "pyspark.sql.tests.connect.test_parity_python_streaming_datasource",
+        "pyspark.sql.tests.connect.test_parity_frame_plot",
+        "pyspark.sql.tests.connect.test_parity_frame_plot_plotly",
         "pyspark.sql.tests.connect.test_utils",
         "pyspark.sql.tests.connect.client.test_artifact",
         "pyspark.sql.tests.connect.client.test_artifact_localcluster",
diff --git a/python/docs/source/getting_started/install.rst 
b/python/docs/source/getting_started/install.rst
index 549656bea103..88c0a8c26cc9 100644
--- a/python/docs/source/getting_started/install.rst
+++ b/python/docs/source/getting_started/install.rst
@@ -183,6 +183,7 @@ Package                    Supported version         Note
 Additional libraries that enhance functionality but are not included in the 
installation packages:
 
 - **memory-profiler**: Used for PySpark UDF memory profiling, 
``spark.profile.show(...)`` and ``spark.sql.pyspark.udf.profiler``.
+- **plotly**: Used for PySpark plotting, ``DataFrame.plot``.
 
 Note that PySpark requires Java 17 or later with ``JAVA_HOME`` properly set 
and refer to |downloading|_.
 
diff --git a/python/packaging/classic/setup.py 
b/python/packaging/classic/setup.py
index 79b74483f00d..17cca326d024 100755
--- a/python/packaging/classic/setup.py
+++ b/python/packaging/classic/setup.py
@@ -288,6 +288,7 @@ try:
             "pyspark.sql.connect.streaming.worker",
             "pyspark.sql.functions",
             "pyspark.sql.pandas",
+            "pyspark.sql.plot",
             "pyspark.sql.protobuf",
             "pyspark.sql.streaming",
             "pyspark.sql.worker",
diff --git a/python/packaging/connect/setup.py 
b/python/packaging/connect/setup.py
index ab166c79747d..6ae16e9a9ad3 100755
--- a/python/packaging/connect/setup.py
+++ b/python/packaging/connect/setup.py
@@ -77,6 +77,7 @@ if "SPARK_TESTING" in os.environ:
         "pyspark.sql.tests.connect.client",
         "pyspark.sql.tests.connect.shell",
         "pyspark.sql.tests.pandas",
+        "pyspark.sql.tests.plot",
         "pyspark.sql.tests.streaming",
         "pyspark.ml.tests.connect",
         "pyspark.pandas.tests",
@@ -161,6 +162,7 @@ try:
         "pyspark.sql.connect.streaming.worker",
         "pyspark.sql.functions",
         "pyspark.sql.pandas",
+        "pyspark.sql.plot",
         "pyspark.sql.protobuf",
         "pyspark.sql.streaming",
         "pyspark.sql.worker",
diff --git a/python/pyspark/errors/error-conditions.json 
b/python/pyspark/errors/error-conditions.json
index 4061d024a83c..92aeb15e21d1 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -1088,6 +1088,11 @@
       "Function `<func_name>` should use only POSITIONAL or POSITIONAL OR 
KEYWORD arguments."
     ]
   },
+  "UNSUPPORTED_PLOT_BACKEND": {
+    "message": [
+      "`<backend>` is not supported, it should be one of the values from 
<supported_backends>"
+    ]
+  },
   "UNSUPPORTED_SIGNATURE": {
     "message": [
       "Unsupported signature: <signature>."
diff --git a/python/pyspark/sql/classic/dataframe.py 
b/python/pyspark/sql/classic/dataframe.py
index 91b959162590..a2778cbc32c4 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -73,6 +73,11 @@ from pyspark.sql.utils import get_active_spark_context, 
to_java_array, to_scala_
 from pyspark.sql.pandas.conversion import PandasConversionMixin
 from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
 
+try:
+    from pyspark.sql.plot import PySparkPlotAccessor
+except ImportError:
+    PySparkPlotAccessor = None  # type: ignore
+
 if TYPE_CHECKING:
     from py4j.java_gateway import JavaObject
     import pyarrow as pa
@@ -1862,6 +1867,10 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
             messageParameters={"member": "queryExecution"},
         )
 
+    @property
+    def plot(self) -> PySparkPlotAccessor:
+        return PySparkPlotAccessor(self)
+
 
 class DataFrameNaFunctions(ParentDataFrameNaFunctions):
     def __init__(self, df: ParentDataFrame):
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 768abd655d49..59d79decf669 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -86,6 +86,10 @@ from pyspark.sql.connect.functions import builtin as F
 from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
 from pyspark.sql.pandas.functions import _validate_pandas_udf  # type: 
ignore[attr-defined]
 
+try:
+    from pyspark.sql.plot import PySparkPlotAccessor
+except ImportError:
+    PySparkPlotAccessor = None  # type: ignore
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import (
@@ -2239,6 +2243,10 @@ class DataFrame(ParentDataFrame):
     def executionInfo(self) -> Optional["ExecutionInfo"]:
         return self._execution_info
 
+    @property
+    def plot(self) -> PySparkPlotAccessor:
+        return PySparkPlotAccessor(self)
+
 
 class DataFrameNaFunctions(ParentDataFrameNaFunctions):
     def __init__(self, df: ParentDataFrame):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ef35b7333257..2179a844b1e5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -43,6 +43,7 @@ from pyspark.sql.streaming import DataStreamWriter
 from pyspark.sql.types import StructType, Row
 from pyspark.sql.utils import dispatch_df_method
 
+
 if TYPE_CHECKING:
     from py4j.java_gateway import JavaObject
     import pyarrow as pa
@@ -65,6 +66,7 @@ if TYPE_CHECKING:
         ArrowMapIterFunction,
         DataFrameLike as PandasDataFrameLike,
     )
+    from pyspark.sql.plot import PySparkPlotAccessor
     from pyspark.sql.metrics import ExecutionInfo
 
 
@@ -6394,6 +6396,32 @@ class DataFrame:
         """
         ...
 
+    @property
+    def plot(self) -> "PySparkPlotAccessor":
+        """
+        Returns a :class:`PySparkPlotAccessor` for plotting functions.
+
+        .. versionadded:: 4.0.0
+
+        Returns
+        -------
+        :class:`PySparkPlotAccessor`
+
+        Notes
+        -----
+        This API is experimental.
+
+        Examples
+        --------
+        >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+        >>> columns = ["category", "int_val", "float_val"]
+        >>> df = spark.createDataFrame(data, columns)
+        >>> type(df.plot)
+        <class 'pyspark.sql.plot.core.PySparkPlotAccessor'>
+        >>> df.plot.line(x="category", y=["int_val", "float_val"])  # doctest: 
+SKIP
+        """
+        ...
+
 
 class DataFrameNaFunctions:
     """Functionality for working with missing data in :class:`DataFrame`.
diff --git a/python/pyspark/sql/plot/__init__.py 
b/python/pyspark/sql/plot/__init__.py
new file mode 100644
index 000000000000..6da07061b2a0
--- /dev/null
+++ b/python/pyspark/sql/plot/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+This package includes the plotting APIs for PySpark DataFrame.
+"""
+from pyspark.sql.plot.core import *  # noqa: F403, F401
diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
new file mode 100644
index 000000000000..392ef73b3884
--- /dev/null
+++ b/python/pyspark/sql/plot/core.py
@@ -0,0 +1,135 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Any, TYPE_CHECKING, Optional, Union
+from types import ModuleType
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.sql.utils import require_minimum_plotly_version
+
+
+if TYPE_CHECKING:
+    from pyspark.sql import DataFrame
+    import pandas as pd
+    from plotly.graph_objs import Figure
+
+
+class PySparkTopNPlotBase:
+    def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame":
+        from pyspark.sql import SparkSession
+
+        session = SparkSession.getActiveSession()
+        if session is None:
+            raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", 
messageParameters=dict())
+
+        max_rows = int(
+            session.conf.get("spark.sql.pyspark.plotting.max_rows")  # type: 
ignore[arg-type]
+        )
+        pdf = sdf.limit(max_rows + 1).toPandas()
+
+        self.partial = False
+        if len(pdf) > max_rows:
+            self.partial = True
+            pdf = pdf.iloc[:max_rows]
+
+        return pdf
+
+
+class PySparkSampledPlotBase:
+    def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame":
+        from pyspark.sql import SparkSession
+
+        session = SparkSession.getActiveSession()
+        if session is None:
+            raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", 
messageParameters=dict())
+
+        sample_ratio = 
session.conf.get("spark.sql.pyspark.plotting.sample_ratio")
+        max_rows = int(
+            session.conf.get("spark.sql.pyspark.plotting.max_rows")  # type: 
ignore[arg-type]
+        )
+
+        if sample_ratio is None:
+            fraction = 1 / (sdf.count() / max_rows)
+            fraction = min(1.0, fraction)
+        else:
+            fraction = float(sample_ratio)
+
+        sampled_sdf = sdf.sample(fraction=fraction)
+        pdf = sampled_sdf.toPandas()
+
+        return pdf
+
+
+class PySparkPlotAccessor:
+    plot_data_map = {
+        "line": PySparkSampledPlotBase().get_sampled,
+    }
+    _backends = {}  # type: ignore[var-annotated]
+
+    def __init__(self, data: "DataFrame"):
+        self.data = data
+
+    def __call__(
+        self, kind: str = "line", backend: Optional[str] = None, **kwargs: Any
+    ) -> "Figure":
+        plot_backend = PySparkPlotAccessor._get_plot_backend(backend)
+
+        return plot_backend.plot_pyspark(self.data, kind=kind, **kwargs)
+
+    @staticmethod
+    def _get_plot_backend(backend: Optional[str] = None) -> ModuleType:
+        backend = backend or "plotly"
+
+        if backend in PySparkPlotAccessor._backends:
+            return PySparkPlotAccessor._backends[backend]
+
+        if backend == "plotly":
+            require_minimum_plotly_version()
+        else:
+            raise PySparkValueError(
+                errorClass="UNSUPPORTED_PLOT_BACKEND",
+                messageParameters={"backend": backend, "supported_backends": 
", ".join(["plotly"])},
+            )
+        from pyspark.sql.plot import plotly as module
+
+        return module
+
+    def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> 
"Figure":
+        """
+        Plot DataFrame as lines.
+
+        Parameters
+        ----------
+        x : str
+            Name of column to use for the horizontal axis.
+        y : str or list of str
+            Name(s) of the column(s) to use for the vertical axis. Multiple 
columns can be plotted.
+        **kwargs : optional
+            Additional keyword arguments.
+
+        Returns
+        -------
+        :class:`plotly.graph_objs.Figure`
+
+        Examples
+        --------
+        >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+        >>> columns = ["category", "int_val", "float_val"]
+        >>> df = spark.createDataFrame(data, columns)
+        >>> df.plot.line(x="category", y="int_val")  # doctest: +SKIP
+        >>> df.plot.line(x="category", y=["int_val", "float_val"])  # doctest: 
+SKIP
+        """
+        return self(kind="line", x=x, y=y, **kwargs)
diff --git a/python/pyspark/sql/plot/plotly.py 
b/python/pyspark/sql/plot/plotly.py
new file mode 100644
index 000000000000..5efc19476057
--- /dev/null
+++ b/python/pyspark/sql/plot/plotly.py
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import TYPE_CHECKING, Any
+
+from pyspark.sql.plot import PySparkPlotAccessor
+
+if TYPE_CHECKING:
+    from pyspark.sql import DataFrame
+    from plotly.graph_objs import Figure
+
+
+def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
+    import plotly
+
+    return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, 
**kwargs)
diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot.py 
b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py
new file mode 100644
index 000000000000..c69e438bf7eb
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.tests.plot.test_frame_plot import DataFramePlotTestsMixin
+
+
+class FramePlotParityTests(DataFramePlotTestsMixin, ReusedConnectTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.connect.test_parity_frame_plot import *  # noqa: 
F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py 
b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py
new file mode 100644
index 000000000000..78508fe53337
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.tests.plot.test_frame_plot_plotly import 
DataFramePlotPlotlyTestsMixin
+
+
+class FramePlotPlotlyParityTests(DataFramePlotPlotlyTestsMixin, 
ReusedConnectTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.connect.test_parity_frame_plot_plotly import *  # 
noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/plot/__init__.py 
b/python/pyspark/sql/tests/plot/__init__.py
new file mode 100644
index 000000000000..cce3acad34a4
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py 
b/python/pyspark/sql/tests/plot/test_frame_plot.py
new file mode 100644
index 000000000000..f753b5ab3db7
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/test_frame_plot.py
@@ -0,0 +1,80 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from pyspark.errors import PySparkValueError
+from pyspark.sql import Row
+from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, 
plotly_requirement_message
+
+
[email protected](not have_plotly, plotly_requirement_message)
+class DataFramePlotTestsMixin:
+    def test_backend(self):
+        accessor = self.spark.range(2).plot
+        backend = accessor._get_plot_backend()
+        self.assertEqual(backend.__name__, "pyspark.sql.plot.plotly")
+
+        with self.assertRaises(PySparkValueError) as pe:
+            accessor._get_plot_backend("matplotlib")
+
+        self.check_error(
+            exception=pe.exception,
+            errorClass="UNSUPPORTED_PLOT_BACKEND",
+            messageParameters={"backend": "matplotlib", "supported_backends": 
"plotly"},
+        )
+
+    def test_topn_max_rows(self):
+        try:
+            self.spark.conf.set("spark.sql.pyspark.plotting.max_rows", "1000")
+            sdf = self.spark.range(2500)
+            pdf = PySparkTopNPlotBase().get_top_n(sdf)
+            self.assertEqual(len(pdf), 1000)
+        finally:
+            self.spark.conf.unset("spark.sql.pyspark.plotting.max_rows")
+
+    def test_sampled_plot_with_ratio(self):
+        try:
+            self.spark.conf.set("spark.sql.pyspark.plotting.sample_ratio", 
"0.5")
+            data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2500)]
+            sdf = self.spark.createDataFrame(data)
+            pdf = PySparkSampledPlotBase().get_sampled(sdf)
+            self.assertEqual(round(len(pdf) / 2500, 1), 0.5)
+        finally:
+            self.spark.conf.unset("spark.sql.pyspark.plotting.sample_ratio")
+
+    def test_sampled_plot_with_max_rows(self):
+        data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2000)]
+        sdf = self.spark.createDataFrame(data)
+        pdf = PySparkSampledPlotBase().get_sampled(sdf)
+        self.assertEqual(round(len(pdf) / 2000, 1), 0.5)
+
+
+class DataFramePlotTests(DataFramePlotTestsMixin, ReusedSQLTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.plot.test_frame_plot import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py 
b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
new file mode 100644
index 000000000000..72a3ed267d19
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import pyspark.sql.plot  # noqa: F401
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, 
plotly_requirement_message
+
+
[email protected](not have_plotly, plotly_requirement_message)
+class DataFramePlotPlotlyTestsMixin:
+    @property
+    def sdf(self):
+        data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+        columns = ["category", "int_val", "float_val"]
+        return self.spark.createDataFrame(data, columns)
+
+    def _check_fig_data(self, fig_data, expected_x, expected_y, 
expected_name=""):
+        self.assertEqual(fig_data["mode"], "lines")
+        self.assertEqual(fig_data["type"], "scatter")
+        self.assertEqual(fig_data["xaxis"], "x")
+        self.assertEqual(list(fig_data["x"]), expected_x)
+        self.assertEqual(fig_data["yaxis"], "y")
+        self.assertEqual(list(fig_data["y"]), expected_y)
+        self.assertEqual(fig_data["name"], expected_name)
+
+    def test_line_plot(self):
+        # single column as vertical axis
+        fig = self.sdf.plot(kind="line", x="category", y="int_val")
+        self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20])
+
+        # multiple columns as vertical axis
+        fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"])
+        self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20], 
"int_val")
+        self._check_fig_data(fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], 
"float_val")
+
+
+class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, 
ReusedSQLTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.plot.test_frame_plot_plotly import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 11b91612419a..5d9ec92cbc83 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -41,6 +41,7 @@ from pyspark.errors import (  # noqa: F401
     PythonException,
     UnknownException,
     SparkUpgradeException,
+    PySparkImportError,
     PySparkNotImplementedError,
     PySparkRuntimeError,
 )
@@ -115,6 +116,22 @@ def require_test_compiled() -> None:
         )
 
 
+def require_minimum_plotly_version() -> None:
+    """Raise ImportError if plotly is not installed"""
+    minimum_plotly_version = "4.8"
+
+    try:
+        import plotly  # noqa: F401
+    except ImportError as error:
+        raise PySparkImportError(
+            errorClass="PACKAGE_NOT_INSTALLED",
+            messageParameters={
+                "package_name": "plotly",
+                "minimum_version": str(minimum_plotly_version),
+            },
+        ) from error
+
+
 class ForeachBatchFunction:
     """
     This is the Python implementation of Java interface 
'ForeachBatchFunction'. This wraps
diff --git a/python/pyspark/testing/sqlutils.py 
b/python/pyspark/testing/sqlutils.py
index 9f07c44c084c..00ad40e68bd7 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -48,6 +48,13 @@ try:
 except Exception as e:
     test_not_compiled_message = str(e)
 
+plotly_requirement_message = None
+try:
+    import plotly
+except ImportError as e:
+    plotly_requirement_message = str(e)
+have_plotly = plotly_requirement_message is None
+
 from pyspark.sql import SparkSession
 from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
 from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 2eaafde52228..6c3e9bac1cfe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3169,6 +3169,29 @@ object SQLConf {
       .version("4.0.0")
       .fallbackConf(Python.PYTHON_WORKER_FAULTHANLDER_ENABLED)
 
+  val PYSPARK_PLOT_MAX_ROWS =
+    buildConf("spark.sql.pyspark.plotting.max_rows")
+      .doc(
+        "The visual limit on top-n-based plots. If set to 1000, the first 1000 
data points " +
+        "will be used for plotting.")
+      .version("4.0.0")
+      .intConf
+      .createWithDefault(1000)
+
+  val PYSPARK_PLOT_SAMPLE_RATIO =
+    buildConf("spark.sql.pyspark.plotting.sample_ratio")
+      .doc(
+        "The proportion of data that will be plotted for sample-based plots. 
It is determined " +
+          "based on spark.sql.pyspark.plotting.max_rows if not explicitly set."
+      )
+      .version("4.0.0")
+      .doubleConf
+      .checkValue(
+        ratio => ratio >= 0.0 && ratio <= 1.0,
+        "The value should be between 0.0 and 1.0 inclusive."
+      )
+      .createOptional
+
   val ARROW_SPARKR_EXECUTION_ENABLED =
     buildConf("spark.sql.execution.arrow.sparkr.enabled")
       .doc("When true, make use of Apache Arrow for columnar data transfers in 
SparkR. " +
@@ -5873,6 +5896,10 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def pythonUDFWorkerFaulthandlerEnabled: Boolean = 
getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED)
 
+  def pysparkPlotMaxRows: Int = getConf(PYSPARK_PLOT_MAX_ROWS)
+
+  def pysparkPlotSampleRatio: Option[Double] = 
getConf(PYSPARK_PLOT_SAMPLE_RATIO)
+
   def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED)
 
   def arrowPySparkFallbackEnabled: Boolean = 
getConf(ARROW_PYSPARK_FALLBACK_ENABLED)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to