Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/18732#discussion_r142592031
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala ---
@@ -435,6 +435,33 @@ class RelationalGroupedDataset protected[sql](
df.logicalPlan.output,
df.logicalPlan))
}
+
+ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = {
+ require(expr.vectorized, "Must pass a vectorized python udf")
+
+ val output = expr.dataType match {
+ case s: StructType => s.map {
+ case StructField(name, dataType, nullable, metadata) =>
+ AttributeReference(name, dataType, nullable, metadata)()
+ }
+ }
+
+ val groupingAttributes: Seq[Attribute] = groupingExprs.map {
+ case ne: NamedExpression => ne.toAttribute
+ }
+
+ val plan = FlatMapGroupsInPandas(
+ groupingAttributes,
+ expr,
+ output,
+ df.logicalPlan
+ )
--- End diff --
little nit: I'd write it
```scala
val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output,
df.logicalPlan)
```
or
```scala
val plan = FlatMapGroupsInPandas(
groupingAttributes, expr, output, df.logicalPlan)
```
if you wouldn't mind.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]