Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/20211#discussion_r160599447
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
---
@@ -80,27 +84,77 @@ case class FlatMapGroupsInPandasExec(
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
- inputRDD.mapPartitionsInternal { iter =>
- val grouped = if (groupingAttributes.isEmpty) {
- Iterator(iter)
- } else {
+ if (additionalGroupingAttributes.isEmpty) {
+ // Fast path if additional grouping attributes is empty
+
+ inputRDD.mapPartitionsInternal { iter =>
+ val grouped = if (groupingAttributes.isEmpty) {
+ Iterator(iter)
+ } else {
+ val groupedIter = GroupedIterator(iter, groupingAttributes,
child.output)
+ val dropGrouping =
+
UnsafeProjection.create(child.output.drop(groupingAttributes.length),
child.output)
+ groupedIter.map {
+ case (_, groupedRowIter) => groupedRowIter.map(dropGrouping)
+ }
+ }
+
+ val context = TaskContext.get()
+
+ val columnarBatchIter = new ArrowPythonRunner(
+ chainedFunc, bufferSize, reuseWorker,
+ PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema,
+ sessionLocalTimeZone, pandasRespectSessionTimeZone)
+ .compute(grouped, context.partitionId(), context)
+
+ columnarBatchIter
+ .flatMap(_.rowIterator.asScala)
+ .map(UnsafeProjection.create(output, output))
+ }
+ } else {
+ // If additionGroupingAttributes is not empty, join the grouping
attributes with
+ // the udf output to get the final result
+
+ inputRDD.mapPartitionsInternal { iter =>
+ assert(groupingAttributes.nonEmpty)
+
val groupedIter = GroupedIterator(iter, groupingAttributes,
child.output)
+
+ val context = TaskContext.get()
+
+ val queue = HybridRowQueue(context.taskMemoryManager(),
+ new File(Utils.getLocalDir(SparkEnv.get.conf)),
additionalGroupingAttributes.length)
+ context.addTaskCompletionListener { _ =>
+ queue.close()
+ }
+ val additionalGroupingProj = UnsafeProjection.create(
+ additionalGroupingAttributes, groupingAttributes)
val dropGrouping =
UnsafeProjection.create(child.output.drop(groupingAttributes.length),
child.output)
- groupedIter.map {
- case (_, groupedRowIter) => groupedRowIter.map(dropGrouping)
+ val grouped = groupedIter.map {
+ case (k, groupedRowIter) =>
+ val additionalGrouping = additionalGroupingProj(k)
+ queue.add(additionalGrouping)
+ (additionalGrouping, groupedRowIter.map(dropGrouping))
--- End diff --
We can return only `groupedRowIter.map(dropGrouping)`.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]