EnricoMi commented on code in PR #40122:
URL: https://github.com/apache/spark/pull/40122#discussion_r1114504577
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -413,40 +414,35 @@ def applyInPandas(
>>> df2 = spark.createDataFrame(
... [(20000101, 1, "x"), (20000101, 2, "y")],
... ("time", "id", "v2"))
- >>> def asof_join(l, r):
- ... return pd.merge_asof(l, r, on="time", by="id")
- >>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
- ... asof_join, schema="time int, id int, v1 double, v2 string"
- ... ).show() # doctest: +SKIP
- +--------+---+---+---+
- | time| id| v1| v2|
- +--------+---+---+---+
- |20000101| 1|1.0| x|
- |20000102| 1|3.0| x|
- |20000101| 2|2.0| y|
- |20000102| 2|4.0| y|
- +--------+---+---+---+
-
- Alternatively, the user can define a function that takes three
arguments. In this case,
- the grouping key(s) will be passed as the first argument and the data
will be passed as the
- second and third arguments. The grouping key(s) will be passed as a
tuple of numpy data
- types, e.g., `numpy.int32` and `numpy.float64`. The data will still be
passed in as two
- `pandas.DataFrame` containing all columns from the original Spark
DataFrames.
-
>>> def asof_join(k, l, r):
... if k == (1,):
... return pd.merge_asof(l, r, on="time", by="id")
... else:
... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])
>>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
- ... asof_join, "time int, id int, v1 double, v2 string").show() #
doctest: +SKIP
+ ... asof_join, "time int, id int, v1 double, v2 string").show() #
doctest: +SKIP
+--------+---+---+---+
| time| id| v1| v2|
+--------+---+---+---+
|20000101| 1|1.0| x|
|20000102| 1|3.0| x|
+--------+---+---+---+
-
+ >>> df3 = spark.createDataFrame(
+ ... [(20000101, 1, "x"), (20000101, 2, "y")],
+ ... ("time", "id", "v3"))
+ >>> def asof_join_multiple(_, *dfs):
Review Comment:
example should mention `key` argument, even if it is not used inside
```suggestion
>>> def asof_join_multiple(key, *dfs):
```
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -397,12 +397,13 @@ def applyInPandas(
Parameters
----------
func : function
- a Python native function that takes two `pandas.DataFrame`\\s, and
- outputs a `pandas.DataFrame`, or that takes one tuple (grouping
keys) and two
+ a Python native function that takes one tuple (grouping keys) and
two or more
``pandas.DataFrame``\\s, and outputs a ``pandas.DataFrame``.
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type
string.
+ pass_key : : bool
+ Used to pass key to the UDF when cogrouped over more than 2
dataframes.
Review Comment:
leaked from the other PR, I think...
```suggestion
```
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -368,23 +368,23 @@ class PandasCogroupedOps:
This API is experimental.
"""
- def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
+ def __init__(self, gd1: "GroupedData", *gds: "GroupedData"):
self._gd1 = gd1
- self._gd2 = gd2
+ self._gds = gds
Review Comment:
assert `gds` is long enough
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -463,12 +459,20 @@ def applyInPandas(
# The usage of the pandas_udf is internal so type checking is disabled.
udf = pandas_udf(
- func, returnType=schema,
functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
+ func,
+ returnType=schema,
+ functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
) # type: ignore[call-overload]
- all_cols = self._extract_cols(self._gd1) +
self._extract_cols(self._gd2)
- udf_column = udf(*all_cols)
- jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd,
udf_column._jc.expr())
+ all_cols = self._extract_cols(self._gd1)
+ for gd in self._gds:
+ all_cols.extend(self._extract_cols(gd))
Review Comment:
Could be simplified to
```suggestion
all_cols = [self._extract_cols(gd) for gd in [self._gd1] + self._gds]
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala:
##########
@@ -227,7 +227,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))
- case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
+ case oldVersion@FlatMapCoGroupsInPandas(_, _, output, _)
Review Comment:
convention is spaces around `@`:
```suggestion
case oldVersion @ FlatMapCoGroupsInPandas(_, _, output, _)
```
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala:
##########
@@ -34,37 +34,32 @@ import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils
-
/**
- * Python UDF Runner for cogrouped udfs. It sends Arrow bathes from two
different DataFrames,
+ * Python UDF Runner for cogrouped udfs. It sends Arrow bathes from different
DataFrames,
* groups them in Python, and receive it back in JVM as batches of single
DataFrame.
*/
class CoGroupedArrowPythonRunner(
- funcs: Seq[ChainedPythonFunctions],
- evalType: Int,
- argOffsets: Array[Array[Int]],
- leftSchema: StructType,
- rightSchema: StructType,
- timeZoneId: String,
- conf: Map[String, String],
- val pythonMetrics: Map[String, SQLMetric])
- extends BasePythonRunner[
- (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs,
evalType, argOffsets)
- with BasicPythonArrowOutput {
+ funcs: Seq[ChainedPythonFunctions],
+ evalType: Int,
+ argOffsets: Array[Array[Int]],
+ schemas: List[StructType],
+ timeZoneId: String,
+ conf: Map[String, String],
+ val pythonMetrics: Map[String,
SQLMetric])
Review Comment:
```suggestion
funcs: Seq[ChainedPythonFunctions],
evalType: Int,
argOffsets: Array[Array[Int]],
schemas: List[StructType],
timeZoneId: String,
conf: Map[String, String],
val pythonMetrics: Map[String, SQLMetric])
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala:
##########
@@ -1365,6 +1365,32 @@ trait QuaternaryLike[T <: TreeNode[T]] { self:
TreeNode[T] =>
protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird:
T, newFourth: T): T
}
+trait NaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
Review Comment:
what is the benefit of defining `NaryLike` over `TreeNode[T]`?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala:
##########
@@ -304,7 +304,7 @@ trait FlatMapGroupsWithStateExecBase
// We apply the values for the key after applying the initial state.
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),
- valueRowIter,
+ coGroupedIterators.head,
Review Comment:
```suggestion
coGroupedIterators.head,
```
##########
sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala:
##########
@@ -582,48 +582,59 @@ class RelationalGroupedDataset protected[sql](
}
/**
- * Applies a vectorized python user-defined function to each cogrouped data.
- * The user-defined function defines a transformation:
- * `pandas.DataFrame`, `pandas.DataFrame` -> `pandas.DataFrame`.
- * For each group in the cogrouped data, all elements in the group are
passed as a
- * `pandas.DataFrame` and the results for all cogroups are combined into a
new [[DataFrame]].
+ * Applies a vectorized python user-defined function to each cogrouped data.
The user-defined
+ * function defines a transformation: `pandas.DataFrame`, `pandas.DataFrame`
->
+ * `pandas.DataFrame`. For each group in the cogrouped data, all elements in
the group are
+ * passed as a `pandas.DataFrame` and the results for all cogroups are
combined into a new
+ * [[DataFrame]].
*
* This function uses Apache Arrow as serialization format between Java
executors and Python
* workers.
*/
private[sql] def flatMapCoGroupsInPandas(
- r: RelationalGroupedDataset,
+ rs: Seq[RelationalGroupedDataset],
expr: PythonUDF): DataFrame = {
require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
"Must pass a cogrouped map udf")
- require(this.groupingExprs.length == r.groupingExprs.length,
- "Cogroup keys must have same size: " +
- s"${this.groupingExprs.length} != ${r.groupingExprs.length}")
+ val groupingExprLengthEquals = rs.map(_.groupingExprs.length).
+ forall(_ == this.groupingExprs.length)
+
+ require(groupingExprLengthEquals, s"Cogroup keys must have same size.")
Review Comment:
Could we keep the key sizes in the error message?
Something like `(this +: rs).map(_.groupingExprs.length).mkString(", ")`?
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -463,12 +459,20 @@ def applyInPandas(
# The usage of the pandas_udf is internal so type checking is disabled.
udf = pandas_udf(
- func, returnType=schema,
functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
+ func,
+ returnType=schema,
+ functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
) # type: ignore[call-overload]
- all_cols = self._extract_cols(self._gd1) +
self._extract_cols(self._gd2)
- udf_column = udf(*all_cols)
- jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd,
udf_column._jc.expr())
+ all_cols = self._extract_cols(self._gd1)
+ for gd in self._gds:
+ all_cols.extend(self._extract_cols(gd))
+ udf_column_expr = udf(*all_cols)._jc.expr()
+ assert self._gd1.session.sparkContext._jvm is not None
+ all_jgds = self._gd1.session.sparkContext._jvm.PythonUtils.toSeq(
+ [gd._jgd for gd in self._gds]
+ )
+ jdf = self._gd1._jgd.flatMapCoGroupsInPandas(all_jgds, udf_column_expr)
Review Comment:
`all` is misleading, as this refers only to `self._gds`:
```suggestion
jgds = self._gd1.session.sparkContext._jvm.PythonUtils.toSeq(
[gd._jgd for gd in self._gds]
)
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(jgds, udf_column_expr)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]