You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/10/30 07:39:11 UTC

[3/3] spark git commit: [SPARK-10342] [SPARK-10309] [SPARK-10474] [SPARK-10929] [SQL] Cooperative memory management

[SPARK-10342] [SPARK-10309] [SPARK-10474] [SPARK-10929] [SQL] Cooperative memory management

This PR introduce a mechanism to call spill() on those SQL operators that support spilling (for example, BytesToBytesMap, UnsafeExternalSorter and ShuffleExternalSorter) if there is not enough memory for execution. The preserved first page is needed anymore, so removed.

Other Spillable objects in Spark core (ExternalSorter and AppendOnlyMap) are not included in this PR, but those could benefit from this (trigger others' spilling).

The PrepareRDD may be not needed anymore, could be removed in follow up PR.

The following script will fail with OOM before this PR, finished in 150 seconds with 2G heap (also works in 1.5 branch, with similar duration).

```python
sqlContext.setConf("spark.sql.shuffle.partitions", "1")
df = sqlContext.range(1<<25).selectExpr("id", "repeat(id, 2) as s")
df2 = df.select(df.id.alias('id2'), df.s.alias('s2'))
j = df.join(df2, df.id==df2.id2).groupBy(df.id).max("id", "id2")
j.explain()
print j.count()
```

For thread-safety, here what I'm got:

1) Without calling spill(), the operators should only be used by single thread, no safety problems.

2) spill() could be triggered in two cases, triggered by itself, or by other operators. we can check trigger == this in spill(), so it's still in the same thread, so safety problems.

3) if it's triggered by other operators (right now cache will not trigger spill()), we only spill the data into disk when it's in scanning stage (building is finished), so the in-memory sorter or memory pages are read-only, we only need to synchronize the iterator and change it.

4) During scanning, the iterator will only use one record in one page, we can't free this page, because the downstream is currently using it (used by UnsafeRow or other objects). In BytesToBytesMap, we just skip the current page, and dump all others into disk. In UnsafeExternalSorter, we keep the page that is used by current record (having the same baseObject), free it when loading the next record. In ShuffleExternalSorter, the spill() will not trigger during scanning.

5) In order to avoid deadlock, we didn't call acquireMemory during spill (so we reused the pointer array in InMemorySorter).

Author: Davies Liu <da...@databricks.com>

Closes #9241 from davies/force_spill.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56419cf1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56419cf1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56419cf1

Branch: refs/heads/master
Commit: 56419cf11f769c80f391b45dc41b3c7101cc5ff4
Parents: d89be0b
Author: Davies Liu <da...@databricks.com>
Authored: Thu Oct 29 23:38:06 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Thu Oct 29 23:38:06 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/memory/MemoryConsumer.java | 128 ++++++
 .../apache/spark/memory/TaskMemoryManager.java  | 138 ++++--
 .../shuffle/sort/ShuffleExternalSorter.java     | 210 +++------
 .../shuffle/sort/ShuffleInMemorySorter.java     |  50 ++-
 .../spark/shuffle/sort/UnsafeShuffleWriter.java |   6 -
 .../spark/unsafe/map/BytesToBytesMap.java       | 430 +++++++++++--------
 .../unsafe/sort/UnsafeExternalSorter.java       | 426 +++++++++---------
 .../unsafe/sort/UnsafeInMemorySorter.java       |  60 ++-
 .../unsafe/sort/UnsafeSorterSpillReader.java    |   6 +-
 .../unsafe/sort/UnsafeSorterSpillWriter.java    |   2 +-
 .../org/apache/spark/memory/MemoryManager.scala |   9 +-
 .../spark/util/collection/Spillable.scala       |   4 +-
 .../spark/memory/TaskMemoryManagerSuite.java    |  77 +++-
 .../shuffle/sort/PackedRecordPointerSuite.java  |  30 +-
 .../sort/ShuffleInMemorySorterSuite.java        |   6 +-
 .../shuffle/sort/UnsafeShuffleWriterSuite.java  |  38 +-
 .../map/AbstractBytesToBytesMapSuite.java       | 149 ++++++-
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |  97 +++--
 .../unsafe/sort/UnsafeInMemorySorterSuite.java  |  20 +-
 .../scala/org/apache/spark/FailureSuite.scala   |   4 +-
 .../memory/GrantEverythingMemoryManager.scala   |  54 ---
 .../spark/memory/MemoryManagerSuite.scala       |  60 +--
 .../apache/spark/memory/TestMemoryManager.scala |  70 +++
 .../sql/execution/UnsafeExternalRowSorter.java  |   7 +-
 .../UnsafeFixedWidthAggregationMap.java         |   2 +-
 .../sql/execution/UnsafeKVExternalSorter.java   |  19 +-
 .../UnsafeFixedWidthAggregationMapSuite.scala   |  22 +-
 .../execution/UnsafeKVExternalSorterSuite.scala |   6 +-
 .../TungstenAggregationIteratorSuite.scala      |  54 ---
 .../unsafe/memory/HeapMemoryAllocator.java      |   9 +-
 .../unsafe/memory/UnsafeMemoryAllocator.java    |   3 -
 31 files changed, 1316 insertions(+), 880 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
new file mode 100644
index 0000000..008799c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory;
+
+
+import java.io.IOException;
+
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+
+/**
+ * An memory consumer of TaskMemoryManager, which support spilling.
+ */
+public abstract class MemoryConsumer {
+
+  private final TaskMemoryManager taskMemoryManager;
+  private final long pageSize;
+  private long used;
+
+  protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) {
+    this.taskMemoryManager = taskMemoryManager;
+    this.pageSize = pageSize;
+    this.used = 0;
+  }
+
+  protected MemoryConsumer(TaskMemoryManager taskMemoryManager) {
+    this(taskMemoryManager, taskMemoryManager.pageSizeBytes());
+  }
+
+  /**
+   * Returns the size of used memory in bytes.
+   */
+  long getUsed() {
+    return used;
+  }
+
+  /**
+   * Force spill during building.
+   *
+   * For testing.
+   */
+  public void spill() throws IOException {
+    spill(Long.MAX_VALUE, this);
+  }
+
+  /**
+   * Spill some data to disk to release memory, which will be called by TaskMemoryManager
+   * when there is not enough memory for the task.
+   *
+   * This should be implemented by subclass.
+   *
+   * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill().
+   *
+   * @param size the amount of memory should be released
+   * @param trigger the MemoryConsumer that trigger this spilling
+   * @return the amount of released memory in bytes
+   * @throws IOException
+   */
+  public abstract long spill(long size, MemoryConsumer trigger) throws IOException;
+
+  /**
+   * Acquire `size` bytes memory.
+   *
+   * If there is not enough memory, throws OutOfMemoryError.
+   */
+  protected void acquireMemory(long size) {
+    long got = taskMemoryManager.acquireExecutionMemory(size, this);
+    if (got < size) {
+      taskMemoryManager.releaseExecutionMemory(got, this);
+      taskMemoryManager.showMemoryUsage();
+      throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got);
+    }
+    used += got;
+  }
+
+  /**
+   * Release `size` bytes memory.
+   */
+  protected void releaseMemory(long size) {
+    used -= size;
+    taskMemoryManager.releaseExecutionMemory(size, this);
+  }
+
+  /**
+   * Allocate a memory block with at least `required` bytes.
+   *
+   * Throws IOException if there is not enough memory.
+   *
+   * @throws OutOfMemoryError
+   */
+  protected MemoryBlock allocatePage(long required) {
+    MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this);
+    if (page == null || page.size() < required) {
+      long got = 0;
+      if (page != null) {
+        got = page.size();
+        freePage(page);
+      }
+      taskMemoryManager.showMemoryUsage();
+      throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got);
+    }
+    used += page.size();
+    return page;
+  }
+
+  /**
+   * Free a memory block.
+   */
+  protected void freePage(MemoryBlock page) {
+    used -= page.size();
+    taskMemoryManager.freePage(page, this);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index 7b31c90..4230575 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -17,13 +17,18 @@
 
 package org.apache.spark.memory;
 
-import java.util.*;
+import javax.annotation.concurrent.GuardedBy;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.BitSet;
+import java.util.HashSet;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.util.Utils;
 
 /**
  * Manages the memory allocated by an individual task.
@@ -101,29 +106,104 @@ public class TaskMemoryManager {
   private final boolean inHeap;
 
   /**
+   * The size of memory granted to each consumer.
+   */
+  @GuardedBy("this")
+  private final HashSet<MemoryConsumer> consumers;
+
+  /**
    * Construct a new TaskMemoryManager.
    */
   public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
     this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap();
     this.memoryManager = memoryManager;
     this.taskAttemptId = taskAttemptId;
+    this.consumers = new HashSet<>();
   }
 
   /**
-   * Acquire N bytes of memory for execution, evicting cached blocks if necessary.
+   * Acquire N bytes of memory for a consumer. If there is no enough memory, it will call
+   * spill() of consumers to release more memory.
+   *
    * @return number of bytes successfully granted (<= N).
    */
-  public long acquireExecutionMemory(long size) {
-    return memoryManager.acquireExecutionMemory(size, taskAttemptId);
+  public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
+    assert(required >= 0);
+    synchronized (this) {
+      long got = memoryManager.acquireExecutionMemory(required, taskAttemptId);
+
+      // try to release memory from other consumers first, then we can reduce the frequency of
+      // spilling, avoid to have too many spilled files.
+      if (got < required) {
+        // Call spill() on other consumers to release memory
+        for (MemoryConsumer c: consumers) {
+          if (c != null && c != consumer && c.getUsed() > 0) {
+            try {
+              long released = c.spill(required - got, consumer);
+              if (released > 0) {
+                logger.info("Task {} released {} from {} for {}", taskAttemptId,
+                  Utils.bytesToString(released), c, consumer);
+                got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId);
+                if (got >= required) {
+                  break;
+                }
+              }
+            } catch (IOException e) {
+              logger.error("error while calling spill() on " + c, e);
+              throw new OutOfMemoryError("error while calling spill() on " + c + " : "
+                + e.getMessage());
+            }
+          }
+        }
+      }
+
+      // call spill() on itself
+      if (got < required && consumer != null) {
+        try {
+          long released = consumer.spill(required - got, consumer);
+          if (released > 0) {
+            logger.info("Task {} released {} from itself ({})", taskAttemptId,
+              Utils.bytesToString(released), consumer);
+            got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId);
+          }
+        } catch (IOException e) {
+          logger.error("error while calling spill() on " + consumer, e);
+          throw new OutOfMemoryError("error while calling spill() on " + consumer + " : "
+            + e.getMessage());
+        }
+      }
+
+      consumers.add(consumer);
+      logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer);
+      return got;
+    }
   }
 
   /**
-   * Release N bytes of execution memory.
+   * Release N bytes of execution memory for a MemoryConsumer.
    */
-  public void releaseExecutionMemory(long size) {
+  public void releaseExecutionMemory(long size, MemoryConsumer consumer) {
+    logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer);
     memoryManager.releaseExecutionMemory(size, taskAttemptId);
   }
 
+  /**
+   * Dump the memory usage of all consumers.
+   */
+  public void showMemoryUsage() {
+    logger.info("Memory used in task " + taskAttemptId);
+    synchronized (this) {
+      for (MemoryConsumer c: consumers) {
+        if (c.getUsed() > 0) {
+          logger.info("Acquired by " + c + ": " + Utils.bytesToString(c.getUsed()));
+        }
+      }
+    }
+  }
+
+  /**
+   * Return the page size in bytes.
+   */
   public long pageSizeBytes() {
     return memoryManager.pageSizeBytes();
   }
@@ -134,42 +214,40 @@ public class TaskMemoryManager {
    *
    * Returns `null` if there was not enough memory to allocate the page.
    */
-  public MemoryBlock allocatePage(long size) {
+  public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
     if (size > MAXIMUM_PAGE_SIZE_BYTES) {
       throw new IllegalArgumentException(
         "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
     }
 
+    long acquired = acquireExecutionMemory(size, consumer);
+    if (acquired <= 0) {
+      return null;
+    }
+
     final int pageNumber;
     synchronized (this) {
       pageNumber = allocatedPages.nextClearBit(0);
       if (pageNumber >= PAGE_TABLE_SIZE) {
+        releaseExecutionMemory(acquired, consumer);
         throw new IllegalStateException(
           "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
       }
       allocatedPages.set(pageNumber);
     }
-    final long acquiredExecutionMemory = acquireExecutionMemory(size);
-    if (acquiredExecutionMemory != size) {
-      releaseExecutionMemory(acquiredExecutionMemory);
-      synchronized (this) {
-        allocatedPages.clear(pageNumber);
-      }
-      return null;
-    }
-    final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size);
+    final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired);
     page.pageNumber = pageNumber;
     pageTable[pageNumber] = page;
     if (logger.isTraceEnabled()) {
-      logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
+      logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired);
     }
     return page;
   }
 
   /**
-   * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
+   * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}.
    */
-  public void freePage(MemoryBlock page) {
+  public void freePage(MemoryBlock page, MemoryConsumer consumer) {
     assert (page.pageNumber != -1) :
       "Called freePage() on memory that wasn't allocated with allocatePage()";
     assert(allocatedPages.get(page.pageNumber));
@@ -182,14 +260,14 @@ public class TaskMemoryManager {
     }
     long pageSize = page.size();
     memoryManager.tungstenMemoryAllocator().free(page);
-    releaseExecutionMemory(pageSize);
+    releaseExecutionMemory(pageSize, consumer);
   }
 
   /**
    * Given a memory page and offset within that page, encode this address into a 64-bit long.
    * This address will remain valid as long as the corresponding page has not been freed.
    *
-   * @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/
+   * @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/
    * @param offsetInPage an offset in this page which incorporates the base offset. In other words,
    *                     this should be the value that you would pass as the base offset into an
    *                     UNSAFE call (e.g. page.baseOffset() + something).
@@ -261,17 +339,17 @@ public class TaskMemoryManager {
    * value can be used to detect memory leaks.
    */
   public long cleanUpAllAllocatedMemory() {
-    long freedBytes = 0;
-    for (MemoryBlock page : pageTable) {
-      if (page != null) {
-        freedBytes += page.size();
-        freePage(page);
+    synchronized (this) {
+      Arrays.fill(pageTable, null);
+      for (MemoryConsumer c: consumers) {
+        if (c != null && c.getUsed() > 0) {
+          // In case of failed task, it's normal to see leaked memory
+          logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c);
+        }
       }
+      consumers.clear();
     }
-
-    freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
-
-    return freedBytes;
+    return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index f43236f..400d852 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -31,15 +31,15 @@ import org.slf4j.LoggerFactory;
 import org.apache.spark.SparkConf;
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.serializer.DummySerializerInstance;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.DiskBlockObjectWriter;
 import org.apache.spark.storage.TempShuffleBlockId;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
 /**
@@ -58,23 +58,18 @@ import org.apache.spark.util.Utils;
  * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
  * specialized merge procedure that avoids extra serialization/deserialization.
  */
-final class ShuffleExternalSorter {
+final class ShuffleExternalSorter extends MemoryConsumer {
 
   private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
 
   @VisibleForTesting
   static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
 
-  private final int initialSize;
   private final int numPartitions;
-  private final int pageSizeBytes;
-  @VisibleForTesting
-  final int maxRecordSizeBytes;
   private final TaskMemoryManager taskMemoryManager;
   private final BlockManager blockManager;
   private final TaskContext taskContext;
   private final ShuffleWriteMetrics writeMetrics;
-  private long numRecordsInsertedSinceLastSpill = 0;
 
   /** Force this sorter to spill when there are this many elements in memory. For testing only */
   private final long numElementsForSpillThreshold;
@@ -98,8 +93,7 @@ final class ShuffleExternalSorter {
   // These variables are reset after spilling:
   @Nullable private ShuffleInMemorySorter inMemSorter;
   @Nullable private MemoryBlock currentPage = null;
-  private long currentPagePosition = -1;
-  private long freeSpaceInCurrentPage = 0;
+  private long pageCursor = -1;
 
   public ShuffleExternalSorter(
       TaskMemoryManager memoryManager,
@@ -108,42 +102,21 @@ final class ShuffleExternalSorter {
       int initialSize,
       int numPartitions,
       SparkConf conf,
-      ShuffleWriteMetrics writeMetrics) throws IOException {
+      ShuffleWriteMetrics writeMetrics) {
+    super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES,
+      memoryManager.pageSizeBytes()));
     this.taskMemoryManager = memoryManager;
     this.blockManager = blockManager;
     this.taskContext = taskContext;
-    this.initialSize = initialSize;
-    this.peakMemoryUsedBytes = initialSize;
     this.numPartitions = numPartitions;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
     this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
     this.numElementsForSpillThreshold =
       conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
-    this.pageSizeBytes = (int) Math.min(
-      PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes());
-    this.maxRecordSizeBytes = pageSizeBytes - 4;
     this.writeMetrics = writeMetrics;
-    initializeForWriting();
-
-    // preserve first page to ensure that we have at least one page to work with. Otherwise,
-    // other operators in the same task may starve this sorter (SPARK-9709).
-    acquireNewPageIfNecessary(pageSizeBytes);
-  }
-
-  /**
-   * Allocates new sort data structures. Called when creating the sorter and after each spill.
-   */
-  private void initializeForWriting() throws IOException {
-    // TODO: move this sizing calculation logic into a static method of sorter:
-    final long memoryRequested = initialSize * 8L;
-    final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryRequested);
-    if (memoryAcquired != memoryRequested) {
-      taskMemoryManager.releaseExecutionMemory(memoryAcquired);
-      throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
-    }
-
+    acquireMemory(initialSize * 8L);
     this.inMemSorter = new ShuffleInMemorySorter(initialSize);
-    numRecordsInsertedSinceLastSpill = 0;
+    this.peakMemoryUsedBytes = getMemoryUsage();
   }
 
   /**
@@ -242,6 +215,8 @@ final class ShuffleExternalSorter {
       }
     }
 
+    inMemSorter.reset();
+
     if (!isLastFile) {  // i.e. this is a spill file
       // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
       // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
@@ -266,9 +241,12 @@ final class ShuffleExternalSorter {
   /**
    * Sort and spill the current records in response to memory pressure.
    */
-  @VisibleForTesting
-  void spill() throws IOException {
-    assert(inMemSorter != null);
+  @Override
+  public long spill(long size, MemoryConsumer trigger) throws IOException {
+    if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) {
+      return 0L;
+    }
+
     logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
       Thread.currentThread().getId(),
       Utils.bytesToString(getMemoryUsage()),
@@ -276,13 +254,9 @@ final class ShuffleExternalSorter {
       spills.size() > 1 ? " times" : " time");
 
     writeSortedFile(false);
-    final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
-    inMemSorter = null;
-    taskMemoryManager.releaseExecutionMemory(inMemSorterMemoryUsage);
     final long spillSize = freeMemory();
     taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
-
-    initializeForWriting();
+    return spillSize;
   }
 
   private long getMemoryUsage() {
@@ -312,18 +286,12 @@ final class ShuffleExternalSorter {
     updatePeakMemoryUsed();
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
-      taskMemoryManager.freePage(block);
       memoryFreed += block.size();
-    }
-    if (inMemSorter != null) {
-      long sorterMemoryUsage = inMemSorter.getMemoryUsage();
-      inMemSorter = null;
-      taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
+      freePage(block);
     }
     allocatedPages.clear();
     currentPage = null;
-    currentPagePosition = -1;
-    freeSpaceInCurrentPage = 0;
+    pageCursor = 0;
     return memoryFreed;
   }
 
@@ -332,16 +300,16 @@ final class ShuffleExternalSorter {
    */
   public void cleanupResources() {
     freeMemory();
+    if (inMemSorter != null) {
+      long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+      inMemSorter = null;
+      releaseMemory(sorterMemoryUsage);
+    }
     for (SpillInfo spill : spills) {
       if (spill.file.exists() && !spill.file.delete()) {
         logger.error("Unable to delete spill file {}", spill.file.getPath());
       }
     }
-    if (inMemSorter != null) {
-      long sorterMemoryUsage = inMemSorter.getMemoryUsage();
-      inMemSorter = null;
-      taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
-    }
   }
 
   /**
@@ -352,16 +320,27 @@ final class ShuffleExternalSorter {
   private void growPointerArrayIfNecessary() throws IOException {
     assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
-      logger.debug("Attempting to expand sort pointer array");
-      final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
-      final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
-      final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryToGrowPointerArray);
-      if (memoryAcquired < memoryToGrowPointerArray) {
-        taskMemoryManager.releaseExecutionMemory(memoryAcquired);
-        spill();
+      long used = inMemSorter.getMemoryUsage();
+      long needed = used + inMemSorter.getMemoryToExpand();
+      try {
+        acquireMemory(needed);  // could trigger spilling
+      } catch (OutOfMemoryError e) {
+        // should have trigger spilling
+        assert(inMemSorter.hasSpaceForAnotherRecord());
+        return;
+      }
+      // check if spilling is triggered or not
+      if (inMemSorter.hasSpaceForAnotherRecord()) {
+        releaseMemory(needed);
       } else {
-        inMemSorter.expandPointerArray();
-        taskMemoryManager.releaseExecutionMemory(oldPointerArrayMemoryUsage);
+        try {
+          inMemSorter.expandPointerArray();
+          releaseMemory(used);
+        } catch (OutOfMemoryError oom) {
+          // Just in case that JVM had run out of memory
+          releaseMemory(needed);
+          spill();
+        }
       }
     }
   }
@@ -370,96 +349,46 @@ final class ShuffleExternalSorter {
    * Allocates more memory in order to insert an additional record. This will request additional
    * memory from the memory manager and spill if the requested memory can not be obtained.
    *
-   * @param requiredSpace the required space in the data page, in bytes, including space for storing
+   * @param required the required space in the data page, in bytes, including space for storing
    *                      the record size. This must be less than or equal to the page size (records
    *                      that exceed the page size are handled via a different code path which uses
    *                      special overflow pages).
    */
-  private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
-    growPointerArrayIfNecessary();
-    if (requiredSpace > freeSpaceInCurrentPage) {
-      logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
-        freeSpaceInCurrentPage);
-      // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
-      // without using the free space at the end of the current page. We should also do this for
-      // BytesToBytesMap.
-      if (requiredSpace > pageSizeBytes) {
-        throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
-          pageSizeBytes + ")");
-      } else {
-        currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
-        if (currentPage == null) {
-          spill();
-          currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
-          if (currentPage == null) {
-            throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
-          }
-        }
-        currentPagePosition = currentPage.getBaseOffset();
-        freeSpaceInCurrentPage = pageSizeBytes;
-        allocatedPages.add(currentPage);
-      }
+  private void acquireNewPageIfNecessary(int required) {
+    if (currentPage == null ||
+      pageCursor + required > currentPage.getBaseOffset() + currentPage.size() ) {
+      // TODO: try to find space in previous pages
+      currentPage = allocatePage(required);
+      pageCursor = currentPage.getBaseOffset();
+      allocatedPages.add(currentPage);
     }
   }
 
   /**
    * Write a record to the shuffle sorter.
    */
-  public void insertRecord(
-      Object recordBaseObject,
-      long recordBaseOffset,
-      int lengthInBytes,
-      int partitionId) throws IOException {
+  public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId)
+    throws IOException {
 
-    if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) {
+    // for tests
+    assert(inMemSorter != null);
+    if (inMemSorter.numRecords() > numElementsForSpillThreshold) {
       spill();
     }
 
     growPointerArrayIfNecessary();
     // Need 4 bytes to store the record length.
-    final int totalSpaceRequired = lengthInBytes + 4;
-
-    // --- Figure out where to insert the new record ----------------------------------------------
-
-    final MemoryBlock dataPage;
-    long dataPagePosition;
-    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
-    if (useOverflowPage) {
-      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
-      // The record is larger than the page size, so allocate a special overflow page just to hold
-      // that record.
-      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-      if (overflowPage == null) {
-        spill();
-        overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-        if (overflowPage == null) {
-          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
-        }
-      }
-      allocatedPages.add(overflowPage);
-      dataPage = overflowPage;
-      dataPagePosition = overflowPage.getBaseOffset();
-    } else {
-      // The record is small enough to fit in a regular data page, but the current page might not
-      // have enough space to hold it (or no pages have been allocated yet).
-      acquireNewPageIfNecessary(totalSpaceRequired);
-      dataPage = currentPage;
-      dataPagePosition = currentPagePosition;
-      // Update bookkeeping information
-      freeSpaceInCurrentPage -= totalSpaceRequired;
-      currentPagePosition += totalSpaceRequired;
-    }
-    final Object dataPageBaseObject = dataPage.getBaseObject();
-
-    final long recordAddress =
-      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
-    Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
-    dataPagePosition += 4;
-    Platform.copyMemory(
-      recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
-    assert(inMemSorter != null);
+    final int required = length + 4;
+    acquireNewPageIfNecessary(required);
+
+    assert(currentPage != null);
+    final Object base = currentPage.getBaseObject();
+    final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
+    Platform.putInt(base, pageCursor, length);
+    pageCursor += 4;
+    Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
+    pageCursor += length;
     inMemSorter.insertRecord(recordAddress, partitionId);
-    numRecordsInsertedSinceLastSpill += 1;
   }
 
   /**
@@ -475,6 +404,9 @@ final class ShuffleExternalSorter {
         // Do not count the final file towards the spill count.
         writeSortedFile(true);
         freeMemory();
+        long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+        inMemSorter = null;
+        releaseMemory(sorterMemoryUsage);
       }
       return spills.toArray(new SpillInfo[spills.size()]);
     } catch (IOException e) {

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index a8dee6c..e630575 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -37,33 +37,51 @@ final class ShuffleInMemorySorter {
    * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
    * records.
    */
-  private long[] pointerArray;
+  private long[] array;
 
   /**
    * The position in the pointer array where new records can be inserted.
    */
-  private int pointerArrayInsertPosition = 0;
+  private int pos = 0;
 
   public ShuffleInMemorySorter(int initialSize) {
     assert (initialSize > 0);
-    this.pointerArray = new long[initialSize];
-    this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE);
+    this.array = new long[initialSize];
+    this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
   }
 
-  public void expandPointerArray() {
-    final long[] oldArray = pointerArray;
+  public int numRecords() {
+    return pos;
+  }
+
+  public void reset() {
+    pos = 0;
+  }
+
+  private int newLength() {
     // Guard against overflow:
-    final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
-    pointerArray = new long[newLength];
-    System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+    return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
+  }
+
+  /**
+   * Returns the memory needed to expand
+   */
+  public long getMemoryToExpand() {
+    return ((long) (newLength() - array.length)) * 8;
+  }
+
+  public void expandPointerArray() {
+    final long[] oldArray = array;
+    array = new long[newLength()];
+    System.arraycopy(oldArray, 0, array, 0, oldArray.length);
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pointerArrayInsertPosition + 1 < pointerArray.length;
+    return pos < array.length;
   }
 
   public long getMemoryUsage() {
-    return pointerArray.length * 8L;
+    return array.length * 8L;
   }
 
   /**
@@ -78,15 +96,15 @@ final class ShuffleInMemorySorter {
    */
   public void insertRecord(long recordPointer, int partitionId) {
     if (!hasSpaceForAnotherRecord()) {
-      if (pointerArray.length == Integer.MAX_VALUE) {
+      if (array.length == Integer.MAX_VALUE) {
         throw new IllegalStateException("Sort pointer array has reached maximum size");
       } else {
         expandPointerArray();
       }
     }
-    pointerArray[pointerArrayInsertPosition] =
+    array[pos] =
         PackedRecordPointer.packPointer(recordPointer, partitionId);
-    pointerArrayInsertPosition++;
+    pos++;
   }
 
   /**
@@ -118,7 +136,7 @@ final class ShuffleInMemorySorter {
    * Return an iterator over record pointers in sorted order.
    */
   public ShuffleSorterIterator getSortedIterator() {
-    sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
-    return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+    sorter.sort(array, 0, pos, SORT_COMPARATOR);
+    return new ShuffleSorterIterator(pos, array);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index f6c5c94..e19b378 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -127,12 +127,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     open();
   }
 
-  @VisibleForTesting
-  public int maxRecordSizeBytes() {
-    assert(sorter != null);
-    return sorter.maxRecordSizeBytes;
-  }
-
   private void updatePeakMemoryUsed() {
     // sorter can be null if this writer is closed
     if (sorter != null) {

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index f035bda..e36709c 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -18,14 +18,20 @@
 package org.apache.spark.unsafe.map;
 
 import javax.annotation.Nullable;
+import java.io.File;
+import java.io.IOException;
 import java.util.Iterator;
 import java.util.LinkedList;
-import java.util.List;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.spark.SparkEnv;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.storage.BlockManager;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.array.LongArray;
@@ -33,7 +39,8 @@ import org.apache.spark.unsafe.bitset.BitSet;
 import org.apache.spark.unsafe.hash.Murmur3_x86_32;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
 
 /**
  * An append-only hash map where keys and values are contiguous regions of bytes.
@@ -54,7 +61,7 @@ import org.apache.spark.memory.TaskMemoryManager;
  * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
  * so we can pass records from this map directly into the sorter to sort records in place.
  */
-public final class BytesToBytesMap {
+public final class BytesToBytesMap extends MemoryConsumer {
 
   private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
 
@@ -62,27 +69,22 @@ public final class BytesToBytesMap {
 
   private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
 
-  /**
-   * Special record length that is placed after the last record in a data page.
-   */
-  private static final int END_OF_PAGE_MARKER = -1;
-
   private final TaskMemoryManager taskMemoryManager;
 
   /**
    * A linked list for tracking all allocated data pages so that we can free all of our memory.
    */
-  private final List<MemoryBlock> dataPages = new LinkedList<MemoryBlock>();
+  private final LinkedList<MemoryBlock> dataPages = new LinkedList<>();
 
   /**
    * The data page that will be used to store keys and values for new hashtable entries. When this
    * page becomes full, a new page will be allocated and this pointer will change to point to that
    * new page.
    */
-  private MemoryBlock currentDataPage = null;
+  private MemoryBlock currentPage = null;
 
   /**
-   * Offset into `currentDataPage` that points to the location where new data can be inserted into
+   * Offset into `currentPage` that points to the location where new data can be inserted into
    * the page. This does not incorporate the page's base offset.
    */
   private long pageCursor = 0;
@@ -117,6 +119,11 @@ public final class BytesToBytesMap {
   // absolute memory addresses.
 
   /**
+   * Whether or not the longArray can grow. We will not insert more elements if it's false.
+   */
+  private boolean canGrowArray = true;
+
+  /**
    * A {@link BitSet} used to track location of the map where the key is set.
    * Size of the bitset should be half of the size of the long array.
    */
@@ -164,13 +171,20 @@ public final class BytesToBytesMap {
 
   private long peakMemoryUsedBytes = 0L;
 
+  private final BlockManager blockManager;
+  private volatile MapIterator destructiveIterator = null;
+  private LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
+
   public BytesToBytesMap(
       TaskMemoryManager taskMemoryManager,
+      BlockManager blockManager,
       int initialCapacity,
       double loadFactor,
       long pageSizeBytes,
       boolean enablePerfMetrics) {
+    super(taskMemoryManager, pageSizeBytes);
     this.taskMemoryManager = taskMemoryManager;
+    this.blockManager = blockManager;
     this.loadFactor = loadFactor;
     this.loc = new Location();
     this.pageSizeBytes = pageSizeBytes;
@@ -187,18 +201,13 @@ public final class BytesToBytesMap {
         TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
     }
     allocate(initialCapacity);
-
-    // Acquire a new page as soon as we construct the map to ensure that we have at least
-    // one page to work with. Otherwise, other operators in the same task may starve this
-    // map (SPARK-9747).
-    acquireNewPage();
   }
 
   public BytesToBytesMap(
       TaskMemoryManager taskMemoryManager,
       int initialCapacity,
       long pageSizeBytes) {
-    this(taskMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
+    this(taskMemoryManager, initialCapacity, pageSizeBytes, false);
   }
 
   public BytesToBytesMap(
@@ -208,6 +217,7 @@ public final class BytesToBytesMap {
       boolean enablePerfMetrics) {
     this(
       taskMemoryManager,
+      SparkEnv.get() != null ? SparkEnv.get().blockManager() :  null,
       initialCapacity,
       0.70,
       pageSizeBytes,
@@ -219,61 +229,153 @@ public final class BytesToBytesMap {
    */
   public int numElements() { return numElements; }
 
-  public static final class BytesToBytesMapIterator implements Iterator<Location> {
+  public final class MapIterator implements Iterator<Location> {
 
-    private final int numRecords;
-    private final Iterator<MemoryBlock> dataPagesIterator;
+    private int numRecords;
     private final Location loc;
 
     private MemoryBlock currentPage = null;
-    private int currentRecordNumber = 0;
+    private int recordsInPage = 0;
     private Object pageBaseObject;
     private long offsetInPage;
 
     // If this iterator destructive or not. When it is true, it frees each page as it moves onto
     // next one.
     private boolean destructive = false;
-    private BytesToBytesMap bmap;
+    private UnsafeSorterSpillReader reader = null;
 
-    private BytesToBytesMapIterator(
-        int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc,
-        boolean destructive, BytesToBytesMap bmap) {
+    private MapIterator(int numRecords, Location loc, boolean destructive) {
       this.numRecords = numRecords;
-      this.dataPagesIterator = dataPagesIterator;
       this.loc = loc;
       this.destructive = destructive;
-      this.bmap = bmap;
-      if (dataPagesIterator.hasNext()) {
-        advanceToNextPage();
+      if (destructive) {
+        destructiveIterator = this;
       }
     }
 
     private void advanceToNextPage() {
-      if (destructive && currentPage != null) {
-        dataPagesIterator.remove();
-        this.bmap.taskMemoryManager.freePage(currentPage);
+      synchronized (this) {
+        int nextIdx = dataPages.indexOf(currentPage) + 1;
+        if (destructive && currentPage != null) {
+          dataPages.remove(currentPage);
+          freePage(currentPage);
+          nextIdx --;
+        }
+        if (dataPages.size() > nextIdx) {
+          currentPage = dataPages.get(nextIdx);
+          pageBaseObject = currentPage.getBaseObject();
+          offsetInPage = currentPage.getBaseOffset();
+          recordsInPage = Platform.getInt(pageBaseObject, offsetInPage);
+          offsetInPage += 4;
+        } else {
+          currentPage = null;
+          if (reader != null) {
+            // remove the spill file from disk
+            File file = spillWriters.removeFirst().getFile();
+            if (file != null && file.exists()) {
+              if (!file.delete()) {
+                logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+              }
+            }
+          }
+          try {
+            reader = spillWriters.getFirst().getReader(blockManager);
+            recordsInPage = -1;
+          } catch (IOException e) {
+            // Scala iterator does not handle exception
+            Platform.throwException(e);
+          }
+        }
       }
-      currentPage = dataPagesIterator.next();
-      pageBaseObject = currentPage.getBaseObject();
-      offsetInPage = currentPage.getBaseOffset();
     }
 
     @Override
     public boolean hasNext() {
-      return currentRecordNumber != numRecords;
+      if (numRecords == 0) {
+        if (reader != null) {
+          // remove the spill file from disk
+          File file = spillWriters.removeFirst().getFile();
+          if (file != null && file.exists()) {
+            if (!file.delete()) {
+              logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+            }
+          }
+        }
+      }
+      return numRecords > 0;
     }
 
     @Override
     public Location next() {
-      int totalLength = Platform.getInt(pageBaseObject, offsetInPage);
-      if (totalLength == END_OF_PAGE_MARKER) {
+      if (recordsInPage == 0) {
         advanceToNextPage();
-        totalLength = Platform.getInt(pageBaseObject, offsetInPage);
       }
-      loc.with(currentPage, offsetInPage);
-      offsetInPage += 4 + totalLength;
-      currentRecordNumber++;
-      return loc;
+      numRecords--;
+      if (currentPage != null) {
+        int totalLength = Platform.getInt(pageBaseObject, offsetInPage);
+        loc.with(currentPage, offsetInPage);
+        offsetInPage += 4 + totalLength;
+        recordsInPage --;
+        return loc;
+      } else {
+        assert(reader != null);
+        if (!reader.hasNext()) {
+          advanceToNextPage();
+        }
+        try {
+          reader.loadNext();
+        } catch (IOException e) {
+          // Scala iterator does not handle exception
+          Platform.throwException(e);
+        }
+        loc.with(reader.getBaseObject(), reader.getBaseOffset(), reader.getRecordLength());
+        return loc;
+      }
+    }
+
+    public long spill(long numBytes) throws IOException {
+      synchronized (this) {
+        if (!destructive || dataPages.size() == 1) {
+          return 0L;
+        }
+
+        // TODO: use existing ShuffleWriteMetrics
+        ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
+
+        long released = 0L;
+        while (dataPages.size() > 0) {
+          MemoryBlock block = dataPages.getLast();
+          // The currentPage is used, cannot be released
+          if (block == currentPage) {
+            break;
+          }
+
+          Object base = block.getBaseObject();
+          long offset = block.getBaseOffset();
+          int numRecords = Platform.getInt(base, offset);
+          offset += 4;
+          final UnsafeSorterSpillWriter writer =
+            new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords);
+          while (numRecords > 0) {
+            int length = Platform.getInt(base, offset);
+            writer.write(base, offset + 4, length, 0);
+            offset += 4 + length;
+            numRecords--;
+          }
+          writer.close();
+          spillWriters.add(writer);
+
+          dataPages.removeLast();
+          released += block.size();
+          freePage(block);
+
+          if (released >= numBytes) {
+            break;
+          }
+        }
+
+        return released;
+      }
     }
 
     @Override
@@ -290,8 +392,8 @@ public final class BytesToBytesMap {
    * If any other lookups or operations are performed on this map while iterating over it, including
    * `lookup()`, the behavior of the returned iterator is undefined.
    */
-  public BytesToBytesMapIterator iterator() {
-    return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this);
+  public MapIterator iterator() {
+    return new MapIterator(numElements, loc, false);
   }
 
   /**
@@ -304,8 +406,8 @@ public final class BytesToBytesMap {
    * If any other lookups or operations are performed on this map while iterating over it, including
    * `lookup()`, the behavior of the returned iterator is undefined.
    */
-  public BytesToBytesMapIterator destructiveIterator() {
-    return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this);
+  public MapIterator destructiveIterator() {
+    return new MapIterator(numElements, loc, true);
   }
 
   /**
@@ -314,11 +416,8 @@ public final class BytesToBytesMap {
    *
    * This function always return the same {@link Location} instance to avoid object allocation.
    */
-  public Location lookup(
-      Object keyBaseObject,
-      long keyBaseOffset,
-      int keyRowLengthBytes) {
-    safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc);
+  public Location lookup(Object keyBase, long keyOffset, int keyLength) {
+    safeLookup(keyBase, keyOffset, keyLength, loc);
     return loc;
   }
 
@@ -327,18 +426,14 @@ public final class BytesToBytesMap {
    *
    * This is a thread-safe version of `lookup`, could be used by multiple threads.
    */
-  public void safeLookup(
-      Object keyBaseObject,
-      long keyBaseOffset,
-      int keyRowLengthBytes,
-      Location loc) {
+  public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) {
     assert(bitset != null);
     assert(longArray != null);
 
     if (enablePerfMetrics) {
       numKeyLookups++;
     }
-    final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes);
+    final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength);
     int pos = hashcode & mask;
     int step = 1;
     while (true) {
@@ -354,16 +449,16 @@ public final class BytesToBytesMap {
         if ((int) (stored) == hashcode) {
           // Full hash code matches.  Let's compare the keys for equality.
           loc.with(pos, hashcode, true);
-          if (loc.getKeyLength() == keyRowLengthBytes) {
+          if (loc.getKeyLength() == keyLength) {
             final MemoryLocation keyAddress = loc.getKeyAddress();
-            final Object storedKeyBaseObject = keyAddress.getBaseObject();
-            final long storedKeyBaseOffset = keyAddress.getBaseOffset();
+            final Object storedkeyBase = keyAddress.getBaseObject();
+            final long storedkeyOffset = keyAddress.getBaseOffset();
             final boolean areEqual = ByteArrayMethods.arrayEquals(
-              keyBaseObject,
-              keyBaseOffset,
-              storedKeyBaseObject,
-              storedKeyBaseOffset,
-              keyRowLengthBytes
+              keyBase,
+              keyOffset,
+              storedkeyBase,
+              storedkeyOffset,
+              keyLength
             );
             if (areEqual) {
               return;
@@ -410,18 +505,18 @@ public final class BytesToBytesMap {
         taskMemoryManager.getOffsetInPage(fullKeyAddress));
     }
 
-    private void updateAddressesAndSizes(final Object page, final long offsetInPage) {
-      long position = offsetInPage;
-      final int totalLength = Platform.getInt(page, position);
+    private void updateAddressesAndSizes(final Object base, final long offset) {
+      long position = offset;
+      final int totalLength = Platform.getInt(base, position);
       position += 4;
-      keyLength = Platform.getInt(page, position);
+      keyLength = Platform.getInt(base, position);
       position += 4;
       valueLength = totalLength - keyLength - 4;
 
-      keyMemoryLocation.setObjAndOffset(page, position);
+      keyMemoryLocation.setObjAndOffset(base, position);
 
       position += keyLength;
-      valueMemoryLocation.setObjAndOffset(page, position);
+      valueMemoryLocation.setObjAndOffset(base, position);
     }
 
     private Location with(int pos, int keyHashcode, boolean isDefined) {
@@ -444,6 +539,19 @@ public final class BytesToBytesMap {
     }
 
     /**
+     * This is only used for spilling
+     */
+    private Location with(Object base, long offset, int length) {
+      this.isDefined = true;
+      this.memoryPage = null;
+      keyLength = Platform.getInt(base, offset);
+      valueLength = length - 4 - keyLength;
+      keyMemoryLocation.setObjAndOffset(base, offset + 4);
+      valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength);
+      return this;
+    }
+
+    /**
      * Returns the memory page that contains the current record.
      * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}.
      */
@@ -517,9 +625,9 @@ public final class BytesToBytesMap {
      * As an example usage, here's the proper way to store a new key:
      * </p>
      * <pre>
-     *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
+     *   Location loc = map.lookup(keyBase, keyOffset, keyLength);
      *   if (!loc.isDefined()) {
-     *     if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+     *     if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) {
      *       // handle failure to grow map (by spilling, for example)
      *     }
      *   }
@@ -531,113 +639,59 @@ public final class BytesToBytesMap {
      * @return true if the put() was successful and false if the put() failed because memory could
      *         not be acquired.
      */
-    public boolean putNewKey(
-        Object keyBaseObject,
-        long keyBaseOffset,
-        int keyLengthBytes,
-        Object valueBaseObject,
-        long valueBaseOffset,
-        int valueLengthBytes) {
+    public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
+        Object valueBase, long valueOffset, int valueLength) {
       assert (!isDefined) : "Can only set value once for a key";
-      assert (keyLengthBytes % 8 == 0);
-      assert (valueLengthBytes % 8 == 0);
+      assert (keyLength % 8 == 0);
+      assert (valueLength % 8 == 0);
       assert(bitset != null);
       assert(longArray != null);
 
-      if (numElements == MAX_CAPACITY) {
-        throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
+      if (numElements == MAX_CAPACITY || !canGrowArray) {
+        return false;
       }
 
       // Here, we'll copy the data into our data pages. Because we only store a relative offset from
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
       // (8 byte key length) (key) (value)
-      final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
-
-      // --- Figure out where to insert the new record ---------------------------------------------
-
-      final MemoryBlock dataPage;
-      final Object dataPageBaseObject;
-      final long dataPageInsertOffset;
-      boolean useOverflowPage = requiredSize > pageSizeBytes - 8;
-      if (useOverflowPage) {
-        // The record is larger than the page size, so allocate a special overflow page just to hold
-        // that record.
-        final long overflowPageSize = requiredSize + 8;
-        MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-        if (overflowPage == null) {
-          logger.debug("Failed to acquire {} bytes of memory", overflowPageSize);
+      final long recordLength = 8 + keyLength + valueLength;
+      if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
+        if (!acquireNewPage(recordLength + 4L)) {
           return false;
         }
-        dataPages.add(overflowPage);
-        dataPage = overflowPage;
-        dataPageBaseObject = overflowPage.getBaseObject();
-        dataPageInsertOffset = overflowPage.getBaseOffset();
-      } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
-        // The record can fit in a data page, but either we have not allocated any pages yet or
-        // the current page does not have enough space.
-        if (currentDataPage != null) {
-          // There wasn't enough space in the current page, so write an end-of-page marker:
-          final Object pageBaseObject = currentDataPage.getBaseObject();
-          final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
-          Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
-        }
-        if (!acquireNewPage()) {
-          return false;
-        }
-        dataPage = currentDataPage;
-        dataPageBaseObject = currentDataPage.getBaseObject();
-        dataPageInsertOffset = currentDataPage.getBaseOffset();
-      } else {
-        // There is enough space in the current data page.
-        dataPage = currentDataPage;
-        dataPageBaseObject = currentDataPage.getBaseObject();
-        dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor;
       }
 
       // --- Append the key and value data to the current data page --------------------------------
-
-      long insertCursor = dataPageInsertOffset;
-
-      // Compute all of our offsets up-front:
-      final long recordOffset = insertCursor;
-      insertCursor += 4;
-      final long keyLengthOffset = insertCursor;
-      insertCursor += 4;
-      final long keyDataOffsetInPage = insertCursor;
-      insertCursor += keyLengthBytes;
-      final long valueDataOffsetInPage = insertCursor;
-      insertCursor += valueLengthBytes; // word used to store the value size
-
-      Platform.putInt(dataPageBaseObject, recordOffset,
-        keyLengthBytes + valueLengthBytes + 4);
-      Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
-      // Copy the key
-      Platform.copyMemory(
-        keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
-      // Copy the value
-      Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
-        valueDataOffsetInPage, valueLengthBytes);
-
-      // --- Update bookeeping data structures -----------------------------------------------------
-
-      if (useOverflowPage) {
-        // Store the end-of-page marker at the end of the data page
-        Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
-      } else {
-        pageCursor += requiredSize;
-      }
-
+      final Object base = currentPage.getBaseObject();
+      long offset = currentPage.getBaseOffset() + pageCursor;
+      final long recordOffset = offset;
+      Platform.putInt(base, offset, keyLength + valueLength + 4);
+      Platform.putInt(base, offset + 4, keyLength);
+      offset += 8;
+      Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength);
+      offset += keyLength;
+      Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
+
+      // --- Update bookkeeping data structures -----------------------------------------------------
+      offset = currentPage.getBaseOffset();
+      Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
+      pageCursor += recordLength;
       numElements++;
       bitset.set(pos);
       final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
-        dataPage, recordOffset);
+        currentPage, recordOffset);
       longArray.set(pos * 2, storedKeyAddress);
       longArray.set(pos * 2 + 1, keyHashcode);
       updateAddressesAndSizes(storedKeyAddress);
       isDefined = true;
+
       if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
-        growAndRehash();
+        try {
+          growAndRehash();
+        } catch (OutOfMemoryError oom) {
+          canGrowArray = false;
+        }
       }
       return true;
     }
@@ -647,18 +701,26 @@ public final class BytesToBytesMap {
    * Acquire a new page from the memory manager.
    * @return whether there is enough space to allocate the new page.
    */
-  private boolean acquireNewPage() {
-    MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
-    if (newPage == null) {
-      logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+  private boolean acquireNewPage(long required) {
+    try {
+      currentPage = allocatePage(required);
+    } catch (OutOfMemoryError e) {
       return false;
     }
-    dataPages.add(newPage);
-    pageCursor = 0;
-    currentDataPage = newPage;
+    dataPages.add(currentPage);
+    Platform.putInt(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0);
+    pageCursor = 4;
     return true;
   }
 
+  @Override
+  public long spill(long size, MemoryConsumer trigger) throws IOException {
+    if (trigger != this && destructiveIterator != null) {
+      return destructiveIterator.spill(size);
+    }
+    return 0L;
+  }
+
   /**
    * Allocate new data structures for this map. When calling this outside of the constructor,
    * make sure to keep references to the old data structures so that you can free them.
@@ -670,6 +732,7 @@ public final class BytesToBytesMap {
     // The capacity needs to be divisible by 64 so that our bit set can be sized properly
     capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
     assert (capacity <= MAX_CAPACITY);
+    acquireMemory(capacity * 16);
     longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
     bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
 
@@ -678,22 +741,42 @@ public final class BytesToBytesMap {
   }
 
   /**
+   * Free the memory used by longArray.
+   */
+  public void freeArray() {
+    updatePeakMemoryUsed();
+    if (longArray != null) {
+      long used = longArray.memoryBlock().size();
+      longArray = null;
+      releaseMemory(used);
+      bitset = null;
+    }
+  }
+
+  /**
    * Free all allocated memory associated with this map, including the storage for keys and values
    * as well as the hash map array itself.
    *
    * This method is idempotent and can be called multiple times.
    */
   public void free() {
-    updatePeakMemoryUsed();
-    longArray = null;
-    bitset = null;
+    freeArray();
     Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
     while (dataPagesIterator.hasNext()) {
       MemoryBlock dataPage = dataPagesIterator.next();
       dataPagesIterator.remove();
-      taskMemoryManager.freePage(dataPage);
+      freePage(dataPage);
     }
     assert(dataPages.isEmpty());
+
+    while (!spillWriters.isEmpty()) {
+      File file = spillWriters.removeFirst().getFile();
+      if (file != null && file.exists()) {
+        if (!file.delete()) {
+          logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+        }
+      }
+    }
   }
 
   public TaskMemoryManager getTaskMemoryManager() {
@@ -782,7 +865,13 @@ public final class BytesToBytesMap {
     final int oldCapacity = (int) oldBitSet.capacity();
 
     // Allocate the new data structures
-    allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
+    try {
+      allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
+    } catch (OutOfMemoryError oom) {
+      longArray = oldLongArray;
+      bitset = oldBitSet;
+      throw oom;
+    }
 
     // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
     for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
@@ -806,6 +895,7 @@ public final class BytesToBytesMap {
         }
       }
     }
+    releaseMemory(oldLongArray.memoryBlock().size());
 
     if (enablePerfMetrics) {
       timeSpentResizingNs += System.nanoTime() - resizeStartTime;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org