EnricoMi commented on code in PR #40122:
URL: https://github.com/apache/spark/pull/40122#discussion_r1114504577


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -413,40 +414,35 @@ def applyInPandas(
         >>> df2 = spark.createDataFrame(
         ...     [(20000101, 1, "x"), (20000101, 2, "y")],
         ...     ("time", "id", "v2"))
-        >>> def asof_join(l, r):
-        ...     return pd.merge_asof(l, r, on="time", by="id")
-        >>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
-        ...     asof_join, schema="time int, id int, v1 double, v2 string"
-        ... ).show()  # doctest: +SKIP
-        +--------+---+---+---+
-        |    time| id| v1| v2|
-        +--------+---+---+---+
-        |20000101|  1|1.0|  x|
-        |20000102|  1|3.0|  x|
-        |20000101|  2|2.0|  y|
-        |20000102|  2|4.0|  y|
-        +--------+---+---+---+
-
-        Alternatively, the user can define a function that takes three 
arguments.  In this case,
-        the grouping key(s) will be passed as the first argument and the data 
will be passed as the
-        second and third arguments.  The grouping key(s) will be passed as a 
tuple of numpy data
-        types, e.g., `numpy.int32` and `numpy.float64`. The data will still be 
passed in as two
-        `pandas.DataFrame` containing all columns from the original Spark 
DataFrames.
-
         >>> def asof_join(k, l, r):
         ...     if k == (1,):
         ...         return pd.merge_asof(l, r, on="time", by="id")
         ...     else:
         ...         return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])
         >>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
-        ...     asof_join, "time int, id int, v1 double, v2 string").show()  # 
doctest: +SKIP
+        ...     asof_join, "time int, id int, v1 double, v2 string").show() # 
doctest: +SKIP
         +--------+---+---+---+
         |    time| id| v1| v2|
         +--------+---+---+---+
         |20000101|  1|1.0|  x|
         |20000102|  1|3.0|  x|
         +--------+---+---+---+
-
+        >>> df3 = spark.createDataFrame(
+        ...     [(20000101, 1, "x"), (20000101, 2, "y")],
+        ...     ("time", "id", "v3"))
+        >>> def asof_join_multiple(_, *dfs):

Review Comment:
   example should mention `key` argument, even if it is not used inside
   ```suggestion
           >>> def asof_join_multiple(key, *dfs):
   ```



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -397,12 +397,13 @@ def applyInPandas(
         Parameters
         ----------
         func : function
-            a Python native function that takes two `pandas.DataFrame`\\s, and
-            outputs a `pandas.DataFrame`, or that takes one tuple (grouping 
keys) and two
+            a Python native function that takes one tuple (grouping keys) and 
two or more
             ``pandas.DataFrame``\\s, and outputs a ``pandas.DataFrame``.
         schema : :class:`pyspark.sql.types.DataType` or str
             the return type of the `func` in PySpark. The value can be either a
             :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+        pass_key : : bool
+             Used to pass key to the UDF when cogrouped over more than 2 
dataframes.

Review Comment:
   leaked from the other PR, I think...
   ```suggestion
   ```



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -368,23 +368,23 @@ class PandasCogroupedOps:
     This API is experimental.
     """
 
-    def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
+    def __init__(self, gd1: "GroupedData", *gds: "GroupedData"):
         self._gd1 = gd1
-        self._gd2 = gd2
+        self._gds = gds

Review Comment:
   assert `gds` is long enough



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -463,12 +459,20 @@ def applyInPandas(
 
         # The usage of the pandas_udf is internal so type checking is disabled.
         udf = pandas_udf(
-            func, returnType=schema, 
functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
+            func,
+            returnType=schema,
+            functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
         )  # type: ignore[call-overload]
 
-        all_cols = self._extract_cols(self._gd1) + 
self._extract_cols(self._gd2)
-        udf_column = udf(*all_cols)
-        jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, 
udf_column._jc.expr())
+        all_cols = self._extract_cols(self._gd1)
+        for gd in self._gds:
+            all_cols.extend(self._extract_cols(gd))

Review Comment:
   Could be simplified to
   ```suggestion
           all_cols = [self._extract_cols(gd) for gd in [self._gd1] + self._gds]
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala:
##########
@@ -227,7 +227,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
         newVersion.copyTagsFrom(oldVersion)
         Seq((oldVersion, newVersion))
 
-      case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
+      case oldVersion@FlatMapCoGroupsInPandas(_, _, output, _)

Review Comment:
   convention is spaces around `@`:
   ```suggestion
         case oldVersion @ FlatMapCoGroupsInPandas(_, _, output, _)
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala:
##########
@@ -34,37 +34,32 @@ import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.util.Utils
 
-
 /**
- * Python UDF Runner for cogrouped udfs. It sends Arrow bathes from two 
different DataFrames,
+ * Python UDF Runner for cogrouped udfs. It sends Arrow bathes from different 
DataFrames,
  * groups them in Python, and receive it back in JVM as batches of single 
DataFrame.
  */
 class CoGroupedArrowPythonRunner(
-    funcs: Seq[ChainedPythonFunctions],
-    evalType: Int,
-    argOffsets: Array[Array[Int]],
-    leftSchema: StructType,
-    rightSchema: StructType,
-    timeZoneId: String,
-    conf: Map[String, String],
-    val pythonMetrics: Map[String, SQLMetric])
-  extends BasePythonRunner[
-    (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, 
evalType, argOffsets)
-  with BasicPythonArrowOutput {
+                                       funcs: Seq[ChainedPythonFunctions],
+                                       evalType: Int,
+                                       argOffsets: Array[Array[Int]],
+                                       schemas: List[StructType],
+                                       timeZoneId: String,
+                                       conf: Map[String, String],
+                                       val pythonMetrics: Map[String, 
SQLMetric])

Review Comment:
   ```suggestion
       funcs: Seq[ChainedPythonFunctions],
       evalType: Int,
       argOffsets: Array[Array[Int]],
       schemas: List[StructType],
       timeZoneId: String,
       conf: Map[String, String],
       val pythonMetrics: Map[String, SQLMetric])
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala:
##########
@@ -1365,6 +1365,32 @@ trait QuaternaryLike[T <: TreeNode[T]] { self: 
TreeNode[T] =>
   protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird: 
T, newFourth: T): T
 }
 
+trait NaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>

Review Comment:
   what is the benefit of defining `NaryLike` over `TreeNode[T]`?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala:
##########
@@ -304,7 +304,7 @@ trait FlatMapGroupsWithStateExecBase
           // We apply the values for the key after applying the initial state.
           callFunctionAndUpdateState(
             stateManager.getState(store, keyUnsafeRow),
-              valueRowIter,
+            coGroupedIterators.head,

Review Comment:
   ```suggestion
                 coGroupedIterators.head,
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala:
##########
@@ -582,48 +582,59 @@ class RelationalGroupedDataset protected[sql](
   }
 
   /**
-   * Applies a vectorized python user-defined function to each cogrouped data.
-   * The user-defined function defines a transformation:
-   * `pandas.DataFrame`, `pandas.DataFrame` -> `pandas.DataFrame`.
-   *  For each group in the cogrouped data, all elements in the group are 
passed as a
-   * `pandas.DataFrame` and the results for all cogroups are combined into a 
new [[DataFrame]].
+   * Applies a vectorized python user-defined function to each cogrouped data. 
The user-defined
+   * function defines a transformation: `pandas.DataFrame`, `pandas.DataFrame` 
->
+   * `pandas.DataFrame`. For each group in the cogrouped data, all elements in 
the group are
+   * passed as a `pandas.DataFrame` and the results for all cogroups are 
combined into a new
+   * [[DataFrame]].
    *
    * This function uses Apache Arrow as serialization format between Java 
executors and Python
    * workers.
    */
   private[sql] def flatMapCoGroupsInPandas(
-      r: RelationalGroupedDataset,
+      rs: Seq[RelationalGroupedDataset],
       expr: PythonUDF): DataFrame = {
     require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
       "Must pass a cogrouped map udf")
-    require(this.groupingExprs.length == r.groupingExprs.length,
-      "Cogroup keys must have same size: " +
-        s"${this.groupingExprs.length} != ${r.groupingExprs.length}")
+    val groupingExprLengthEquals = rs.map(_.groupingExprs.length).
+      forall(_ == this.groupingExprs.length)
+
+    require(groupingExprLengthEquals, s"Cogroup keys must have same size.")

Review Comment:
   Could we keep the key sizes in the error message?
   
   Something like `(this +: rs).map(_.groupingExprs.length).mkString(", ")`?



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -463,12 +459,20 @@ def applyInPandas(
 
         # The usage of the pandas_udf is internal so type checking is disabled.
         udf = pandas_udf(
-            func, returnType=schema, 
functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
+            func,
+            returnType=schema,
+            functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
         )  # type: ignore[call-overload]
 
-        all_cols = self._extract_cols(self._gd1) + 
self._extract_cols(self._gd2)
-        udf_column = udf(*all_cols)
-        jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, 
udf_column._jc.expr())
+        all_cols = self._extract_cols(self._gd1)
+        for gd in self._gds:
+            all_cols.extend(self._extract_cols(gd))
+        udf_column_expr = udf(*all_cols)._jc.expr()
+        assert self._gd1.session.sparkContext._jvm is not None
+        all_jgds = self._gd1.session.sparkContext._jvm.PythonUtils.toSeq(
+            [gd._jgd for gd in self._gds]
+        )
+        jdf = self._gd1._jgd.flatMapCoGroupsInPandas(all_jgds, udf_column_expr)

Review Comment:
   `all` is misleading, as this refers only to `self._gds`:
   ```suggestion
           jgds = self._gd1.session.sparkContext._jvm.PythonUtils.toSeq(
               [gd._jgd for gd in self._gds]
           )
           jdf = self._gd1._jgd.flatMapCoGroupsInPandas(jgds, udf_column_expr)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to