Repository: spark
Updated Branches:
  refs/heads/branch-2.2 6641aa620 -> 124789b62


[SPARK-25144][SQL][TEST][BRANCH-2.2] Free aggregate map when task ends

## What changes were proposed in this pull request?

[SPARK-25144](https://issues.apache.org/jira/browse/SPARK-25144) reports memory 
leaks on Apache Spark 2.0.2 ~ 2.3.2-RC5.

```scala
scala> case class Foo(bar: Option[String])
scala> val ds = List(Foo(Some("bar"))).toDS
scala> val result = ds.flatMap(_.bar).distinct
scala> result.rdd.isEmpty
18/08/19 23:01:54 WARN Executor: Managed memory leak detected; size = 8650752 
bytes, TID = 125
res0: Boolean = false
```

This is a backport of cloud-fan 's https://github.com/apache/spark/pull/21738 
which is a single commit among 3 commits of SPARK-21743. In addition, I added a 
test case to prevent regressions in branch-2.3 and branch-2.2. Although 
SPARK-21743 is reverted due to regression, this subpatch can go to branch-2.3 
and branch-2.2. This will be merged as cloud-fan 's commit.

## How was this patch tested?

Pass the jenkins with a newly added test case.

Closes #22156 from dongjoon-hyun/SPARK-25144-2.2.

Authored-by: Wenchen Fan <wenc...@databricks.com>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/124789b6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/124789b6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/124789b6

Branch: refs/heads/branch-2.2
Commit: 124789b62583c6e5c7d427207394c572b6911579
Parents: 6641aa6
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Aug 21 09:07:27 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Tue Aug 21 09:07:27 2018 +0800

----------------------------------------------------------------------
 .../UnsafeFixedWidthAggregationMap.java         | 17 +++++++++++-----
 .../execution/aggregate/HashAggregateExec.scala |  2 +-
 .../aggregate/TungstenAggregationIterator.scala |  2 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  8 ++++++++
 .../UnsafeFixedWidthAggregationMapSuite.scala   | 21 ++++++++++++--------
 5 files changed, 35 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/124789b6/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index cd521c5..4299cc8 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution;
 import java.io.IOException;
 
 import org.apache.spark.SparkEnv;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.TaskContext;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -84,7 +84,7 @@ public final class UnsafeFixedWidthAggregationMap {
    * @param emptyAggregationBuffer the default value for new keys (a "zero" of 
the agg. function)
    * @param aggregationBufferSchema the schema of the aggregation buffer, used 
for row conversion.
    * @param groupingKeySchema the schema of the grouping key, used for row 
conversion.
-   * @param taskMemoryManager the memory manager used to allocate our Unsafe 
memory structures.
+   * @param taskContext the current task context.
    * @param initialCapacity the initial capacity of the map (a sizing hint to 
avoid re-hashing).
    * @param pageSizeBytes the data page size, in bytes; limits the maximum 
record size.
    * @param enablePerfMetrics if true, performance metrics will be recorded 
(has minor perf impact)
@@ -93,7 +93,7 @@ public final class UnsafeFixedWidthAggregationMap {
       InternalRow emptyAggregationBuffer,
       StructType aggregationBufferSchema,
       StructType groupingKeySchema,
-      TaskMemoryManager taskMemoryManager,
+      TaskContext taskContext,
       int initialCapacity,
       long pageSizeBytes,
       boolean enablePerfMetrics) {
@@ -101,13 +101,20 @@ public final class UnsafeFixedWidthAggregationMap {
     this.currentAggregationBuffer = new 
UnsafeRow(aggregationBufferSchema.length());
     this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
     this.groupingKeySchema = groupingKeySchema;
-    this.map =
-      new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, 
enablePerfMetrics);
+    this.map = new BytesToBytesMap(
+      taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, 
enablePerfMetrics);
     this.enablePerfMetrics = enablePerfMetrics;
 
     // Initialize the buffer for aggregation value
     final UnsafeProjection valueProjection = 
UnsafeProjection.create(aggregationBufferSchema);
     this.emptyAggregationBuffer = 
valueProjection.apply(emptyAggregationBuffer).getBytes();
+
+    // Register a cleanup task with TaskContext to ensure that memory is 
guaranteed to be freed at
+    // the end of the task. This is necessary to avoid memory leaks in when 
the downstream operator
+    // does not fully consume the aggregation map's output (e.g. aggregate 
followed by limit).
+    taskContext.addTaskCompletionListener(context -> {
+      free();
+    });
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/124789b6/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 68c8e6c..8e0e27f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -310,7 +310,7 @@ case class HashAggregateExec(
       initialBuffer,
       bufferSchema,
       groupingKeySchema,
-      TaskContext.get().taskMemoryManager(),
+      TaskContext.get(),
       1024 * 16, // initial capacity
       TaskContext.get().taskMemoryManager().pageSizeBytes,
       false // disable tracking of performance metrics

http://git-wip-us.apache.org/repos/asf/spark/blob/124789b6/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 2988161..670c33d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -160,7 +160,7 @@ class TungstenAggregationIterator(
     initialAggregationBuffer,
     
StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)),
     StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
-    TaskContext.get().taskMemoryManager(),
+    TaskContext.get(),
     1024 * 16, // initial capacity
     TaskContext.get().taskMemoryManager().pageSizeBytes,
     false // disable tracking of performance metrics

http://git-wip-us.apache.org/repos/asf/spark/blob/124789b6/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 95d8c86..d2b17a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2702,4 +2702,12 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
       }
     }
   }
+
+  test("SPARK-25144 'distinct' causes memory leak") {
+    val ds = List(Foo(Some("bar"))).toDS
+    val result = ds.flatMap(_.bar).distinct
+    result.rdd.isEmpty
+  }
 }
+
+case class Foo(bar: Option[String])

http://git-wip-us.apache.org/repos/asf/spark/blob/124789b6/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 6c222a0..3b77657 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -23,6 +23,7 @@ import scala.collection.mutable
 import scala.util.{Random, Try}
 import scala.util.control.NonFatal
 
+import org.mockito.Mockito._
 import org.scalatest.Matchers
 
 import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, 
TaskContextImpl}
@@ -53,6 +54,8 @@ class UnsafeFixedWidthAggregationMapSuite
   private var memoryManager: TestMemoryManager = null
   private var taskMemoryManager: TaskMemoryManager = null
 
+  private var taskContext: TaskContext = null
+
   def testWithMemoryLeakDetection(name: String)(f: => Unit) {
     def cleanup(): Unit = {
       if (taskMemoryManager != null) {
@@ -66,6 +69,8 @@ class UnsafeFixedWidthAggregationMapSuite
       val conf = new SparkConf().set("spark.memory.offHeap.enabled", "false")
       memoryManager = new TestMemoryManager(conf)
       taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
+      taskContext = mock(classOf[TaskContext])
+      when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
 
       TaskContext.setTaskContext(new TaskContextImpl(
         stageId = 0,
@@ -110,7 +115,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       1024, // initial capacity,
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       1024, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -177,7 +182,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -225,7 +230,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -266,7 +271,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       StructType(Nil),
       StructType(Nil),
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -311,7 +316,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       pageSize,
       false // disable perf metrics
@@ -349,7 +354,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       pageSize,
       false // disable perf metrics


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to