This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e8c79634b2a1 [SPARK-50171][PYTHON] Make numpy optional for KDE plot
e8c79634b2a1 is described below
commit e8c79634b2a1571667f9f390a5051d0d77114f45
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Oct 30 10:52:55 2024 -0700
[SPARK-50171][PYTHON] Make numpy optional for KDE plot
### What changes were proposed in this pull request?
Make numpy optional for KDE plot
### Why are the changes needed?
to support KDE even if numpy is not available
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48705 from zhengruifeng/kde_np.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/plot/core.py | 18 +++++++++++-------
python/pyspark/sql/plot/plotly.py | 7 +++++++
.../pyspark/sql/tests/plot/test_frame_plot_plotly.py | 3 ---
3 files changed, 18 insertions(+), 10 deletions(-)
diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
index c20912dda90a..ab8f3463302b 100644
--- a/python/pyspark/sql/plot/core.py
+++ b/python/pyspark/sql/plot/core.py
@@ -17,7 +17,7 @@
import math
-from typing import Any, TYPE_CHECKING, List, Optional, Union
+from typing import Any, TYPE_CHECKING, List, Optional, Union, Sequence
from types import ModuleType
from pyspark.errors import (
PySparkRuntimeError,
@@ -489,10 +489,14 @@ class PySparkPlotAccessor:
class PySparkKdePlotBase:
@staticmethod
- def get_ind(sdf: "DataFrame", ind: Optional[Union["np.ndarray", int]]) ->
"np.ndarray":
- require_minimum_numpy_version()
- import numpy as np
+ def linspace(start, stop, num): # type: ignore[no-untyped-def]
+ if num == 1:
+ return [float(start)]
+ step = float(stop - start) / (num - 1)
+ return [start + step * i for i in range(num)]
+ @staticmethod
+ def get_ind(sdf: "DataFrame", ind: Optional[Union[Sequence[float], int]])
-> Sequence[float]:
def calc_min_max() -> "Row":
if len(sdf.columns) > 1:
min_col = F.least(*map(F.min, sdf)) # type: ignore
@@ -505,7 +509,7 @@ class PySparkKdePlotBase:
if ind is None:
min_val, max_val = calc_min_max()
sample_range = max_val - min_val
- ind = np.linspace(
+ ind = PySparkKdePlotBase.linspace(
min_val - 0.5 * sample_range,
max_val + 0.5 * sample_range,
1000,
@@ -513,7 +517,7 @@ class PySparkKdePlotBase:
elif is_integer(ind):
min_val, max_val = calc_min_max()
sample_range = max_val - min_val
- ind = np.linspace(
+ ind = PySparkKdePlotBase.linspace(
min_val - 0.5 * sample_range,
max_val + 0.5 * sample_range,
ind,
@@ -524,7 +528,7 @@ class PySparkKdePlotBase:
def compute_kde_col(
input_col: Column,
bw_method: Union[int, float],
- ind: "np.ndarray",
+ ind: Sequence[float],
) -> Column:
# refers to org.apache.spark.mllib.stat.KernelDensity
assert bw_method is not None and isinstance(
diff --git a/python/pyspark/sql/plot/plotly.py
b/python/pyspark/sql/plot/plotly.py
index ceae4b999aa8..263de4d25670 100644
--- a/python/pyspark/sql/plot/plotly.py
+++ b/python/pyspark/sql/plot/plotly.py
@@ -130,6 +130,7 @@ def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure":
def plot_kde(data: "DataFrame", **kwargs: Any) -> "Figure":
+ from pyspark.sql.utils import has_numpy
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
@@ -144,6 +145,12 @@ def plot_kde(data: "DataFrame", **kwargs: Any) -> "Figure":
colnames = process_column_param(kwargs.pop("column", None), data)
ind = PySparkKdePlotBase.get_ind(data.select(*colnames), kwargs.pop("ind",
None))
+ if has_numpy:
+ import numpy as np
+
+ if isinstance(ind, np.ndarray):
+ ind = [float(i) for i in ind]
+
kde_cols = [
PySparkKdePlotBase.compute_kde_col(
input_col=data[col_name],
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
index 362d1225416a..84a9c2aa0170 100644
--- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
+++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
@@ -22,9 +22,7 @@ from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_plotly,
- have_numpy,
plotly_requirement_message,
- numpy_requirement_message,
have_pandas,
pandas_requirement_message,
)
@@ -392,7 +390,6 @@ class DataFramePlotPlotlyTestsMixin:
},
)
- @unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_kde_plot(self):
fig = self.sdf4.plot.kde(column="math_score", bw_method=0.3, ind=5)
expected_fig_data1 = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]