This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2316fed65192 [SPARK-50170][PYTHON] Move
`_invoke_internal_function_over_columns` to `pyspark.sql.utils`
2316fed65192 is described below
commit 2316fed65192c32ca211e35399246029687f0007
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Oct 30 19:17:44 2024 +0800
[SPARK-50170][PYTHON] Move `_invoke_internal_function_over_columns` to
`pyspark.sql.utils`
### What changes were proposed in this pull request?
Move `_invoke_internal_function_over_columns` to `pyspark.sql.utils`
### Why are the changes needed?
code deduplication, to make it easy to reuse
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48703 from zhengruifeng/mv_internal_fn.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/pandas/spark/functions.py | 21 ++-------------
python/pyspark/sql/plot/core.py | 44 ++++++++++----------------------
python/pyspark/sql/utils.py | 20 +++++++++++++++
3 files changed, 36 insertions(+), 49 deletions(-)
diff --git a/python/pyspark/pandas/spark/functions.py
b/python/pyspark/pandas/spark/functions.py
index 53146a163b1e..a6b8e79ca50f 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -18,25 +18,8 @@
Additional Spark functions used in pandas-on-Spark.
"""
from pyspark.sql import Column, functions as F
-from pyspark.sql.utils import is_remote
-from typing import Union, TYPE_CHECKING
-
-if TYPE_CHECKING:
- from pyspark.sql._typing import ColumnOrName
-
-
-def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName")
-> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns
-
- return _invoke_function_over_columns(name, *cols)
-
- else:
- from pyspark.sql.classic.column import _to_seq, _to_java_column
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.internalFn(name, _to_seq(sc,
cols, _to_java_column)))
+from pyspark.sql.utils import _invoke_internal_function_over_columns
+from typing import Union
def timestamp_ntz_to_long(col: Column) -> Column:
diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
index 328ebe348878..d63837ced8c8 100644
--- a/python/pyspark/sql/plot/core.py
+++ b/python/pyspark/sql/plot/core.py
@@ -27,13 +27,13 @@ from pyspark.errors import (
from pyspark.sql import Column, functions as F
from pyspark.sql.pandas.utils import require_minimum_numpy_version,
require_minimum_pandas_version
from pyspark.sql.types import NumericType
-from pyspark.sql.utils import is_remote, require_minimum_plotly_version
+from pyspark.sql.utils import require_minimum_plotly_version,
_invoke_internal_function_over_columns
+
from pandas.core.dtypes.inference import is_integer
if TYPE_CHECKING:
from pyspark.sql import DataFrame, Row
- from pyspark.sql._typing import ColumnOrName
import pandas as pd
import numpy as np
from plotly.graph_objs import Figure
@@ -564,6 +564,10 @@ class PySparkKdePlotBase:
class PySparkHistogramPlotBase:
+ @staticmethod
+ def array_binary_search(col: Column, value: Column) -> Column:
+ return _invoke_internal_function_over_columns("array_binary_search",
col, value)
+
@staticmethod
def get_bins(sdf: "DataFrame", bins: int) -> "np.ndarray":
require_minimum_numpy_version()
@@ -615,7 +619,7 @@ class PySparkHistogramPlotBase:
# determines which bucket a given value falls into, based on
predefined bin intervals
# refers to
org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets
def binary_search_for_buckets(value: Column) -> Column:
- index = array_binary_search(F.lit(bins), value)
+ index = PySparkHistogramPlotBase.array_binary_search(F.lit(bins),
value)
bucket = F.when(index >= 0, index).otherwise(-index - 2)
unboundErrMsg = F.lit(f"value %s out of the bins bounds:
[{bins[0]}, {bins[-1]}]")
return (
@@ -709,6 +713,12 @@ class PySparkHistogramPlotBase:
class PySparkBoxPlotBase:
+ @staticmethod
+ def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
+ return _invoke_internal_function_over_columns(
+ "collect_top_k", col, F.lit(num), F.lit(reverse)
+ )
+
@staticmethod
def compute_box(
sdf: "DataFrame", colnames: List[str], whis: float, precision: float,
showfliers: bool
@@ -763,7 +773,7 @@ class PySparkBoxPlotBase:
outlier,
F.struct(F.abs(value - med), value.alias("val")),
).otherwise(F.lit(None))
- topk = collect_top_k(pair, 1001, False)
+ topk = PySparkBoxPlotBase.collect_top_k(pair, 1001, False)
fliers = F.when(F.size(topk) > 0,
topk["val"]).otherwise(F.lit(None))
else:
fliers = F.lit(None)
@@ -782,29 +792,3 @@ class PySparkBoxPlotBase:
sdf_result =
sdf.join(sdf_stats.hint("broadcast")).select(*result_scols)
return sdf_result.first()
-
-
-def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName")
-> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns
-
- return _invoke_function_over_columns(name, *cols)
-
- else:
- from pyspark.sql.classic.column import _to_seq, _to_java_column
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(
- sc._jvm.PythonSQLUtils.internalFn( # type: ignore
- name, _to_seq(sc, cols, _to_java_column) # type: ignore
- )
- )
-
-
-def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
- return _invoke_internal_function_over_columns("collect_top_k", col,
F.lit(num), F.lit(reverse))
-
-
-def array_binary_search(col: Column, value: Column) -> Column:
- return _invoke_internal_function_over_columns("array_binary_search", col,
value)
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 5d9ec92cbc83..b961fae19151 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -58,8 +58,10 @@ if TYPE_CHECKING:
JVMView,
)
from pyspark import SparkContext
+ from pyspark.sql import Column
from pyspark.sql.session import SparkSession
from pyspark.sql.dataframe import DataFrame
+ from pyspark.sql._typing import ColumnOrName
from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
has_numpy: bool = False
@@ -216,6 +218,24 @@ def enum_to_value(value: Any) -> Any:
return enum_to_value(value.value) if value is not None and
isinstance(value, Enum) else value
+def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName")
-> "Column":
+ if is_remote():
+ from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns
+
+ return _invoke_function_over_columns(name, *cols)
+
+ else:
+ from pyspark.sql.classic.column import Column, _to_seq, _to_java_column
+ from pyspark import SparkContext
+
+ sc = SparkContext._active_spark_context
+ return Column(
+ sc._jvm.PythonSQLUtils.internalFn( # type: ignore
+ name, _to_seq(sc, cols, _to_java_column) # type: ignore
+ )
+ )
+
+
def is_timestamp_ntz_preferred() -> bool:
"""
Return a bool if TimestampNTZType is preferred according to the SQL
configuration set.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]