This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e8c79634b2a1 [SPARK-50171][PYTHON] Make numpy optional for KDE plot
e8c79634b2a1 is described below

commit e8c79634b2a1571667f9f390a5051d0d77114f45
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Oct 30 10:52:55 2024 -0700

    [SPARK-50171][PYTHON] Make numpy optional for KDE plot
    
    ### What changes were proposed in this pull request?
    Make numpy optional for KDE plot
    
    ### Why are the changes needed?
    to support KDE even if numpy is not available
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #48705 from zhengruifeng/kde_np.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/sql/plot/core.py                        | 18 +++++++++++-------
 python/pyspark/sql/plot/plotly.py                      |  7 +++++++
 .../pyspark/sql/tests/plot/test_frame_plot_plotly.py   |  3 ---
 3 files changed, 18 insertions(+), 10 deletions(-)

diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
index c20912dda90a..ab8f3463302b 100644
--- a/python/pyspark/sql/plot/core.py
+++ b/python/pyspark/sql/plot/core.py
@@ -17,7 +17,7 @@
 
 import math
 
-from typing import Any, TYPE_CHECKING, List, Optional, Union
+from typing import Any, TYPE_CHECKING, List, Optional, Union, Sequence
 from types import ModuleType
 from pyspark.errors import (
     PySparkRuntimeError,
@@ -489,10 +489,14 @@ class PySparkPlotAccessor:
 
 class PySparkKdePlotBase:
     @staticmethod
-    def get_ind(sdf: "DataFrame", ind: Optional[Union["np.ndarray", int]]) -> 
"np.ndarray":
-        require_minimum_numpy_version()
-        import numpy as np
+    def linspace(start, stop, num):  # type: ignore[no-untyped-def]
+        if num == 1:
+            return [float(start)]
+        step = float(stop - start) / (num - 1)
+        return [start + step * i for i in range(num)]
 
+    @staticmethod
+    def get_ind(sdf: "DataFrame", ind: Optional[Union[Sequence[float], int]]) 
-> Sequence[float]:
         def calc_min_max() -> "Row":
             if len(sdf.columns) > 1:
                 min_col = F.least(*map(F.min, sdf))  # type: ignore
@@ -505,7 +509,7 @@ class PySparkKdePlotBase:
         if ind is None:
             min_val, max_val = calc_min_max()
             sample_range = max_val - min_val
-            ind = np.linspace(
+            ind = PySparkKdePlotBase.linspace(
                 min_val - 0.5 * sample_range,
                 max_val + 0.5 * sample_range,
                 1000,
@@ -513,7 +517,7 @@ class PySparkKdePlotBase:
         elif is_integer(ind):
             min_val, max_val = calc_min_max()
             sample_range = max_val - min_val
-            ind = np.linspace(
+            ind = PySparkKdePlotBase.linspace(
                 min_val - 0.5 * sample_range,
                 max_val + 0.5 * sample_range,
                 ind,
@@ -524,7 +528,7 @@ class PySparkKdePlotBase:
     def compute_kde_col(
         input_col: Column,
         bw_method: Union[int, float],
-        ind: "np.ndarray",
+        ind: Sequence[float],
     ) -> Column:
         # refers to org.apache.spark.mllib.stat.KernelDensity
         assert bw_method is not None and isinstance(
diff --git a/python/pyspark/sql/plot/plotly.py 
b/python/pyspark/sql/plot/plotly.py
index ceae4b999aa8..263de4d25670 100644
--- a/python/pyspark/sql/plot/plotly.py
+++ b/python/pyspark/sql/plot/plotly.py
@@ -130,6 +130,7 @@ def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure":
 
 
 def plot_kde(data: "DataFrame", **kwargs: Any) -> "Figure":
+    from pyspark.sql.utils import has_numpy
     from pyspark.sql.pandas.utils import require_minimum_pandas_version
 
     require_minimum_pandas_version()
@@ -144,6 +145,12 @@ def plot_kde(data: "DataFrame", **kwargs: Any) -> "Figure":
     colnames = process_column_param(kwargs.pop("column", None), data)
     ind = PySparkKdePlotBase.get_ind(data.select(*colnames), kwargs.pop("ind", 
None))
 
+    if has_numpy:
+        import numpy as np
+
+        if isinstance(ind, np.ndarray):
+            ind = [float(i) for i in ind]
+
     kde_cols = [
         PySparkKdePlotBase.compute_kde_col(
             input_col=data[col_name],
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py 
b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
index 362d1225416a..84a9c2aa0170 100644
--- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
+++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
@@ -22,9 +22,7 @@ from pyspark.errors import PySparkTypeError, PySparkValueError
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
     have_plotly,
-    have_numpy,
     plotly_requirement_message,
-    numpy_requirement_message,
     have_pandas,
     pandas_requirement_message,
 )
@@ -392,7 +390,6 @@ class DataFramePlotPlotlyTestsMixin:
             },
         )
 
-    @unittest.skipIf(not have_numpy, numpy_requirement_message)
     def test_kde_plot(self):
         fig = self.sdf4.plot.kde(column="math_score", bw_method=0.3, ind=5)
         expected_fig_data1 = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to