This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e1b2ac55b4b9 [SPARK-49767][PS][CONNECT] Refactor the internal function 
invocation
e1b2ac55b4b9 is described below

commit e1b2ac55b4b9463824d3f23eb7fbac88ede843d9
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Sep 25 19:25:47 2024 +0900

    [SPARK-49767][PS][CONNECT] Refactor the internal function invocation
    
    ### What changes were proposed in this pull request?
    Refactor the internal function invocation
    
    ### Why are the changes needed?
    by introducing a new helper function 
`_invoke_internal_function_over_columns`, we no longer need to add dedicated 
internal functions in `PythonSQLUtils`
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #48227 from zhengruifeng/py_fn.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/pandas/plot/core.py                 |   2 +-
 python/pyspark/pandas/spark/functions.py           | 175 +++------------------
 python/pyspark/pandas/window.py                    |   3 +-
 .../spark/sql/api/python/PythonSQLUtils.scala      |  43 +----
 .../apache/spark/sql/DataFrameSelfJoinSuite.scala  |   3 +-
 .../org/apache/spark/sql/DataFrameSuite.scala      |   3 +-
 6 files changed, 33 insertions(+), 196 deletions(-)

diff --git a/python/pyspark/pandas/plot/core.py 
b/python/pyspark/pandas/plot/core.py
index 6f036b766924..7333fae1ad43 100644
--- a/python/pyspark/pandas/plot/core.py
+++ b/python/pyspark/pandas/plot/core.py
@@ -215,7 +215,7 @@ class HistogramPlotBase(NumericPlotBase):
 
         # refers to 
org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets
         def binary_search_for_buckets(value: Column):
-            index = SF.binary_search(F.lit(bins), value)
+            index = SF.array_binary_search(F.lit(bins), value)
             bucket = F.when(index >= 0, index).otherwise(-index - 2)
             unboundErrMsg = F.lit(f"value %s out of the bins bounds: 
[{bins[0]}, {bins[-1]}]")
             return (
diff --git a/python/pyspark/pandas/spark/functions.py 
b/python/pyspark/pandas/spark/functions.py
index 4bcf07f6f650..4d95466a98e1 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -19,197 +19,72 @@ Additional Spark functions used in pandas-on-Spark.
 """
 from pyspark.sql import Column, functions as F
 from pyspark.sql.utils import is_remote
-from typing import Union
+from typing import Union, TYPE_CHECKING
 
+if TYPE_CHECKING:
+    from pyspark.sql._typing import ColumnOrName
 
-def product(col: Column, dropna: bool) -> Column:
+
+def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") 
-> Column:
     if is_remote():
-        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns, lit
+        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns
 
-        return _invoke_function_over_columns(
-            "pandas_product",
-            col,
-            lit(dropna),
-        )
+        return _invoke_function_over_columns(name, *cols)
 
     else:
+        from pyspark.sql.classic.column import _to_seq, _to_java_column
         from pyspark import SparkContext
 
         sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))
+        return Column(sc._jvm.PythonSQLUtils.internalFn(name, _to_seq(sc, 
cols, _to_java_column)))
 
 
-def stddev(col: Column, ddof: int) -> Column:
-    if is_remote():
-        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns, lit
-
-        return _invoke_function_over_columns(
-            "pandas_stddev",
-            col,
-            lit(ddof),
-        )
+def product(col: Column, dropna: bool) -> Column:
+    return _invoke_internal_function_over_columns("pandas_product", col, 
F.lit(dropna))
 
-    else:
-        from pyspark import SparkContext
 
-        sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
+def stddev(col: Column, ddof: int) -> Column:
+    return _invoke_internal_function_over_columns("pandas_stddev", col, 
F.lit(ddof))
 
 
 def var(col: Column, ddof: int) -> Column:
-    if is_remote():
-        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns, lit
-
-        return _invoke_function_over_columns(
-            "pandas_var",
-            col,
-            lit(ddof),
-        )
-
-    else:
-        from pyspark import SparkContext
-
-        sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
+    return _invoke_internal_function_over_columns("pandas_var", col, 
F.lit(ddof))
 
 
 def skew(col: Column) -> Column:
-    if is_remote():
-        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns
-
-        return _invoke_function_over_columns(
-            "pandas_skew",
-            col,
-        )
-
-    else:
-        from pyspark import SparkContext
-
-        sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.pandasSkewness(col._jc))
+    return _invoke_internal_function_over_columns("pandas_skew", col)
 
 
 def kurt(col: Column) -> Column:
-    if is_remote():
-        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns
-
-        return _invoke_function_over_columns(
-            "pandas_kurt",
-            col,
-        )
-
-    else:
-        from pyspark import SparkContext
-
-        sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.pandasKurtosis(col._jc))
+    return _invoke_internal_function_over_columns("pandas_kurt", col)
 
 
 def mode(col: Column, dropna: bool) -> Column:
-    if is_remote():
-        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns, lit
-
-        return _invoke_function_over_columns(
-            "pandas_mode",
-            col,
-            lit(dropna),
-        )
-
-    else:
-        from pyspark import SparkContext
-
-        sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))
+    return _invoke_internal_function_over_columns("pandas_mode", col, 
F.lit(dropna))
 
 
 def covar(col1: Column, col2: Column, ddof: int) -> Column:
-    if is_remote():
-        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns, lit
+    return _invoke_internal_function_over_columns("pandas_covar", col1, col2, 
F.lit(ddof))
 
-        return _invoke_function_over_columns(
-            "pandas_covar",
-            col1,
-            col2,
-            lit(ddof),
-        )
 
-    else:
-        from pyspark import SparkContext
-
-        sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc, 
ddof))
-
-
-def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
-    if is_remote():
-        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns, lit
-
-        return _invoke_function_over_columns(
-            "ewm",
-            col,
-            lit(alpha),
-            lit(ignore_na),
-        )
-
-    else:
-        from pyspark import SparkContext
-
-        sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))
+def ewm(col: Column, alpha: float, ignorena: bool) -> Column:
+    return _invoke_internal_function_over_columns("ewm", col, F.lit(alpha), 
F.lit(ignorena))
 
 
 def null_index(col: Column) -> Column:
-    if is_remote():
-        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns
-
-        return _invoke_function_over_columns(
-            "null_index",
-            col,
-        )
-
-    else:
-        from pyspark import SparkContext
-
-        sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
+    return _invoke_internal_function_over_columns("null_index", col)
 
 
 def distributed_sequence_id() -> Column:
-    if is_remote():
-        from pyspark.sql.connect.functions.builtin import _invoke_function
-
-        return _invoke_function("distributed_sequence_id")
-    else:
-        from pyspark import SparkContext
-
-        sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id())
+    return _invoke_internal_function_over_columns("distributed_sequence_id")
 
 
 def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
-    if is_remote():
-        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns
+    return _invoke_internal_function_over_columns("collect_top_k", col, 
F.lit(num), F.lit(reverse))
 
-        return _invoke_function_over_columns("collect_top_k", col, F.lit(num), 
F.lit(reverse))
 
-    else:
-        from pyspark import SparkContext
-
-        sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, 
reverse))
-
-
-def binary_search(col: Column, value: Column) -> Column:
-    if is_remote():
-        from pyspark.sql.connect.functions.builtin import 
_invoke_function_over_columns
-
-        return _invoke_function_over_columns("array_binary_search", col, value)
-
-    else:
-        from pyspark import SparkContext
-
-        sc = SparkContext._active_spark_context
-        return Column(sc._jvm.PythonSQLUtils.binary_search(col._jc, value._jc))
+def array_binary_search(col: Column, value: Column) -> Column:
+    return _invoke_internal_function_over_columns("array_binary_search", col, 
value)
 
 
 def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py
index 0aaeb7df89be..fb5dd29169e9 100644
--- a/python/pyspark/pandas/window.py
+++ b/python/pyspark/pandas/window.py
@@ -2434,7 +2434,8 @@ class ExponentialMovingLike(Generic[FrameLike], 
metaclass=ABCMeta):
         if opt_count != 1:
             raise ValueError("com, span, halflife, and alpha are mutually 
exclusive")
 
-        return unified_alpha
+        # convert possible numpy.float64 to float for lit function
+        return float(unified_alpha)
 
     @abstractmethod
     def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) -> 
FrameLike:
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index bc270e6ac64a..3504f6e76f79 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.execution.{ExplainMode, QueryExecution}
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, StructType}
@@ -147,45 +146,6 @@ private[sql] object PythonSQLUtils extends Logging {
   def castTimestampNTZToLong(c: Column): Column =
     Column.internalFn("timestamp_ntz_to_long", c)
 
-  def ewm(e: Column, alpha: Double, ignoreNA: Boolean): Column =
-    Column.internalFn("ewm", e, lit(alpha), lit(ignoreNA))
-
-  def nullIndex(e: Column): Column = Column.internalFn("null_index", e)
-
-  def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
-    Column.internalFn("collect_top_k", e, lit(num), lit(reverse))
-
-  def binary_search(e: Column, value: Column): Column =
-    Column.internalFn("array_binary_search", e, value)
-
-  def pandasProduct(e: Column, ignoreNA: Boolean): Column =
-    Column.internalFn("pandas_product", e, lit(ignoreNA))
-
-  def pandasStddev(e: Column, ddof: Int): Column =
-    Column.internalFn("pandas_stddev", e, lit(ddof))
-
-  def pandasVariance(e: Column, ddof: Int): Column =
-    Column.internalFn("pandas_var", e, lit(ddof))
-
-  def pandasSkewness(e: Column): Column =
-    Column.internalFn("pandas_skew", e)
-
-  def pandasKurtosis(e: Column): Column =
-    Column.internalFn("pandas_kurt", e)
-
-  def pandasMode(e: Column, ignoreNA: Boolean): Column =
-    Column.internalFn("pandas_mode", e, lit(ignoreNA))
-
-  def pandasCovar(col1: Column, col2: Column, ddof: Int): Column =
-    Column.internalFn("pandas_covar", col1, col2, lit(ddof))
-
-  /**
-   * A long column that increases one by one.
-   * This is for 'distributed-sequence' default index in pandas API on Spark.
-   */
-  def distributed_sequence_id(): Column =
-    Column.internalFn("distributed_sequence_id")
-
   def unresolvedNamedLambdaVariable(name: String): Column =
     Column(internal.UnresolvedNamedLambdaVariable.apply(name))
 
@@ -205,6 +165,9 @@ private[sql] object PythonSQLUtils extends Logging {
 
   @scala.annotation.varargs
   def fn(name: String, arguments: Column*): Column = Column.fn(name, 
arguments: _*)
+
+  @scala.annotation.varargs
+  def internalFn(name: String, inputs: Column*): Column = 
Column.internalFn(name, inputs: _*)
 }
 
 /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index 1d7698df2f1b..f0ed2241fd28 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql
 
 import org.apache.spark.api.python.PythonEvalType
-import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
 import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, 
AttributeReference, PythonUDF, SortOrder}
 import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, 
ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
 import org.apache.spark.sql.expressions.Window
@@ -405,7 +404,7 @@ class DataFrameSelfJoinSuite extends QueryTest with 
SharedSparkSession {
     assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))
 
     // Test for AttachDistributedSequence
-    val df13 = df1.select(distributed_sequence_id().alias("seq"), col("*"))
+    val df13 = 
df1.select(Column.internalFn("distributed_sequence_id").alias("seq"), col("*"))
     val df14 = df13.filter($"value" === "A2")
     assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2")))
     assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2")))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index e1774cab4a0d..2c0d9e29bb27 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -29,7 +29,6 @@ import org.scalatest.matchers.should.Matchers._
 
 import org.apache.spark.SparkException
 import org.apache.spark.api.python.PythonEvalType
-import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
 import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -2318,7 +2317,7 @@ class DataFrameSuite extends QueryTest
 
   test("SPARK-36338: DataFrame.withSequenceColumn should append unique 
sequence IDs") {
     val ids = spark.range(10).repartition(5).select(
-      distributed_sequence_id().alias("default_index"), col("id"))
+      Column.internalFn("distributed_sequence_id").alias("default_index"), 
col("id"))
     assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet)
     assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to