This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new d422aed4a2e8 [SPARK-46684][PYTHON][CONNECT][3.5] Fix 
CoGroup.applyInPandas/Arrow to pass arguments properly
d422aed4a2e8 is described below

commit d422aed4a2e82d671a592b096919015bddeb751f
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Fri Jan 12 13:19:12 2024 +0900

    [SPARK-46684][PYTHON][CONNECT][3.5] Fix CoGroup.applyInPandas/Arrow to pass 
arguments properly
    
    ### What changes were proposed in this pull request?
    
    This is a backport of apache/spark#44695.
    
    Fix `CoGroup.applyInPandas/Arrow` to pass arguments properly.
    
    ### Why are the changes needed?
    
    In Spark Connect, `CoGroup.applyInPandas/Arrow` doesn't take arguments 
properly, so the arguments of the UDF can be broken:
    
    ```py
    >>> import pandas as pd
    >>>
    >>> df1 = spark.createDataFrame(
    ...     [(1, 1.0, "a"), (2, 2.0, "b"), (1, 3.0, "c"), (2, 4.0, "d")], 
("id", "v1", "v2")
    ... )
    >>> df2 = spark.createDataFrame([(1, "x"), (2, "y"), (1, "z")], ("id", 
"v3"))
    >>>
    >>> def summarize(left, right):
    ...     return pd.DataFrame(
    ...         {
    ...             "left_rows": [len(left)],
    ...             "left_columns": [len(left.columns)],
    ...             "right_rows": [len(right)],
    ...             "right_columns": [len(right.columns)],
    ...         }
    ...     )
    ...
    >>> df = (
    ...     df1.groupby("id")
    ...     .cogroup(df2.groupby("id"))
    ...     .applyInPandas(
    ...         summarize,
    ...         schema="left_rows long, left_columns long, right_rows long, 
right_columns long",
    ...     )
    ... )
    >>>
    >>> df.show()
    +---------+------------+----------+-------------+
    |left_rows|left_columns|right_rows|right_columns|
    +---------+------------+----------+-------------+
    |        2|           1|         2|            1|
    |        2|           1|         1|            1|
    +---------+------------+----------+-------------+
    ```
    
    The result should be:
    
    ```py
    +---------+------------+----------+-------------+
    |left_rows|left_columns|right_rows|right_columns|
    +---------+------------+----------+-------------+
    |        2|           3|         2|            2|
    |        2|           3|         1|            2|
    +---------+------------+----------+-------------+
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    This is a bug fix.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44696 from ueshin/issues/SPARK-46684/3.5/cogroup.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  | 32 ++++++++++----------
 python/pyspark/sql/dataframe.py                    |  8 +++--
 .../sql/tests/pandas/test_pandas_cogrouped_map.py  | 35 ++++++++++++++++++++++
 3 files changed, 57 insertions(+), 18 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 50a55f5e6411..709e0811e5de 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -674,8 +674,6 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) 
extends Logging {
         transformTypedCoGroupMap(rel, commonUdf)
 
       case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
-        val pythonUdf = transformPythonUDF(commonUdf)
-
         val inputCols =
           rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
             Column(transformExpression(expr)))
@@ -690,6 +688,10 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
           .ofRows(session, transformRelation(rel.getOther))
           .groupBy(otherCols: _*)
 
+        val pythonUdf = createUserDefinedPythonFunction(commonUdf)
+          .builder(input.df.logicalPlan.output ++ other.df.logicalPlan.output)
+          .asInstanceOf[PythonUDF]
+
         input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
 
       case _ =>
@@ -1587,17 +1589,23 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
 
   private def transformPythonFuncExpression(
       fun: proto.CommonInlineUserDefinedFunction): Expression = {
+    createUserDefinedPythonFunction(fun)
+      .builder(fun.getArgumentsList.asScala.map(transformExpression).toSeq) 
match {
+      case udaf: PythonUDAF => udaf.toAggregateExpression()
+      case other => other
+    }
+  }
+
+  private def createUserDefinedPythonFunction(
+      fun: proto.CommonInlineUserDefinedFunction): UserDefinedPythonFunction = 
{
     val udf = fun.getPythonUdf
+    val function = transformPythonFunction(udf)
     UserDefinedPythonFunction(
       name = fun.getFunctionName,
-      func = transformPythonFunction(udf),
+      func = function,
       dataType = transformDataType(udf.getOutputType),
       pythonEvalType = udf.getEvalType,
       udfDeterministic = fun.getDeterministic)
-      .builder(fun.getArgumentsList.asScala.map(transformExpression).toSeq) 
match {
-      case udaf: PythonUDAF => udaf.toAggregateExpression()
-      case other => other
-    }
   }
 
   private def transformPythonFunction(fun: proto.PythonUDF): 
SimplePythonFunction = {
@@ -2584,15 +2592,7 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
   }
 
   private def handleRegisterPythonUDF(fun: 
proto.CommonInlineUserDefinedFunction): Unit = {
-    val udf = fun.getPythonUdf
-    val function = transformPythonFunction(udf)
-    val udpf = UserDefinedPythonFunction(
-      name = fun.getFunctionName,
-      func = function,
-      dataType = transformDataType(udf.getOutputType),
-      pythonEvalType = udf.getEvalType,
-      udfDeterministic = fun.getDeterministic)
-
+    val udpf = createUserDefinedPythonFunction(fun)
     session.udf.registerPython(fun.getFunctionName, udpf)
   }
 
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 5707ae2a31fe..7c382ab1c5a5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -942,7 +942,11 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         age  | 16
         name | Bob
         """
+        print(self._show_string(n, truncate, vertical))
 
+    def _show_string(
+        self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = 
False
+    ) -> str:
         if not isinstance(n, int) or isinstance(n, bool):
             raise PySparkTypeError(
                 error_class="NOT_INT",
@@ -956,7 +960,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             )
 
         if isinstance(truncate, bool) and truncate:
-            print(self._jdf.showString(n, 20, vertical))
+            return self._jdf.showString(n, 20, vertical)
         else:
             try:
                 int_truncate = int(truncate)
@@ -969,7 +973,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                     },
                 )
 
-            print(self._jdf.showString(n, int_truncate, vertical))
+            return self._jdf.showString(n, int_truncate, vertical)
 
     def __repr__(self) -> str:
         if not self._support_repr_html and 
self.sparkSession._jconf.isReplEagerEvalEnabled():
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index b867156e71a5..c3cd0f37b103 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -445,6 +445,41 @@ class CogroupedApplyInPandasTestsMixin:
         actual = df.orderBy("id", "day").take(days)
         self.assertEqual(actual, [Row(0, day, vals, vals) for day in 
range(days)])
 
+    def test_with_local_data(self):
+        df1 = self.spark.createDataFrame(
+            [(1, 1.0, "a"), (2, 2.0, "b"), (1, 3.0, "c"), (2, 4.0, "d")], 
("id", "v1", "v2")
+        )
+        df2 = self.spark.createDataFrame([(1, "x"), (2, "y"), (1, "z")], 
("id", "v3"))
+
+        def summarize(left, right):
+            return pd.DataFrame(
+                {
+                    "left_rows": [len(left)],
+                    "left_columns": [len(left.columns)],
+                    "right_rows": [len(right)],
+                    "right_columns": [len(right.columns)],
+                }
+            )
+
+        df = (
+            df1.groupby("id")
+            .cogroup(df2.groupby("id"))
+            .applyInPandas(
+                summarize,
+                schema="left_rows long, left_columns long, right_rows long, 
right_columns long",
+            )
+        )
+
+        self.assertEqual(
+            df._show_string(),
+            "+---------+------------+----------+-------------+\n"
+            "|left_rows|left_columns|right_rows|right_columns|\n"
+            "+---------+------------+----------+-------------+\n"
+            "|        2|           3|         2|            2|\n"
+            "|        2|           3|         1|            2|\n"
+            "+---------+------------+----------+-------------+\n",
+        )
+
     @staticmethod
     def _test_with_key(left, right, isLeft):
         def right_assign_key(key, lft, rgt):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to