This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e1b2ac55b4b9 [SPARK-49767][PS][CONNECT] Refactor the internal function
invocation
e1b2ac55b4b9 is described below
commit e1b2ac55b4b9463824d3f23eb7fbac88ede843d9
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Sep 25 19:25:47 2024 +0900
[SPARK-49767][PS][CONNECT] Refactor the internal function invocation
### What changes were proposed in this pull request?
Refactor the internal function invocation
### Why are the changes needed?
by introducing a new helper function
`_invoke_internal_function_over_columns`, we no longer need to add dedicated
internal functions in `PythonSQLUtils`
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48227 from zhengruifeng/py_fn.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/plot/core.py | 2 +-
python/pyspark/pandas/spark/functions.py | 175 +++------------------
python/pyspark/pandas/window.py | 3 +-
.../spark/sql/api/python/PythonSQLUtils.scala | 43 +----
.../apache/spark/sql/DataFrameSelfJoinSuite.scala | 3 +-
.../org/apache/spark/sql/DataFrameSuite.scala | 3 +-
6 files changed, 33 insertions(+), 196 deletions(-)
diff --git a/python/pyspark/pandas/plot/core.py
b/python/pyspark/pandas/plot/core.py
index 6f036b766924..7333fae1ad43 100644
--- a/python/pyspark/pandas/plot/core.py
+++ b/python/pyspark/pandas/plot/core.py
@@ -215,7 +215,7 @@ class HistogramPlotBase(NumericPlotBase):
# refers to
org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets
def binary_search_for_buckets(value: Column):
- index = SF.binary_search(F.lit(bins), value)
+ index = SF.array_binary_search(F.lit(bins), value)
bucket = F.when(index >= 0, index).otherwise(-index - 2)
unboundErrMsg = F.lit(f"value %s out of the bins bounds:
[{bins[0]}, {bins[-1]}]")
return (
diff --git a/python/pyspark/pandas/spark/functions.py
b/python/pyspark/pandas/spark/functions.py
index 4bcf07f6f650..4d95466a98e1 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -19,197 +19,72 @@ Additional Spark functions used in pandas-on-Spark.
"""
from pyspark.sql import Column, functions as F
from pyspark.sql.utils import is_remote
-from typing import Union
+from typing import Union, TYPE_CHECKING
+if TYPE_CHECKING:
+ from pyspark.sql._typing import ColumnOrName
-def product(col: Column, dropna: bool) -> Column:
+
+def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName")
-> Column:
if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns, lit
+ from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns
- return _invoke_function_over_columns(
- "pandas_product",
- col,
- lit(dropna),
- )
+ return _invoke_function_over_columns(name, *cols)
else:
+ from pyspark.sql.classic.column import _to_seq, _to_java_column
from pyspark import SparkContext
sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))
+ return Column(sc._jvm.PythonSQLUtils.internalFn(name, _to_seq(sc,
cols, _to_java_column)))
-def stddev(col: Column, ddof: int) -> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns, lit
-
- return _invoke_function_over_columns(
- "pandas_stddev",
- col,
- lit(ddof),
- )
+def product(col: Column, dropna: bool) -> Column:
+ return _invoke_internal_function_over_columns("pandas_product", col,
F.lit(dropna))
- else:
- from pyspark import SparkContext
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
+def stddev(col: Column, ddof: int) -> Column:
+ return _invoke_internal_function_over_columns("pandas_stddev", col,
F.lit(ddof))
def var(col: Column, ddof: int) -> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns, lit
-
- return _invoke_function_over_columns(
- "pandas_var",
- col,
- lit(ddof),
- )
-
- else:
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
+ return _invoke_internal_function_over_columns("pandas_var", col,
F.lit(ddof))
def skew(col: Column) -> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns
-
- return _invoke_function_over_columns(
- "pandas_skew",
- col,
- )
-
- else:
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.pandasSkewness(col._jc))
+ return _invoke_internal_function_over_columns("pandas_skew", col)
def kurt(col: Column) -> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns
-
- return _invoke_function_over_columns(
- "pandas_kurt",
- col,
- )
-
- else:
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.pandasKurtosis(col._jc))
+ return _invoke_internal_function_over_columns("pandas_kurt", col)
def mode(col: Column, dropna: bool) -> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns, lit
-
- return _invoke_function_over_columns(
- "pandas_mode",
- col,
- lit(dropna),
- )
-
- else:
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))
+ return _invoke_internal_function_over_columns("pandas_mode", col,
F.lit(dropna))
def covar(col1: Column, col2: Column, ddof: int) -> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns, lit
+ return _invoke_internal_function_over_columns("pandas_covar", col1, col2,
F.lit(ddof))
- return _invoke_function_over_columns(
- "pandas_covar",
- col1,
- col2,
- lit(ddof),
- )
- else:
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc,
ddof))
-
-
-def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns, lit
-
- return _invoke_function_over_columns(
- "ewm",
- col,
- lit(alpha),
- lit(ignore_na),
- )
-
- else:
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))
+def ewm(col: Column, alpha: float, ignorena: bool) -> Column:
+ return _invoke_internal_function_over_columns("ewm", col, F.lit(alpha),
F.lit(ignorena))
def null_index(col: Column) -> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns
-
- return _invoke_function_over_columns(
- "null_index",
- col,
- )
-
- else:
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
+ return _invoke_internal_function_over_columns("null_index", col)
def distributed_sequence_id() -> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import _invoke_function
-
- return _invoke_function("distributed_sequence_id")
- else:
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id())
+ return _invoke_internal_function_over_columns("distributed_sequence_id")
def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns
+ return _invoke_internal_function_over_columns("collect_top_k", col,
F.lit(num), F.lit(reverse))
- return _invoke_function_over_columns("collect_top_k", col, F.lit(num),
F.lit(reverse))
- else:
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num,
reverse))
-
-
-def binary_search(col: Column, value: Column) -> Column:
- if is_remote():
- from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns
-
- return _invoke_function_over_columns("array_binary_search", col, value)
-
- else:
- from pyspark import SparkContext
-
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.binary_search(col._jc, value._jc))
+def array_binary_search(col: Column, value: Column) -> Column:
+ return _invoke_internal_function_over_columns("array_binary_search", col,
value)
def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py
index 0aaeb7df89be..fb5dd29169e9 100644
--- a/python/pyspark/pandas/window.py
+++ b/python/pyspark/pandas/window.py
@@ -2434,7 +2434,8 @@ class ExponentialMovingLike(Generic[FrameLike],
metaclass=ABCMeta):
if opt_count != 1:
raise ValueError("com, span, halflife, and alpha are mutually
exclusive")
- return unified_alpha
+ # convert possible numpy.float64 to float for lit function
+ return float(unified_alpha)
@abstractmethod
def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) ->
FrameLike:
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index bc270e6ac64a..3504f6e76f79 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.execution.{ExplainMode, QueryExecution}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
@@ -147,45 +146,6 @@ private[sql] object PythonSQLUtils extends Logging {
def castTimestampNTZToLong(c: Column): Column =
Column.internalFn("timestamp_ntz_to_long", c)
- def ewm(e: Column, alpha: Double, ignoreNA: Boolean): Column =
- Column.internalFn("ewm", e, lit(alpha), lit(ignoreNA))
-
- def nullIndex(e: Column): Column = Column.internalFn("null_index", e)
-
- def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
- Column.internalFn("collect_top_k", e, lit(num), lit(reverse))
-
- def binary_search(e: Column, value: Column): Column =
- Column.internalFn("array_binary_search", e, value)
-
- def pandasProduct(e: Column, ignoreNA: Boolean): Column =
- Column.internalFn("pandas_product", e, lit(ignoreNA))
-
- def pandasStddev(e: Column, ddof: Int): Column =
- Column.internalFn("pandas_stddev", e, lit(ddof))
-
- def pandasVariance(e: Column, ddof: Int): Column =
- Column.internalFn("pandas_var", e, lit(ddof))
-
- def pandasSkewness(e: Column): Column =
- Column.internalFn("pandas_skew", e)
-
- def pandasKurtosis(e: Column): Column =
- Column.internalFn("pandas_kurt", e)
-
- def pandasMode(e: Column, ignoreNA: Boolean): Column =
- Column.internalFn("pandas_mode", e, lit(ignoreNA))
-
- def pandasCovar(col1: Column, col2: Column, ddof: Int): Column =
- Column.internalFn("pandas_covar", col1, col2, lit(ddof))
-
- /**
- * A long column that increases one by one.
- * This is for 'distributed-sequence' default index in pandas API on Spark.
- */
- def distributed_sequence_id(): Column =
- Column.internalFn("distributed_sequence_id")
-
def unresolvedNamedLambdaVariable(name: String): Column =
Column(internal.UnresolvedNamedLambdaVariable.apply(name))
@@ -205,6 +165,9 @@ private[sql] object PythonSQLUtils extends Logging {
@scala.annotation.varargs
def fn(name: String, arguments: Column*): Column = Column.fn(name,
arguments: _*)
+
+ @scala.annotation.varargs
+ def internalFn(name: String, inputs: Column*): Column =
Column.internalFn(name, inputs: _*)
}
/**
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index 1d7698df2f1b..f0ed2241fd28 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql
import org.apache.spark.api.python.PythonEvalType
-import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending,
AttributeReference, PythonUDF, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate,
ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
import org.apache.spark.sql.expressions.Window
@@ -405,7 +404,7 @@ class DataFrameSelfJoinSuite extends QueryTest with
SharedSparkSession {
assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))
// Test for AttachDistributedSequence
- val df13 = df1.select(distributed_sequence_id().alias("seq"), col("*"))
+ val df13 =
df1.select(Column.internalFn("distributed_sequence_id").alias("seq"), col("*"))
val df14 = df13.filter($"value" === "A2")
assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2")))
assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2")))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index e1774cab4a0d..2c0d9e29bb27 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -29,7 +29,6 @@ import org.scalatest.matchers.should.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.api.python.PythonEvalType
-import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -2318,7 +2317,7 @@ class DataFrameSuite extends QueryTest
test("SPARK-36338: DataFrame.withSequenceColumn should append unique
sequence IDs") {
val ids = spark.range(10).repartition(5).select(
- distributed_sequence_id().alias("default_index"), col("id"))
+ Column.internalFn("distributed_sequence_id").alias("default_index"),
col("id"))
assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet)
assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]