Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20806#discussion_r174277864
  
    --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
    @@ -1658,6 +1659,43 @@ class Dataset[T] private[sql](
       def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): 
KeyValueGroupedDataset[K, T] =
         groupByKey(func.call(_))(encoder)
     
    +
    +  /**
    +   * Aggregates the elements of this Dataset in a multi-level tree pattern.
    +   *
    +   * @param depth suggested depth of the tree (default: 2)
    +   */
    +  private[spark] def treeAggregate[U : Encoder : ClassTag](zeroValue: U)(
    +      seqOp: (U, T) => U,
    +      combOp: (U, U) => U,
    +      depth: Int = 2): U = {
    +    require(depth >= 1, s"Depth must be greater than or equal to 1 but got 
$depth.")
    +    val sparkContext = sparkSession.sparkContext
    +    val copiedZeroValue = Utils.clone(zeroValue, 
sparkContext.env.closureSerializer.newInstance())
    +    if (rdd.partitions.length == 0) {
    +      copiedZeroValue
    +    } else {
    +      val aggregatePartition =
    +        (it: Iterator[T]) => it.aggregate(zeroValue)(seqOp, combOp)
    +      var partiallyAggregated: Dataset[U] = mapPartitions(it => 
Iterator(aggregatePartition(it)))
    --- End diff --
    
    Why can't we call `rdd.treeAggregate` directly?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to