Github user marmbrus commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5725#discussion_r29303406
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala 
---
    @@ -265,7 +283,49 @@ case class GeneratedAggregate(
     
             val resultProjection = resultProjectionBuilder()
             Iterator(resultProjection(buffer))
    +      } else if (unsafeEnabled && schemaSupportsUnsafe) {
    +        log.info("Using Unsafe-based aggregator")
    +        val aggregationMap = new UnsafeFixedWidthAggregationMap(
    +          newAggregationBuffer(EmptyRow),
    +          aggregationBufferSchema,
    +          groupKeySchema,
    +          TaskContext.get.taskMemoryManager(),
    +          1024 * 16, // initial capacity
    +          false // disable tracking of performance metrics
    +        )
    +
    +        while (iter.hasNext) {
    +          val currentRow: Row = iter.next()
    +          val groupKey: Row = groupProjection(currentRow)
    +          val aggregationBuffer = 
aggregationMap.getAggregationBuffer(groupKey)
    +          
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, 
currentRow))
    +        }
    +
    +        new Iterator[Row] {
    +          private[this] val mapIterator = aggregationMap.iterator()
    +          private[this] val resultProjection = resultProjectionBuilder()
    +
    +          def hasNext: Boolean = mapIterator.hasNext
    +
    +          def next(): Row = {
    +            val entry = mapIterator.next()
    +            val result = resultProjection(joinedRow(entry.key, 
entry.value))
    +            if (hasNext) {
    +              result
    +            } else {
    +              // This is the last element in the iterator, so let's free 
the buffer. Before we do,
    +              // though, we need to make a defensive copy of the result so 
that we don't return an
    +              // object that might contain dangling pointers to the freed 
memory
    +              val resultCopy = result.copy()
    +              aggregationMap.free()
    --- End diff --
    
    Checking my understanding here... We are safe for takes that don't exhaust 
the iterator because of  some registration that is happening due to 
`TaskContext.get.taskMemoryManager()` above?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to