Repository: spark Updated Branches: refs/heads/master 481f07929 -> 86664338f
[SPARK-17528][SQL][FOLLOWUP] remove unnecessary data copy in object hash aggregate ## What changes were proposed in this pull request? In #18483 , we fixed the data copy bug when saving into `InternalRow`, and removed all workarounds for this bug in the aggregate code path. However, the object hash aggregate was missed, this PR fixes it. This patch is also a requirement for #17419 , which shows that DataFrame version is slower than RDD version because of this issue. ## How was this patch tested? existing tests Author: Wenchen Fan <wenc...@databricks.com> Closes #18712 from cloud-fan/minor. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/86664338 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/86664338 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/86664338 Branch: refs/heads/master Commit: 86664338f25f58b2f59db93b68cd57de671a4c0b Parents: 481f079 Author: Wenchen Fan <wenc...@databricks.com> Authored: Mon Jul 24 10:18:28 2017 -0700 Committer: Cheng Lian <l...@databricks.com> Committed: Mon Jul 24 10:18:28 2017 -0700 ---------------------------------------------------------------------- .../aggregate/ObjectAggregationIterator.scala | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/86664338/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 6e47f9d..eef2c4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -70,10 +70,6 @@ class ObjectAggregationIterator( generateProcessRow(newExpressions, newFunctions, newInputAttributes) } - // A safe projection used to do deep clone of input rows to prevent false sharing. - private[this] val safeProjection: Projection = - FromUnsafeProjection(outputAttributes.map(_.dataType)) - /** * Start processing input rows. */ @@ -151,12 +147,11 @@ class ObjectAggregationIterator( val groupingKey = groupingProjection.apply(null) val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey) while (inputRows.hasNext) { - val newInput = safeProjection(inputRows.next()) - processRow(buffer, newInput) + processRow(buffer, inputRows.next()) } } else { while (inputRows.hasNext && !sortBased) { - val newInput = safeProjection(inputRows.next()) + val newInput = inputRows.next() val groupingKey = groupingProjection.apply(newInput) val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey) processRow(buffer, newInput) @@ -266,9 +261,7 @@ class SortBasedAggregator( // Firstly, update the aggregation buffer with input rows. while (hasNextInput && groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) { - // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be - // overwritten when `inputIterator` steps forward, we need to do a deep copy here. - processRow(result.aggregationBuffer, inputIterator.getValue.copy()) + processRow(result.aggregationBuffer, inputIterator.getValue) hasNextInput = inputIterator.next() } @@ -277,12 +270,7 @@ class SortBasedAggregator( // be called after calling processRow. while (hasNextAggBuffer && groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) { - mergeAggregationBuffers( - result.aggregationBuffer, - // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be - // overwritten when `inputIterator` steps forward, we need to do a deep copy here. - initialAggBufferIterator.getValue.copy() - ) + mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue) hasNextAggBuffer = initialAggBufferIterator.next() } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org