EnricoMi commented on code in PR #39640:
URL: https://github.com/apache/spark/pull/39640#discussion_r1081026350


##########
sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -171,6 +171,84 @@ class KeyValueGroupedDataset[K, V] private[sql](
     flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
   }
 
+  /**
+   * (Scala-specific)
+   * Applies the given function to each group of data.  For each unique group, 
the function will
+   * be passed the group key and a sorted iterator that contains all of the 
elements in the group.
+   * The function can return an iterator containing elements of an arbitrary 
type which will be
+   * returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result 
requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an 
aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is 
too large to fit into
+   * memory.  However, users must take care to avoid materializing the whole 
iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is 
possible given the memory
+   * constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except 
for the iterator
+   * to be sorted according to the given sort expressions. That sorting does 
not add
+   * computational complexity.
+   *
+   * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.4.0
+   */
+  def flatMapSortedGroups[U : Encoder]
+      (sortExprs: Column*)
+      (f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
+      col.expr match {
+        case expr: SortOrder => expr
+        case expr: Expression => SortOrder(expr, Ascending)
+      }
+    }
+
+    Dataset[U](
+      sparkSession,
+      MapGroups(
+        f,
+        groupingAttributes,
+        dataAttributes,
+        sortOrder,
+        logicalPlan
+      )
+    )
+  }
+
+
+  /**
+   * (Java-specific)

Review Comment:
   A `def func(cols: Column*)` itself is not Java friendly, it requires a Scala 
`Seq`, which requires something like this in Java:
   
   ```Scala
   func(JavaConverters.iterableAsScalaIterable(Arrays.asList(col("_1"), 
col("_2"))).toSeq())
   ```
   
   Annotated with `@scala.annotation.varargs`, the compiler adds a Java 
friendly version `func(Column... cols)`. But this cannot be combined with 
implict arguments (`[U : Encoder]`) or multiple parameter lists 
(`(other)(thisSortExprs)(otherSortExprs)(f)`), because the `Column*` has to be 
the last argument in the last parameter list:
   
       A method annotated with @varargs must have a single repeated parameter 
in its last parameter list.
   
   I managed to improve the Scala and Java calls in conjunction with Java 
function interfaces:
   
   Scala:
   ```Scala
       val aggregated = grouped.flatMapSortedGroups($"seq", 
expr("length(key)"))(
         (g, iter) => asJavaIterator(Iterator(g._1, iter.asScala.mkString(", 
")))
       )(Encoders.STRING)
   
       val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder: 
_*)(rightOrder: _*)(
         (key, left, right) => asJavaIterator(Iterator(
           key -> (left.asScala.map(_._2).mkString + "#" + 
right.asScala.map(_._2).mkString)
         ))
       )
   ```
   
   Java:
   ```Java
       Dataset<String> flatMapSorted = grouped.flatMapSortedGroups(
           
JavaConverters.iterableAsScalaIterable(Collections.singletonList(ds.col("value"))).toSeq(),
           (FlatMapGroupsFunction<Integer, String, String>) (key, values) -> {
             return Collections.singletonList(...).iterator();
           },
         Encoders.STRING());
   
       Dataset<String> cogroupSorted = grouped.cogroupSorted(
         grouped2,
         
JavaConverters.iterableAsScalaIterable(Collections.singletonList(ds.col("value"))).toSeq(),
         
JavaConverters.iterableAsScalaIterable(Collections.singletonList(ds2.col("value").desc())).toSeq(),
         (CoGroupFunction<Integer, String, Integer, String>) (key, left, right) 
-> {
           return Collections.singletonList(...).iterator();
         },
         Encoders.STRING()
       );
   ```
   
   Note: the `.asScala` are required in `Scala` because we are using 
`FlatMapGroupsFunction` and `CoGroupFunction`, which use Java iterators. This 
is intuitive and in contrast to the `flatMapGroups` and `cogroup` methods:
   
   ```Scala
       // flapMapGroups
       val aggregated = grouped.flatMapGroups {
         (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString)
       }
   
       // flapMapSortedGroups
       val aggregated = grouped.flatMapSortedGroups($"seq", 
expr("length(key)"))(
         (g, iter) => asJavaIterator(Iterator(g._1, iter.asScala.mkString(", 
")))
       )(Encoders.STRING)
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to