cloud-fan commented on a change in pull request #32242:
URL: https://github.com/apache/spark/pull/32242#discussion_r618141861
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -687,40 +725,59 @@ case class HashAggregateExec(
val thisPlan = ctx.addReferenceObj("plan", this)
- // Create a name for the iterator from the fast hash map, and the code to
create fast hash map.
- val (iterTermForFastHashMap, createFastHashMap) = if
(isFastHashMapEnabled) {
- // Generates the fast hash map class and creates the fast hash map term.
- val fastHashMapClassName = ctx.freshName("FastHashMap")
- if (isVectorizedHashMapEnabled) {
- val generatedMap = new VectorizedHashMapGenerator(ctx,
aggregateExpressions,
- fastHashMapClassName, groupingKeySchema, bufferSchema,
bitMaxCapacity).generate()
- ctx.addInnerClass(generatedMap)
-
- // Inline mutable state since not many aggregation operations in a task
- fastHashMapTerm = ctx.addMutableState(
- fastHashMapClassName, "vectorizedFastHashMap", forceInline = true)
- val iter = ctx.addMutableState(
- "java.util.Iterator<InternalRow>",
- "vectorizedFastHashMapIter",
- forceInline = true)
- val create = s"$fastHashMapTerm = new $fastHashMapClassName();"
- (iter, create)
- } else {
- val generatedMap = new RowBasedHashMapGenerator(ctx,
aggregateExpressions,
- fastHashMapClassName, groupingKeySchema, bufferSchema,
bitMaxCapacity).generate()
- ctx.addInnerClass(generatedMap)
-
- // Inline mutable state since not many aggregation operations in a task
- fastHashMapTerm = ctx.addMutableState(
- fastHashMapClassName, "fastHashMap", forceInline = true)
- val iter = ctx.addMutableState(
- "org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
- "fastHashMapIter", forceInline = true)
- val create = s"$fastHashMapTerm = new $fastHashMapClassName(" +
- s"$thisPlan.getTaskMemoryManager(),
$thisPlan.getEmptyAggregationBuffer());"
- (iter, create)
- }
- } else ("", "")
+ // Create a name for the iterator from the fast hash map, the code to
create
+ // and add hook to close fast hash map.
+ val (iterTermForFastHashMap, createFastHashMap, addHookToCloseFastHashMap)
=
+ if (isFastHashMapEnabled) {
+ // Generates the fast hash map class and creates the fast hash map
term.
+ val fastHashMapClassName = ctx.freshName("FastHashMap")
+ val (iter, create) = if (isVectorizedHashMapEnabled) {
+ val generatedMap = new VectorizedHashMapGenerator(ctx,
aggregateExpressions,
+ fastHashMapClassName, groupingKeySchema, bufferSchema,
bitMaxCapacity).generate()
+ ctx.addInnerClass(generatedMap)
+
+ // Inline mutable state since not many aggregation operations in a
task
+ fastHashMapTerm = ctx.addMutableState(
+ fastHashMapClassName, "vectorizedFastHashMap", forceInline = true)
+ val iter = ctx.addMutableState(
+ "java.util.Iterator<InternalRow>",
+ "vectorizedFastHashMapIter",
+ forceInline = true)
+ val create = s"$fastHashMapTerm = new $fastHashMapClassName();"
+ (iter, create)
+ } else {
+ val generatedMap = new RowBasedHashMapGenerator(ctx,
aggregateExpressions,
+ fastHashMapClassName, groupingKeySchema, bufferSchema,
bitMaxCapacity).generate()
+ ctx.addInnerClass(generatedMap)
+
+ // Inline mutable state since not many aggregation operations in a
task
+ fastHashMapTerm = ctx.addMutableState(
+ fastHashMapClassName, "fastHashMap", forceInline = true)
+ val iter = ctx.addMutableState(
+ "org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
+ "fastHashMapIter", forceInline = true)
+ val create = s"$fastHashMapTerm = new $fastHashMapClassName(" +
+ s"$thisPlan.getTaskContext().taskMemoryManager(), " +
+ s"$thisPlan.getEmptyAggregationBuffer());"
+ (iter, create)
+ }
+
+ // Generates the code to register a cleanup task with TaskContext to
ensure that memory
+ // is guaranteed to be freed at the end of the task. This is necessary
to avoid memory
+ // leaks in when the downstream operator does not fully consume the
aggregation map's
+ // output (e.g. aggregate followed by limit).
+ val hookToCloseFastHashMap =
+ s"""
+ |$thisPlan.getTaskContext().addTaskCompletionListener(
+ | new org.apache.spark.util.TaskCompletionListener() {
+ | @Override
+ | public void onTaskCompletion(org.apache.spark.TaskContext
context) {
+ | $fastHashMapTerm.close();
+ | }
+ |});
+ """.stripMargin
+ (iter, create, hookToCloseFastHashMap)
Review comment:
or we put the logic in the base `HashMapGenerator` as a method, and call
the method in both the vectorized and row-based fast hash map.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]