Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/20806#discussion_r174358970 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -1658,6 +1659,43 @@ class Dataset[T] private[sql]( def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = groupByKey(func.call(_))(encoder) + + /** + * Aggregates the elements of this Dataset in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + */ + private[spark] def treeAggregate[U : Encoder : ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2): U = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + val sparkContext = sparkSession.sparkContext + val copiedZeroValue = Utils.clone(zeroValue, sparkContext.env.closureSerializer.newInstance()) + if (rdd.partitions.length == 0) { + copiedZeroValue + } else { + val aggregatePartition = + (it: Iterator[T]) => it.aggregate(zeroValue)(seqOp, combOp) + var partiallyAggregated: Dataset[U] = mapPartitions(it => Iterator(aggregatePartition(it))) --- End diff -- Since this benchmark should be performed on a cluster not a local machine, I will try to run a benchmark later once I set up one.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org