Repository: spark
Updated Branches:
  refs/heads/master 70bcc9d5a -> f340b6b30


[SPARK-22997] Add additional defenses against use of freed MemoryBlocks

## What changes were proposed in this pull request?

This patch modifies Spark's `MemoryAllocator` implementations so that 
`free(MemoryBlock)` mutates the passed block to clear pointers (in the off-heap 
case) or null out references to backing `long[]` arrays (in the on-heap case). 
The goal of this change is to add an extra layer of defense against 
use-after-free bugs because currently it's hard to detect corruption caused by 
blind writes to freed memory blocks.

## How was this patch tested?

New unit tests in `PlatformSuite`, including new tests for existing 
functionality because we did not have sufficient mutation coverage of the 
on-heap memory allocator's pooling logic.

Author: Josh Rosen <[email protected]>

Closes #20191 from 
JoshRosen/SPARK-22997-add-defenses-against-use-after-free-bugs-in-memory-allocator.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f340b6b3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f340b6b3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f340b6b3

Branch: refs/heads/master
Commit: f340b6b3066033d40b7e163fd5fb68e9820adfb1
Parents: 70bcc9d
Author: Josh Rosen <[email protected]>
Authored: Wed Jan 10 00:45:47 2018 -0800
Committer: Josh Rosen <[email protected]>
Committed: Wed Jan 10 00:45:47 2018 -0800

----------------------------------------------------------------------
 .../unsafe/memory/HeapMemoryAllocator.java      | 35 ++++++++++----
 .../apache/spark/unsafe/memory/MemoryBlock.java | 21 +++++++-
 .../unsafe/memory/UnsafeMemoryAllocator.java    | 11 +++++
 .../apache/spark/unsafe/PlatformUtilSuite.java  | 50 +++++++++++++++++++-
 .../apache/spark/memory/TaskMemoryManager.java  | 13 ++++-
 .../spark/memory/TaskMemoryManagerSuite.java    | 29 ++++++++++++
 6 files changed, 146 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
----------------------------------------------------------------------
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
index cc9cc42..3acfe36 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -31,8 +31,7 @@ import org.apache.spark.unsafe.Platform;
 public class HeapMemoryAllocator implements MemoryAllocator {
 
   @GuardedBy("this")
-  private final Map<Long, LinkedList<WeakReference<MemoryBlock>>> 
bufferPoolsBySize =
-    new HashMap<>();
+  private final Map<Long, LinkedList<WeakReference<long[]>>> bufferPoolsBySize 
= new HashMap<>();
 
   private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024;
 
@@ -49,13 +48,14 @@ public class HeapMemoryAllocator implements MemoryAllocator 
{
   public MemoryBlock allocate(long size) throws OutOfMemoryError {
     if (shouldPool(size)) {
       synchronized (this) {
-        final LinkedList<WeakReference<MemoryBlock>> pool = 
bufferPoolsBySize.get(size);
+        final LinkedList<WeakReference<long[]>> pool = 
bufferPoolsBySize.get(size);
         if (pool != null) {
           while (!pool.isEmpty()) {
-            final WeakReference<MemoryBlock> blockReference = pool.pop();
-            final MemoryBlock memory = blockReference.get();
-            if (memory != null) {
-              assert (memory.size() == size);
+            final WeakReference<long[]> arrayReference = pool.pop();
+            final long[] array = arrayReference.get();
+            if (array != null) {
+              assert (array.length * 8L >= size);
+              MemoryBlock memory = new MemoryBlock(array, 
Platform.LONG_ARRAY_OFFSET, size);
               if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
                 memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
               }
@@ -76,18 +76,35 @@ public class HeapMemoryAllocator implements MemoryAllocator 
{
 
   @Override
   public void free(MemoryBlock memory) {
+    assert (memory.obj != null) :
+      "baseObject was null; are you trying to use the on-heap allocator to 
free off-heap memory?";
+    assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+      "page has already been freed";
+    assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER)
+            || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) :
+      "TMM-allocated pages must first be freed via TMM.freePage(), not 
directly in allocator free()";
+
     final long size = memory.size();
     if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
       memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
     }
+
+    // Mark the page as freed (so we can detect double-frees).
+    memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER;
+
+    // As an additional layer of defense against use-after-free bugs, we 
mutate the
+    // MemoryBlock to null out its reference to the long[] array.
+    long[] array = (long[]) memory.obj;
+    memory.setObjAndOffset(null, 0);
+
     if (shouldPool(size)) {
       synchronized (this) {
-        LinkedList<WeakReference<MemoryBlock>> pool = 
bufferPoolsBySize.get(size);
+        LinkedList<WeakReference<long[]>> pool = bufferPoolsBySize.get(size);
         if (pool == null) {
           pool = new LinkedList<>();
           bufferPoolsBySize.put(size, pool);
         }
-        pool.add(new WeakReference<>(memory));
+        pool.add(new WeakReference<>(array));
       }
     } else {
       // Do nothing

http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
----------------------------------------------------------------------
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
index cd1d378..c333857 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -26,6 +26,25 @@ import org.apache.spark.unsafe.Platform;
  */
 public class MemoryBlock extends MemoryLocation {
 
+  /** Special `pageNumber` value for pages which were not allocated by 
TaskMemoryManagers */
+  public static final int NO_PAGE_NUMBER = -1;
+
+  /**
+   * Special `pageNumber` value for marking pages that have been freed in the 
TaskMemoryManager.
+   * We set `pageNumber` to this value in TaskMemoryManager.freePage() so that 
MemoryAllocator
+   * can detect if pages which were allocated by TaskMemoryManager have been 
freed in the TMM
+   * before being passed to MemoryAllocator.free() (it is an error to allocate 
a page in
+   * TaskMemoryManager and then directly free it in a MemoryAllocator without 
going through
+   * the TMM freePage() call).
+   */
+  public static final int FREED_IN_TMM_PAGE_NUMBER = -2;
+
+  /**
+   * Special `pageNumber` value for pages that have been freed by the 
MemoryAllocator. This allows
+   * us to detect double-frees.
+   */
+  public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3;
+
   private final long length;
 
   /**
@@ -33,7 +52,7 @@ public class MemoryBlock extends MemoryLocation {
    * TaskMemoryManager. This field is public so that it can be modified by the 
TaskMemoryManager,
    * which lives in a different package.
    */
-  public int pageNumber = -1;
+  public int pageNumber = NO_PAGE_NUMBER;
 
   public MemoryBlock(@Nullable Object obj, long offset, long length) {
     super(obj, offset);

http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
----------------------------------------------------------------------
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
index 55bcdf1..4368fb6 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
@@ -38,9 +38,20 @@ public class UnsafeMemoryAllocator implements 
MemoryAllocator {
   public void free(MemoryBlock memory) {
     assert (memory.obj == null) :
       "baseObject not null; are you trying to use the off-heap allocator to 
free on-heap memory?";
+    assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+      "page has already been freed";
+    assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER)
+            || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) :
+      "TMM-allocated pages must be freed via TMM.freePage(), not directly in 
allocator free()";
+
     if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
       memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
     }
     Platform.freeMemory(memory.offset);
+    // As an additional layer of defense against use-after-free bugs, we 
mutate the
+    // MemoryBlock to reset its pointer.
+    memory.offset = 0;
+    // Mark the page as freed (so we can detect double-frees).
+    memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER;
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
----------------------------------------------------------------------
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
index 4b14133..6285483 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -63,6 +63,52 @@ public class PlatformUtilSuite {
   }
 
   @Test
+  public void onHeapMemoryAllocatorPoolingReUsesLongArrays() {
+    MemoryBlock block1 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+    Object baseObject1 = block1.getBaseObject();
+    MemoryAllocator.HEAP.free(block1);
+    MemoryBlock block2 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+    Object baseObject2 = block2.getBaseObject();
+    Assert.assertSame(baseObject1, baseObject2);
+    MemoryAllocator.HEAP.free(block2);
+  }
+
+  @Test
+  public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() {
+    MemoryBlock block = MemoryAllocator.HEAP.allocate(1024);
+    Assert.assertNotNull(block.getBaseObject());
+    MemoryAllocator.HEAP.free(block);
+    Assert.assertNull(block.getBaseObject());
+    Assert.assertEquals(0, block.getBaseOffset());
+    Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, 
block.pageNumber);
+  }
+
+  @Test
+  public void freeingOffHeapMemoryBlockResetsOffset() {
+    MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024);
+    Assert.assertNull(block.getBaseObject());
+    Assert.assertNotEquals(0, block.getBaseOffset());
+    MemoryAllocator.UNSAFE.free(block);
+    Assert.assertNull(block.getBaseObject());
+    Assert.assertEquals(0, block.getBaseOffset());
+    Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, 
block.pageNumber);
+  }
+
+  @Test(expected = AssertionError.class)
+  public void onHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() {
+    MemoryBlock block = MemoryAllocator.HEAP.allocate(1024);
+    MemoryAllocator.HEAP.free(block);
+    MemoryAllocator.HEAP.free(block);
+  }
+
+  @Test(expected = AssertionError.class)
+  public void offHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() {
+    MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024);
+    MemoryAllocator.UNSAFE.free(block);
+    MemoryAllocator.UNSAFE.free(block);
+  }
+
+  @Test
   public void memoryDebugFillEnabledInTest() {
     Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED);
     MemoryBlock onheap = MemoryAllocator.HEAP.allocate(1);
@@ -71,9 +117,11 @@ public class PlatformUtilSuite {
       MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
 
     MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+    Object onheap1BaseObject = onheap1.getBaseObject();
+    long onheap1BaseOffset = onheap1.getBaseOffset();
     MemoryAllocator.HEAP.free(onheap1);
     Assert.assertEquals(
-      Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()),
+      Platform.getByte(onheap1BaseObject, onheap1BaseOffset),
       MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
     MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024);
     Assert.assertEquals(

http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java 
b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index e8d3730..632d718 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -321,8 +321,12 @@ public class TaskMemoryManager {
    * Free a block of memory allocated via {@link 
TaskMemoryManager#allocatePage}.
    */
   public void freePage(MemoryBlock page, MemoryConsumer consumer) {
-    assert (page.pageNumber != -1) :
+    assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) :
       "Called freePage() on memory that wasn't allocated with allocatePage()";
+    assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+      "Called freePage() on a memory block that has already been freed";
+    assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) :
+            "Called freePage() on a memory block that has already been freed";
     assert(allocatedPages.get(page.pageNumber));
     pageTable[page.pageNumber] = null;
     synchronized (this) {
@@ -332,6 +336,10 @@ public class TaskMemoryManager {
       logger.trace("Freed page number {} ({} bytes)", page.pageNumber, 
page.size());
     }
     long pageSize = page.size();
+    // Clear the page number before passing the block to the MemoryAllocator's 
free().
+    // Doing this allows the MemoryAllocator to detect when a 
TaskMemoryManager-managed
+    // page has been inappropriately directly freed without calling 
TMM.freePage().
+    page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
     memoryManager.tungstenMemoryAllocator().free(page);
     releaseExecutionMemory(pageSize, consumer);
   }
@@ -358,7 +366,7 @@ public class TaskMemoryManager {
 
   @VisibleForTesting
   public static long encodePageNumberAndOffset(int pageNumber, long 
offsetInPage) {
-    assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid 
page";
+    assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid 
page";
     return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & 
MASK_LONG_LOWER_51_BITS);
   }
 
@@ -424,6 +432,7 @@ public class TaskMemoryManager {
       for (MemoryBlock page : pageTable) {
         if (page != null) {
           logger.debug("unreleased page: " + page + " in task " + 
taskAttemptId);
+          page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
           memoryManager.tungstenMemoryAllocator().free(page);
         }
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java 
b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index 46b0516..a0664b3 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -21,6 +21,7 @@ import org.junit.Assert;
 import org.junit.Test;
 
 import org.apache.spark.SparkConf;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 
 public class TaskMemoryManagerSuite {
@@ -69,6 +70,34 @@ public class TaskMemoryManagerSuite {
   }
 
   @Test
+  public void freeingPageSetsPageNumberToSpecialConstant() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new TestMemoryManager(new 
SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+    final MemoryConsumer c = new TestMemoryConsumer(manager, 
MemoryMode.ON_HEAP);
+    final MemoryBlock dataPage = manager.allocatePage(256, c);
+    c.freePage(dataPage);
+    Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, 
dataPage.pageNumber);
+  }
+
+  @Test(expected = AssertionError.class)
+  public void freeingPageDirectlyInAllocatorTriggersAssertionError() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new TestMemoryManager(new 
SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+    final MemoryConsumer c = new TestMemoryConsumer(manager, 
MemoryMode.ON_HEAP);
+    final MemoryBlock dataPage = manager.allocatePage(256, c);
+    MemoryAllocator.HEAP.free(dataPage);
+  }
+
+  @Test(expected = AssertionError.class)
+  public void callingFreePageOnDirectlyAllocatedPageTriggersAssertionError() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new TestMemoryManager(new 
SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+    final MemoryConsumer c = new TestMemoryConsumer(manager, 
MemoryMode.ON_HEAP);
+    final MemoryBlock dataPage = MemoryAllocator.HEAP.allocate(256);
+    manager.freePage(dataPage, c);
+  }
+
+  @Test
   public void cooperativeSpilling() {
     final TestMemoryManager memoryManager = new TestMemoryManager(new 
SparkConf());
     memoryManager.limit(100);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to