EnricoMi commented on code in PR #39640:
URL: https://github.com/apache/spark/pull/39640#discussion_r1081026350
##########
sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -171,6 +171,84 @@ class KeyValueGroupedDataset[K, V] private[sql](
flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
}
+ /**
+ * (Scala-specific)
+ * Applies the given function to each group of data. For each unique group,
the function will
+ * be passed the group key and a sorted iterator that contains all of the
elements in the group.
+ * The function can return an iterator containing elements of an arbitrary
type which will be
+ * returned as a new [[Dataset]].
+ *
+ * This function does not support partial aggregation, and as a result
requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an
aggregation over each
+ * key, it is best to use the reduce function or an
+ * `org.apache.spark.sql.expressions#Aggregator`.
+ *
+ * Internally, the implementation will spill to disk if any given group is
too large to fit into
+ * memory. However, users must take care to avoid materializing the whole
iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is
possible given the memory
+ * constraints of their cluster.
+ *
+ * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except
for the iterator
+ * to be sorted according to the given sort expressions. That sorting does
not add
+ * computational complexity.
+ *
+ * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+ * @since 3.4.0
+ */
+ def flatMapSortedGroups[U : Encoder]
+ (sortExprs: Column*)
+ (f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+ val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
+ col.expr match {
+ case expr: SortOrder => expr
+ case expr: Expression => SortOrder(expr, Ascending)
+ }
+ }
+
+ Dataset[U](
+ sparkSession,
+ MapGroups(
+ f,
+ groupingAttributes,
+ dataAttributes,
+ sortOrder,
+ logicalPlan
+ )
+ )
+ }
+
+
+ /**
+ * (Java-specific)
Review Comment:
A `def func(cols: Column*)` itself is not Java friendly, it requires a Scala
`Seq`, which requires something like this in Java:
```Scala
func(JavaConverters.iterableAsScalaIterable(Arrays.asList(col("_1"),
col("_2"))).toSeq())
```
Annotated with `@scala.annotation.varargs`, the compiler adds a Java
friendly version `func(Column... cols)`. But this cannot be combined with
implict arguments (`[U : Encoder]`) or multiple parameter lists
(`(other)(thisSortExprs)(otherSortExprs)(f)`), because the `Column*` has to be
the last argument in the last parameter list:
A method annotated with @varargs must have a single repeated parameter
in its last parameter list.
I managed to improve the Scala and Java calls in conjunction with Java
function interfaces:
Scala:
```Scala
val aggregated = grouped.flatMapSortedGroups($"seq",
expr("length(key)"))(
(g, iter) => asJavaIterator(Iterator(g._1, iter.asScala.mkString(",
")))
)(Encoders.STRING)
val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder:
_*)(rightOrder: _*)(
(key, left, right) => asJavaIterator(Iterator(
key -> (left.asScala.map(_._2).mkString + "#" +
right.asScala.map(_._2).mkString)
))
)
```
Java:
```Java
Dataset<String> flatMapSorted = grouped.flatMapSortedGroups(
JavaConverters.iterableAsScalaIterable(Collections.singletonList(ds.col("value"))).toSeq(),
(FlatMapGroupsFunction<Integer, String, String>) (key, values) -> {
return Collections.singletonList(...).iterator();
},
Encoders.STRING());
Dataset<String> cogroupSorted = grouped.cogroupSorted(
grouped2,
JavaConverters.iterableAsScalaIterable(Collections.singletonList(ds.col("value"))).toSeq(),
JavaConverters.iterableAsScalaIterable(Collections.singletonList(ds2.col("value").desc())).toSeq(),
(CoGroupFunction<Integer, String, Integer, String>) (key, left, right)
-> {
return Collections.singletonList(...).iterator();
},
Encoders.STRING()
);
```
Note: the `.asScala` are required in `Scala` because we are using
`FlatMapGroupsFunction` and `CoGroupFunction`, which use Java iterators. This
is intuitive and in contrast to the `flatMapGroups` and `cogroup` methods:
```Scala
// flapMapGroups
val aggregated = grouped.flatMapGroups {
(g, iter) => Iterator(g._1, iter.map(_._2).sum.toString)
}
// flapMapSortedGroups
val aggregated = grouped.flatMapSortedGroups($"seq",
expr("length(key)"))(
(g, iter) => asJavaIterator(Iterator(g._1, iter.asScala.mkString(",
")))
)(Encoders.STRING)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]