This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new d422aed4a2e8 [SPARK-46684][PYTHON][CONNECT][3.5] Fix CoGroup.applyInPandas/Arrow to pass arguments properly d422aed4a2e8 is described below commit d422aed4a2e82d671a592b096919015bddeb751f Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Fri Jan 12 13:19:12 2024 +0900 [SPARK-46684][PYTHON][CONNECT][3.5] Fix CoGroup.applyInPandas/Arrow to pass arguments properly ### What changes were proposed in this pull request? This is a backport of apache/spark#44695. Fix `CoGroup.applyInPandas/Arrow` to pass arguments properly. ### Why are the changes needed? In Spark Connect, `CoGroup.applyInPandas/Arrow` doesn't take arguments properly, so the arguments of the UDF can be broken: ```py >>> import pandas as pd >>> >>> df1 = spark.createDataFrame( ... [(1, 1.0, "a"), (2, 2.0, "b"), (1, 3.0, "c"), (2, 4.0, "d")], ("id", "v1", "v2") ... ) >>> df2 = spark.createDataFrame([(1, "x"), (2, "y"), (1, "z")], ("id", "v3")) >>> >>> def summarize(left, right): ... return pd.DataFrame( ... { ... "left_rows": [len(left)], ... "left_columns": [len(left.columns)], ... "right_rows": [len(right)], ... "right_columns": [len(right.columns)], ... } ... ) ... >>> df = ( ... df1.groupby("id") ... .cogroup(df2.groupby("id")) ... .applyInPandas( ... summarize, ... schema="left_rows long, left_columns long, right_rows long, right_columns long", ... ) ... ) >>> >>> df.show() +---------+------------+----------+-------------+ |left_rows|left_columns|right_rows|right_columns| +---------+------------+----------+-------------+ | 2| 1| 2| 1| | 2| 1| 1| 1| +---------+------------+----------+-------------+ ``` The result should be: ```py +---------+------------+----------+-------------+ |left_rows|left_columns|right_rows|right_columns| +---------+------------+----------+-------------+ | 2| 3| 2| 2| | 2| 3| 1| 2| +---------+------------+----------+-------------+ ``` ### Does this PR introduce _any_ user-facing change? This is a bug fix. ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44696 from ueshin/issues/SPARK-46684/3.5/cogroup. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../sql/connect/planner/SparkConnectPlanner.scala | 32 ++++++++++---------- python/pyspark/sql/dataframe.py | 8 +++-- .../sql/tests/pandas/test_pandas_cogrouped_map.py | 35 ++++++++++++++++++++++ 3 files changed, 57 insertions(+), 18 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 50a55f5e6411..709e0811e5de 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -674,8 +674,6 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { transformTypedCoGroupMap(rel, commonUdf) case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => - val pythonUdf = transformPythonUDF(commonUdf) - val inputCols = rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr))) @@ -690,6 +688,10 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .ofRows(session, transformRelation(rel.getOther)) .groupBy(otherCols: _*) + val pythonUdf = createUserDefinedPythonFunction(commonUdf) + .builder(input.df.logicalPlan.output ++ other.df.logicalPlan.output) + .asInstanceOf[PythonUDF] + input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan case _ => @@ -1587,17 +1589,23 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { private def transformPythonFuncExpression( fun: proto.CommonInlineUserDefinedFunction): Expression = { + createUserDefinedPythonFunction(fun) + .builder(fun.getArgumentsList.asScala.map(transformExpression).toSeq) match { + case udaf: PythonUDAF => udaf.toAggregateExpression() + case other => other + } + } + + private def createUserDefinedPythonFunction( + fun: proto.CommonInlineUserDefinedFunction): UserDefinedPythonFunction = { val udf = fun.getPythonUdf + val function = transformPythonFunction(udf) UserDefinedPythonFunction( name = fun.getFunctionName, - func = transformPythonFunction(udf), + func = function, dataType = transformDataType(udf.getOutputType), pythonEvalType = udf.getEvalType, udfDeterministic = fun.getDeterministic) - .builder(fun.getArgumentsList.asScala.map(transformExpression).toSeq) match { - case udaf: PythonUDAF => udaf.toAggregateExpression() - case other => other - } } private def transformPythonFunction(fun: proto.PythonUDF): SimplePythonFunction = { @@ -2584,15 +2592,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } private def handleRegisterPythonUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = { - val udf = fun.getPythonUdf - val function = transformPythonFunction(udf) - val udpf = UserDefinedPythonFunction( - name = fun.getFunctionName, - func = function, - dataType = transformDataType(udf.getOutputType), - pythonEvalType = udf.getEvalType, - udfDeterministic = fun.getDeterministic) - + val udpf = createUserDefinedPythonFunction(fun) session.udf.registerPython(fun.getFunctionName, udpf) } diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5707ae2a31fe..7c382ab1c5a5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -942,7 +942,11 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): age | 16 name | Bob """ + print(self._show_string(n, truncate, vertical)) + def _show_string( + self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False + ) -> str: if not isinstance(n, int) or isinstance(n, bool): raise PySparkTypeError( error_class="NOT_INT", @@ -956,7 +960,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): ) if isinstance(truncate, bool) and truncate: - print(self._jdf.showString(n, 20, vertical)) + return self._jdf.showString(n, 20, vertical) else: try: int_truncate = int(truncate) @@ -969,7 +973,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): }, ) - print(self._jdf.showString(n, int_truncate, vertical)) + return self._jdf.showString(n, int_truncate, vertical) def __repr__(self) -> str: if not self._support_repr_html and self.sparkSession._jconf.isReplEagerEvalEnabled(): diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py index b867156e71a5..c3cd0f37b103 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py @@ -445,6 +445,41 @@ class CogroupedApplyInPandasTestsMixin: actual = df.orderBy("id", "day").take(days) self.assertEqual(actual, [Row(0, day, vals, vals) for day in range(days)]) + def test_with_local_data(self): + df1 = self.spark.createDataFrame( + [(1, 1.0, "a"), (2, 2.0, "b"), (1, 3.0, "c"), (2, 4.0, "d")], ("id", "v1", "v2") + ) + df2 = self.spark.createDataFrame([(1, "x"), (2, "y"), (1, "z")], ("id", "v3")) + + def summarize(left, right): + return pd.DataFrame( + { + "left_rows": [len(left)], + "left_columns": [len(left.columns)], + "right_rows": [len(right)], + "right_columns": [len(right.columns)], + } + ) + + df = ( + df1.groupby("id") + .cogroup(df2.groupby("id")) + .applyInPandas( + summarize, + schema="left_rows long, left_columns long, right_rows long, right_columns long", + ) + ) + + self.assertEqual( + df._show_string(), + "+---------+------------+----------+-------------+\n" + "|left_rows|left_columns|right_rows|right_columns|\n" + "+---------+------------+----------+-------------+\n" + "| 2| 3| 2| 2|\n" + "| 2| 3| 1| 2|\n" + "+---------+------------+----------+-------------+\n", + ) + @staticmethod def _test_with_key(left, right, isLeft): def right_assign_key(key, lft, rgt): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org