You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2018/01/10 08:46:16 UTC
spark git commit: [SPARK-22997] Add additional defenses against use
of freed MemoryBlocks
Repository: spark
Updated Branches:
refs/heads/master 70bcc9d5a -> f340b6b30
[SPARK-22997] Add additional defenses against use of freed MemoryBlocks
## What changes were proposed in this pull request?
This patch modifies Spark's `MemoryAllocator` implementations so that `free(MemoryBlock)` mutates the passed block to clear pointers (in the off-heap case) or null out references to backing `long[]` arrays (in the on-heap case). The goal of this change is to add an extra layer of defense against use-after-free bugs because currently it's hard to detect corruption caused by blind writes to freed memory blocks.
## How was this patch tested?
New unit tests in `PlatformSuite`, including new tests for existing functionality because we did not have sufficient mutation coverage of the on-heap memory allocator's pooling logic.
Author: Josh Rosen <jo...@databricks.com>
Closes #20191 from JoshRosen/SPARK-22997-add-defenses-against-use-after-free-bugs-in-memory-allocator.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f340b6b3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f340b6b3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f340b6b3
Branch: refs/heads/master
Commit: f340b6b3066033d40b7e163fd5fb68e9820adfb1
Parents: 70bcc9d
Author: Josh Rosen <jo...@databricks.com>
Authored: Wed Jan 10 00:45:47 2018 -0800
Committer: Josh Rosen <jo...@databricks.com>
Committed: Wed Jan 10 00:45:47 2018 -0800
----------------------------------------------------------------------
.../unsafe/memory/HeapMemoryAllocator.java | 35 ++++++++++----
.../apache/spark/unsafe/memory/MemoryBlock.java | 21 +++++++-
.../unsafe/memory/UnsafeMemoryAllocator.java | 11 +++++
.../apache/spark/unsafe/PlatformUtilSuite.java | 50 +++++++++++++++++++-
.../apache/spark/memory/TaskMemoryManager.java | 13 ++++-
.../spark/memory/TaskMemoryManagerSuite.java | 29 ++++++++++++
6 files changed, 146 insertions(+), 13 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
index cc9cc42..3acfe36 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -31,8 +31,7 @@ import org.apache.spark.unsafe.Platform;
public class HeapMemoryAllocator implements MemoryAllocator {
@GuardedBy("this")
- private final Map<Long, LinkedList<WeakReference<MemoryBlock>>> bufferPoolsBySize =
- new HashMap<>();
+ private final Map<Long, LinkedList<WeakReference<long[]>>> bufferPoolsBySize = new HashMap<>();
private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024;
@@ -49,13 +48,14 @@ public class HeapMemoryAllocator implements MemoryAllocator {
public MemoryBlock allocate(long size) throws OutOfMemoryError {
if (shouldPool(size)) {
synchronized (this) {
- final LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
+ final LinkedList<WeakReference<long[]>> pool = bufferPoolsBySize.get(size);
if (pool != null) {
while (!pool.isEmpty()) {
- final WeakReference<MemoryBlock> blockReference = pool.pop();
- final MemoryBlock memory = blockReference.get();
- if (memory != null) {
- assert (memory.size() == size);
+ final WeakReference<long[]> arrayReference = pool.pop();
+ final long[] array = arrayReference.get();
+ if (array != null) {
+ assert (array.length * 8L >= size);
+ MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size);
if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
}
@@ -76,18 +76,35 @@ public class HeapMemoryAllocator implements MemoryAllocator {
@Override
public void free(MemoryBlock memory) {
+ assert (memory.obj != null) :
+ "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?";
+ assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+ "page has already been freed";
+ assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER)
+ || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) :
+ "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator free()";
+
final long size = memory.size();
if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
}
+
+ // Mark the page as freed (so we can detect double-frees).
+ memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER;
+
+ // As an additional layer of defense against use-after-free bugs, we mutate the
+ // MemoryBlock to null out its reference to the long[] array.
+ long[] array = (long[]) memory.obj;
+ memory.setObjAndOffset(null, 0);
+
if (shouldPool(size)) {
synchronized (this) {
- LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
+ LinkedList<WeakReference<long[]>> pool = bufferPoolsBySize.get(size);
if (pool == null) {
pool = new LinkedList<>();
bufferPoolsBySize.put(size, pool);
}
- pool.add(new WeakReference<>(memory));
+ pool.add(new WeakReference<>(array));
}
} else {
// Do nothing
http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
index cd1d378..c333857 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -26,6 +26,25 @@ import org.apache.spark.unsafe.Platform;
*/
public class MemoryBlock extends MemoryLocation {
+ /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */
+ public static final int NO_PAGE_NUMBER = -1;
+
+ /**
+ * Special `pageNumber` value for marking pages that have been freed in the TaskMemoryManager.
+ * We set `pageNumber` to this value in TaskMemoryManager.freePage() so that MemoryAllocator
+ * can detect if pages which were allocated by TaskMemoryManager have been freed in the TMM
+ * before being passed to MemoryAllocator.free() (it is an error to allocate a page in
+ * TaskMemoryManager and then directly free it in a MemoryAllocator without going through
+ * the TMM freePage() call).
+ */
+ public static final int FREED_IN_TMM_PAGE_NUMBER = -2;
+
+ /**
+ * Special `pageNumber` value for pages that have been freed by the MemoryAllocator. This allows
+ * us to detect double-frees.
+ */
+ public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3;
+
private final long length;
/**
@@ -33,7 +52,7 @@ public class MemoryBlock extends MemoryLocation {
* TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager,
* which lives in a different package.
*/
- public int pageNumber = -1;
+ public int pageNumber = NO_PAGE_NUMBER;
public MemoryBlock(@Nullable Object obj, long offset, long length) {
super(obj, offset);
http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
index 55bcdf1..4368fb6 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
@@ -38,9 +38,20 @@ public class UnsafeMemoryAllocator implements MemoryAllocator {
public void free(MemoryBlock memory) {
assert (memory.obj == null) :
"baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?";
+ assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+ "page has already been freed";
+ assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER)
+ || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) :
+ "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()";
+
if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
}
Platform.freeMemory(memory.offset);
+ // As an additional layer of defense against use-after-free bugs, we mutate the
+ // MemoryBlock to reset its pointer.
+ memory.offset = 0;
+ // Mark the page as freed (so we can detect double-frees).
+ memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER;
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
index 4b14133..6285483 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -63,6 +63,52 @@ public class PlatformUtilSuite {
}
@Test
+ public void onHeapMemoryAllocatorPoolingReUsesLongArrays() {
+ MemoryBlock block1 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+ Object baseObject1 = block1.getBaseObject();
+ MemoryAllocator.HEAP.free(block1);
+ MemoryBlock block2 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+ Object baseObject2 = block2.getBaseObject();
+ Assert.assertSame(baseObject1, baseObject2);
+ MemoryAllocator.HEAP.free(block2);
+ }
+
+ @Test
+ public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() {
+ MemoryBlock block = MemoryAllocator.HEAP.allocate(1024);
+ Assert.assertNotNull(block.getBaseObject());
+ MemoryAllocator.HEAP.free(block);
+ Assert.assertNull(block.getBaseObject());
+ Assert.assertEquals(0, block.getBaseOffset());
+ Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber);
+ }
+
+ @Test
+ public void freeingOffHeapMemoryBlockResetsOffset() {
+ MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024);
+ Assert.assertNull(block.getBaseObject());
+ Assert.assertNotEquals(0, block.getBaseOffset());
+ MemoryAllocator.UNSAFE.free(block);
+ Assert.assertNull(block.getBaseObject());
+ Assert.assertEquals(0, block.getBaseOffset());
+ Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber);
+ }
+
+ @Test(expected = AssertionError.class)
+ public void onHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() {
+ MemoryBlock block = MemoryAllocator.HEAP.allocate(1024);
+ MemoryAllocator.HEAP.free(block);
+ MemoryAllocator.HEAP.free(block);
+ }
+
+ @Test(expected = AssertionError.class)
+ public void offHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() {
+ MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024);
+ MemoryAllocator.UNSAFE.free(block);
+ MemoryAllocator.UNSAFE.free(block);
+ }
+
+ @Test
public void memoryDebugFillEnabledInTest() {
Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED);
MemoryBlock onheap = MemoryAllocator.HEAP.allocate(1);
@@ -71,9 +117,11 @@ public class PlatformUtilSuite {
MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+ Object onheap1BaseObject = onheap1.getBaseObject();
+ long onheap1BaseOffset = onheap1.getBaseOffset();
MemoryAllocator.HEAP.free(onheap1);
Assert.assertEquals(
- Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()),
+ Platform.getByte(onheap1BaseObject, onheap1BaseOffset),
MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024);
Assert.assertEquals(
http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index e8d3730..632d718 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -321,8 +321,12 @@ public class TaskMemoryManager {
* Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}.
*/
public void freePage(MemoryBlock page, MemoryConsumer consumer) {
- assert (page.pageNumber != -1) :
+ assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) :
"Called freePage() on memory that wasn't allocated with allocatePage()";
+ assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+ "Called freePage() on a memory block that has already been freed";
+ assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) :
+ "Called freePage() on a memory block that has already been freed";
assert(allocatedPages.get(page.pageNumber));
pageTable[page.pageNumber] = null;
synchronized (this) {
@@ -332,6 +336,10 @@ public class TaskMemoryManager {
logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
}
long pageSize = page.size();
+ // Clear the page number before passing the block to the MemoryAllocator's free().
+ // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed
+ // page has been inappropriately directly freed without calling TMM.freePage().
+ page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
memoryManager.tungstenMemoryAllocator().free(page);
releaseExecutionMemory(pageSize, consumer);
}
@@ -358,7 +366,7 @@ public class TaskMemoryManager {
@VisibleForTesting
public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
- assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
+ assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page";
return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
}
@@ -424,6 +432,7 @@ public class TaskMemoryManager {
for (MemoryBlock page : pageTable) {
if (page != null) {
logger.debug("unreleased page: " + page + " in task " + taskAttemptId);
+ page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
memoryManager.tungstenMemoryAllocator().free(page);
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f340b6b3/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index 46b0516..a0664b3 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -21,6 +21,7 @@ import org.junit.Assert;
import org.junit.Test;
import org.apache.spark.SparkConf;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryBlock;
public class TaskMemoryManagerSuite {
@@ -69,6 +70,34 @@ public class TaskMemoryManagerSuite {
}
@Test
+ public void freeingPageSetsPageNumberToSpecialConstant() {
+ final TaskMemoryManager manager = new TaskMemoryManager(
+ new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+ final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
+ final MemoryBlock dataPage = manager.allocatePage(256, c);
+ c.freePage(dataPage);
+ Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber);
+ }
+
+ @Test(expected = AssertionError.class)
+ public void freeingPageDirectlyInAllocatorTriggersAssertionError() {
+ final TaskMemoryManager manager = new TaskMemoryManager(
+ new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+ final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
+ final MemoryBlock dataPage = manager.allocatePage(256, c);
+ MemoryAllocator.HEAP.free(dataPage);
+ }
+
+ @Test(expected = AssertionError.class)
+ public void callingFreePageOnDirectlyAllocatedPageTriggersAssertionError() {
+ final TaskMemoryManager manager = new TaskMemoryManager(
+ new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+ final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
+ final MemoryBlock dataPage = MemoryAllocator.HEAP.allocate(256);
+ manager.freePage(dataPage, c);
+ }
+
+ @Test
public void cooperativeSpilling() {
final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf());
memoryManager.limit(100);
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org