You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2018/01/10 08:46:16 UTC

spark git commit: [SPARK-22997] Add additional defenses against use of freed MemoryBlocks

Repository: spark
Updated Branches:
  refs/heads/master 70bcc9d5a -> f340b6b30


[SPARK-22997] Add additional defenses against use of freed MemoryBlocks

## What changes were proposed in this pull request?

This patch modifies Spark's `MemoryAllocator` implementations so that `free(MemoryBlock)` mutates the passed block to clear pointers (in the off-heap case) or null out references to backing `long[]` arrays (in the on-heap case). The goal of this change is to add an extra layer of defense against use-after-free bugs because currently it's hard to detect corruption caused by blind writes to freed memory blocks.

## How was this patch tested?

New unit tests in `PlatformSuite`, including new tests for existing functionality because we did not have sufficient mutation coverage of the on-heap memory allocator's pooling logic.

Author: Josh Rosen <jo...@databricks.com>

Closes #20191 from JoshRosen/SPARK-22997-add-defenses-against-use-after-free-bugs-in-memory-allocator.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f340b6b3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f340b6b3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f340b6b3

Branch: refs/heads/master
Commit: f340b6b3066033d40b7e163fd5fb68e9820adfb1
Parents: 70bcc9d
Author: Josh Rosen <jo...@databricks.com>
Authored: Wed Jan 10 00:45:47 2018 -0800
Committer: Josh Rosen <jo...@databricks.com>
Committed: Wed Jan 10 00:45:47 2018 -0800

----------------------------------------------------------------------
 .../unsafe/memory/HeapMemoryAllocator.java      | 35 ++++++++++----
 .../apache/spark/unsafe/memory/MemoryBlock.java | 21 +++++++-
 .../unsafe/memory/UnsafeMemoryAllocator.java    | 11 +++++
 .../apache/spark/unsafe/PlatformUtilSuite.java  | 50 +++++++++++++++++++-
 .../apache/spark/memory/TaskMemoryManager.java  | 13 ++++-
 .../spark/memory/TaskMemoryManagerSuite.java    | 29 ++++++++++++
 6 files changed, 146 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
index cc9cc42..3acfe36 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -31,8 +31,7 @@ import org.apache.spark.unsafe.Platform;
 public class HeapMemoryAllocator implements MemoryAllocator {
 
   @GuardedBy("this")
-  private final Map<Long, LinkedList<WeakReference<MemoryBlock>>> bufferPoolsBySize =
-    new HashMap<>();
+  private final Map<Long, LinkedList<WeakReference<long[]>>> bufferPoolsBySize = new HashMap<>();
 
   private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024;
 
@@ -49,13 +48,14 @@ public class HeapMemoryAllocator implements MemoryAllocator {
   public MemoryBlock allocate(long size) throws OutOfMemoryError {
     if (shouldPool(size)) {
       synchronized (this) {
-        final LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
+        final LinkedList<WeakReference<long[]>> pool = bufferPoolsBySize.get(size);
         if (pool != null) {
           while (!pool.isEmpty()) {
-            final WeakReference<MemoryBlock> blockReference = pool.pop();
-            final MemoryBlock memory = blockReference.get();
-            if (memory != null) {
-              assert (memory.size() == size);
+            final WeakReference<long[]> arrayReference = pool.pop();
+            final long[] array = arrayReference.get();
+            if (array != null) {
+              assert (array.length * 8L >= size);
+              MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size);
               if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
                 memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
               }
@@ -76,18 +76,35 @@ public class HeapMemoryAllocator implements MemoryAllocator {
 
   @Override
   public void free(MemoryBlock memory) {
+    assert (memory.obj != null) :
+      "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?";
+    assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+      "page has already been freed";
+    assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER)
+            || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) :
+      "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator free()";
+
     final long size = memory.size();
     if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
       memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
     }
+
+    // Mark the page as freed (so we can detect double-frees).
+    memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER;
+
+    // As an additional layer of defense against use-after-free bugs, we mutate the
+    // MemoryBlock to null out its reference to the long[] array.
+    long[] array = (long[]) memory.obj;
+    memory.setObjAndOffset(null, 0);
+
     if (shouldPool(size)) {
       synchronized (this) {
-        LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
+        LinkedList<WeakReference<long[]>> pool = bufferPoolsBySize.get(size);
         if (pool == null) {
           pool = new LinkedList<>();
           bufferPoolsBySize.put(size, pool);
         }
-        pool.add(new WeakReference<>(memory));
+        pool.add(new WeakReference<>(array));
       }
     } else {
       // Do nothing

http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
index cd1d378..c333857 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -26,6 +26,25 @@ import org.apache.spark.unsafe.Platform;
  */
 public class MemoryBlock extends MemoryLocation {
 
+  /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */
+  public static final int NO_PAGE_NUMBER = -1;
+
+  /**
+   * Special `pageNumber` value for marking pages that have been freed in the TaskMemoryManager.
+   * We set `pageNumber` to this value in TaskMemoryManager.freePage() so that MemoryAllocator
+   * can detect if pages which were allocated by TaskMemoryManager have been freed in the TMM
+   * before being passed to MemoryAllocator.free() (it is an error to allocate a page in
+   * TaskMemoryManager and then directly free it in a MemoryAllocator without going through
+   * the TMM freePage() call).
+   */
+  public static final int FREED_IN_TMM_PAGE_NUMBER = -2;
+
+  /**
+   * Special `pageNumber` value for pages that have been freed by the MemoryAllocator. This allows
+   * us to detect double-frees.
+   */
+  public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3;
+
   private final long length;
 
   /**
@@ -33,7 +52,7 @@ public class MemoryBlock extends MemoryLocation {
    * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager,
    * which lives in a different package.
    */
-  public int pageNumber = -1;
+  public int pageNumber = NO_PAGE_NUMBER;
 
   public MemoryBlock(@Nullable Object obj, long offset, long length) {
     super(obj, offset);

http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
index 55bcdf1..4368fb6 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
@@ -38,9 +38,20 @@ public class UnsafeMemoryAllocator implements MemoryAllocator {
   public void free(MemoryBlock memory) {
     assert (memory.obj == null) :
       "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?";
+    assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+      "page has already been freed";
+    assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER)
+            || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) :
+      "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()";
+
     if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
       memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
     }
     Platform.freeMemory(memory.offset);
+    // As an additional layer of defense against use-after-free bugs, we mutate the
+    // MemoryBlock to reset its pointer.
+    memory.offset = 0;
+    // Mark the page as freed (so we can detect double-frees).
+    memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER;
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
index 4b14133..6285483 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -63,6 +63,52 @@ public class PlatformUtilSuite {
   }
 
   @Test
+  public void onHeapMemoryAllocatorPoolingReUsesLongArrays() {
+    MemoryBlock block1 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+    Object baseObject1 = block1.getBaseObject();
+    MemoryAllocator.HEAP.free(block1);
+    MemoryBlock block2 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+    Object baseObject2 = block2.getBaseObject();
+    Assert.assertSame(baseObject1, baseObject2);
+    MemoryAllocator.HEAP.free(block2);
+  }
+
+  @Test
+  public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() {
+    MemoryBlock block = MemoryAllocator.HEAP.allocate(1024);
+    Assert.assertNotNull(block.getBaseObject());
+    MemoryAllocator.HEAP.free(block);
+    Assert.assertNull(block.getBaseObject());
+    Assert.assertEquals(0, block.getBaseOffset());
+    Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber);
+  }
+
+  @Test
+  public void freeingOffHeapMemoryBlockResetsOffset() {
+    MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024);
+    Assert.assertNull(block.getBaseObject());
+    Assert.assertNotEquals(0, block.getBaseOffset());
+    MemoryAllocator.UNSAFE.free(block);
+    Assert.assertNull(block.getBaseObject());
+    Assert.assertEquals(0, block.getBaseOffset());
+    Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber);
+  }
+
+  @Test(expected = AssertionError.class)
+  public void onHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() {
+    MemoryBlock block = MemoryAllocator.HEAP.allocate(1024);
+    MemoryAllocator.HEAP.free(block);
+    MemoryAllocator.HEAP.free(block);
+  }
+
+  @Test(expected = AssertionError.class)
+  public void offHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() {
+    MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024);
+    MemoryAllocator.UNSAFE.free(block);
+    MemoryAllocator.UNSAFE.free(block);
+  }
+
+  @Test
   public void memoryDebugFillEnabledInTest() {
     Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED);
     MemoryBlock onheap = MemoryAllocator.HEAP.allocate(1);
@@ -71,9 +117,11 @@ public class PlatformUtilSuite {
       MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
 
     MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+    Object onheap1BaseObject = onheap1.getBaseObject();
+    long onheap1BaseOffset = onheap1.getBaseOffset();
     MemoryAllocator.HEAP.free(onheap1);
     Assert.assertEquals(
-      Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()),
+      Platform.getByte(onheap1BaseObject, onheap1BaseOffset),
       MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
     MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024);
     Assert.assertEquals(

http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index e8d3730..632d718 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -321,8 +321,12 @@ public class TaskMemoryManager {
    * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}.
    */
   public void freePage(MemoryBlock page, MemoryConsumer consumer) {
-    assert (page.pageNumber != -1) :
+    assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) :
       "Called freePage() on memory that wasn't allocated with allocatePage()";
+    assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+      "Called freePage() on a memory block that has already been freed";
+    assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) :
+            "Called freePage() on a memory block that has already been freed";
     assert(allocatedPages.get(page.pageNumber));
     pageTable[page.pageNumber] = null;
     synchronized (this) {
@@ -332,6 +336,10 @@ public class TaskMemoryManager {
       logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
     }
     long pageSize = page.size();
+    // Clear the page number before passing the block to the MemoryAllocator's free().
+    // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed
+    // page has been inappropriately directly freed without calling TMM.freePage().
+    page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
     memoryManager.tungstenMemoryAllocator().free(page);
     releaseExecutionMemory(pageSize, consumer);
   }
@@ -358,7 +366,7 @@ public class TaskMemoryManager {
 
   @VisibleForTesting
   public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
-    assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
+    assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page";
     return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
   }
 
@@ -424,6 +432,7 @@ public class TaskMemoryManager {
       for (MemoryBlock page : pageTable) {
         if (page != null) {
           logger.debug("unreleased page: " + page + " in task " + taskAttemptId);
+          page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
           memoryManager.tungstenMemoryAllocator().free(page);
         }
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index 46b0516..a0664b3 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -21,6 +21,7 @@ import org.junit.Assert;
 import org.junit.Test;
 
 import org.apache.spark.SparkConf;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 
 public class TaskMemoryManagerSuite {
@@ -69,6 +70,34 @@ public class TaskMemoryManagerSuite {
   }
 
   @Test
+  public void freeingPageSetsPageNumberToSpecialConstant() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+    final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
+    final MemoryBlock dataPage = manager.allocatePage(256, c);
+    c.freePage(dataPage);
+    Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber);
+  }
+
+  @Test(expected = AssertionError.class)
+  public void freeingPageDirectlyInAllocatorTriggersAssertionError() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+    final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
+    final MemoryBlock dataPage = manager.allocatePage(256, c);
+    MemoryAllocator.HEAP.free(dataPage);
+  }
+
+  @Test(expected = AssertionError.class)
+  public void callingFreePageOnDirectlyAllocatedPageTriggersAssertionError() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+    final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
+    final MemoryBlock dataPage = MemoryAllocator.HEAP.allocate(256);
+    manager.freePage(dataPage, c);
+  }
+
+  @Test
   public void cooperativeSpilling() {
     final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf());
     memoryManager.limit(100);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org