You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/14 21:56:21 UTC

spark git commit: [SPARK-9031] Merge BlockObjectWriter and DiskBlockObject writer to remove abstract class

Repository: spark
Updated Branches:
  refs/heads/master 8fb3a65cb -> d267c2834


[SPARK-9031] Merge BlockObjectWriter and DiskBlockObject writer to remove abstract class

BlockObjectWriter has only one concrete non-test class, DiskBlockObjectWriter. In order to simplify the code in preparation for other refactorings, I think that we should remove this base class and have only DiskBlockObjectWriter.

While at one time we may have planned to have multiple BlockObjectWriter implementations, that doesn't seem to have happened, so the extra abstraction seems unnecessary.

Author: Josh Rosen <jo...@databricks.com>

Closes #7391 from JoshRosen/shuffle-write-interface-refactoring and squashes the following commits:

c418e33 [Josh Rosen] Fix compilation
5047995 [Josh Rosen] Fix comments
d5dc548 [Josh Rosen] Update references in comments
89dc797 [Josh Rosen] Rename test suite.
5755918 [Josh Rosen] Remove unnecessary val in case class
1607c91 [Josh Rosen] Merge BlockObjectWriter and DiskBlockObjectWriter


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d267c283
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d267c283
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d267c283

Branch: refs/heads/master
Commit: d267c2834a639aaebd0559355c6a82613abb689b
Parents: 8fb3a65
Author: Josh Rosen <jo...@databricks.com>
Authored: Tue Jul 14 12:56:17 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Jul 14 12:56:17 2015 -0700

----------------------------------------------------------------------
 .../sort/BypassMergeSortShuffleWriter.java      |   8 +-
 .../unsafe/UnsafeShuffleExternalSorter.java     |   2 +-
 .../unsafe/sort/UnsafeSorterSpillWriter.java    |   4 +-
 .../shuffle/FileShuffleBlockResolver.scala      |   8 +-
 .../shuffle/IndexShuffleBlockResolver.scala     |   2 +-
 .../spark/shuffle/hash/HashShuffleWriter.scala  |   4 +-
 .../org/apache/spark/storage/BlockManager.scala |   2 +-
 .../spark/storage/BlockObjectWriter.scala       | 256 -------------------
 .../spark/storage/DiskBlockObjectWriter.scala   | 234 +++++++++++++++++
 .../spark/util/collection/ChainedBuffer.scala   |   2 +-
 .../spark/util/collection/ExternalSorter.scala  |   4 +-
 .../util/collection/PartitionedPairBuffer.scala |   1 -
 .../PartitionedSerializedPairBuffer.scala       |   5 +-
 .../WritablePartitionedPairCollection.scala     |   8 +-
 .../BypassMergeSortShuffleWriterSuite.scala     |   4 +-
 .../spark/storage/BlockObjectWriterSuite.scala  | 173 -------------
 .../storage/DiskBlockObjectWriterSuite.scala    | 173 +++++++++++++
 .../PartitionedSerializedPairBufferSuite.scala  |  52 ++--
 18 files changed, 459 insertions(+), 483 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index d3d6280..0b8b604 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -75,7 +75,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
   private final Serializer serializer;
 
   /** Array of file writers, one for each partition */
-  private BlockObjectWriter[] partitionWriters;
+  private DiskBlockObjectWriter[] partitionWriters;
 
   public BypassMergeSortShuffleWriter(
       SparkConf conf,
@@ -101,7 +101,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
     }
     final SerializerInstance serInstance = serializer.newInstance();
     final long openStartTime = System.nanoTime();
-    partitionWriters = new BlockObjectWriter[numPartitions];
+    partitionWriters = new DiskBlockObjectWriter[numPartitions];
     for (int i = 0; i < numPartitions; i++) {
       final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
         blockManager.diskBlockManager().createTempShuffleBlock();
@@ -121,7 +121,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
       partitionWriters[partitioner.getPartition(key)].write(key, record._2());
     }
 
-    for (BlockObjectWriter writer : partitionWriters) {
+    for (DiskBlockObjectWriter writer : partitionWriters) {
       writer.commitAndClose();
     }
   }
@@ -169,7 +169,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
     if (partitionWriters != null) {
       try {
         final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
-        for (BlockObjectWriter writer : partitionWriters) {
+        for (DiskBlockObjectWriter writer : partitionWriters) {
           // This method explicitly does _not_ throw exceptions:
           writer.revertPartialWritesAndClose();
           if (!diskBlockManager.getFile(writer.blockId()).delete()) {

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 5628957..1d46043 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -157,7 +157,7 @@ final class UnsafeShuffleExternalSorter {
 
     // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
     // after SPARK-5581 is fixed.
-    BlockObjectWriter writer;
+    DiskBlockObjectWriter writer;
 
     // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
     // be an API to directly transfer bytes from managed memory to the disk writer, we buffer

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index b8d6665..71eed29 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -26,7 +26,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.serializer.DummySerializerInstance;
 import org.apache.spark.storage.BlockId;
 import org.apache.spark.storage.BlockManager;
-import org.apache.spark.storage.BlockObjectWriter;
+import org.apache.spark.storage.DiskBlockObjectWriter;
 import org.apache.spark.storage.TempLocalBlockId;
 import org.apache.spark.unsafe.PlatformDependent;
 
@@ -47,7 +47,7 @@ final class UnsafeSorterSpillWriter {
   private final File file;
   private final BlockId blockId;
   private final int numRecordsToWrite;
-  private BlockObjectWriter writer;
+  private DiskBlockObjectWriter writer;
   private int numRecordsSpilled = 0;
 
   public UnsafeSorterSpillWriter(

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index 6c3b308..f6a96d8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVecto
 
 /** A group of writers for a ShuffleMapTask, one writer per reducer. */
 private[spark] trait ShuffleWriterGroup {
-  val writers: Array[BlockObjectWriter]
+  val writers: Array[DiskBlockObjectWriter]
 
   /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
   def releaseWriters(success: Boolean)
@@ -113,15 +113,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
 
       val openStartTime = System.nanoTime
       val serializerInstance = serializer.newInstance()
-      val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+      val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) {
         fileGroup = getUnusedFileGroup()
-        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+        Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
           blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize,
             writeMetrics)
         }
       } else {
-        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+        Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
           val blockFile = blockManager.diskBlockManager.getFile(blockId)
           // Because of previous failures, the shuffle file may already exist on this machine.

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index d9c63b6..fae6955 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -114,7 +114,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
 }
 
 private[spark] object IndexShuffleBlockResolver {
-  // No-op reduce ID used in interactions with disk store and BlockObjectWriter.
+  // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter.
   // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort
   // shuffle outputs for several reduces are glommed into a single file.
   // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId.

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index eb87cee..41df70c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -22,7 +22,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle._
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
 
 private[spark] class HashShuffleWriter[K, V](
     shuffleBlockResolver: FileShuffleBlockResolver,
@@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V](
 
   private def commitWritesAndBuildStatus(): MapStatus = {
     // Commit the writes. Get the size of each bucket block (total block size).
-    val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter =>
+    val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter =>
       writer.commitAndClose()
       writer.fileSegment().length
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 1beafa1..8649367 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -648,7 +648,7 @@ private[spark] class BlockManager(
       file: File,
       serializerInstance: SerializerInstance,
       bufferSize: Int,
-      writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = {
+      writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
     val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
     val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
     new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream,

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
deleted file mode 100644
index 7eeabd1..0000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ /dev/null
@@ -1,256 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream}
-import java.nio.channels.FileChannel
-
-import org.apache.spark.Logging
-import org.apache.spark.serializer.{SerializerInstance, SerializationStream}
-import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.util.Utils
-
-/**
- * An interface for writing JVM objects to some underlying storage. This interface allows
- * appending data to an existing block, and can guarantee atomicity in the case of faults
- * as it allows the caller to revert partial writes.
- *
- * This interface does not support concurrent writes. Also, once the writer has
- * been opened, it cannot be reopened again.
- */
-private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream {
-
-  def open(): BlockObjectWriter
-
-  def close()
-
-  def isOpen: Boolean
-
-  /**
-   * Flush the partial writes and commit them as a single atomic block.
-   */
-  def commitAndClose(): Unit
-
-  /**
-   * Reverts writes that haven't been flushed yet. Callers should invoke this function
-   * when there are runtime exceptions. This method will not throw, though it may be
-   * unsuccessful in truncating written data.
-   */
-  def revertPartialWritesAndClose()
-
-  /**
-   * Writes a key-value pair.
-   */
-  def write(key: Any, value: Any)
-
-  /**
-   * Notify the writer that a record worth of bytes has been written with OutputStream#write.
-   */
-  def recordWritten()
-
-  /**
-   * Returns the file segment of committed data that this Writer has written.
-   * This is only valid after commitAndClose() has been called.
-   */
-  def fileSegment(): FileSegment
-}
-
-/**
- * BlockObjectWriter which writes directly to a file on disk. Appends to the given file.
- */
-private[spark] class DiskBlockObjectWriter(
-    blockId: BlockId,
-    file: File,
-    serializerInstance: SerializerInstance,
-    bufferSize: Int,
-    compressStream: OutputStream => OutputStream,
-    syncWrites: Boolean,
-    // These write metrics concurrently shared with other active BlockObjectWriter's who
-    // are themselves performing writes. All updates must be relative.
-    writeMetrics: ShuffleWriteMetrics)
-  extends BlockObjectWriter(blockId)
-  with Logging
-{
-
-  /** The file channel, used for repositioning / truncating the file. */
-  private var channel: FileChannel = null
-  private var bs: OutputStream = null
-  private var fos: FileOutputStream = null
-  private var ts: TimeTrackingOutputStream = null
-  private var objOut: SerializationStream = null
-  private var initialized = false
-  private var hasBeenClosed = false
-  private var commitAndCloseHasBeenCalled = false
-
-  /**
-   * Cursors used to represent positions in the file.
-   *
-   * xxxxxxxx|--------|---       |
-   *         ^        ^          ^
-   *         |        |        finalPosition
-   *         |      reportedPosition
-   *       initialPosition
-   *
-   * initialPosition: Offset in the file where we start writing. Immutable.
-   * reportedPosition: Position at the time of the last update to the write metrics.
-   * finalPosition: Offset where we stopped writing. Set on closeAndCommit() then never changed.
-   * -----: Current writes to the underlying file.
-   * xxxxx: Existing contents of the file.
-   */
-  private val initialPosition = file.length()
-  private var finalPosition: Long = -1
-  private var reportedPosition = initialPosition
-
-  /**
-   * Keep track of number of records written and also use this to periodically
-   * output bytes written since the latter is expensive to do for each record.
-   */
-  private var numRecordsWritten = 0
-
-  override def open(): BlockObjectWriter = {
-    if (hasBeenClosed) {
-      throw new IllegalStateException("Writer already closed. Cannot be reopened.")
-    }
-    fos = new FileOutputStream(file, true)
-    ts = new TimeTrackingOutputStream(writeMetrics, fos)
-    channel = fos.getChannel()
-    bs = compressStream(new BufferedOutputStream(ts, bufferSize))
-    objOut = serializerInstance.serializeStream(bs)
-    initialized = true
-    this
-  }
-
-  override def close() {
-    if (initialized) {
-      Utils.tryWithSafeFinally {
-        if (syncWrites) {
-          // Force outstanding writes to disk and track how long it takes
-          objOut.flush()
-          val start = System.nanoTime()
-          fos.getFD.sync()
-          writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
-        }
-      } {
-        objOut.close()
-      }
-
-      channel = null
-      bs = null
-      fos = null
-      ts = null
-      objOut = null
-      initialized = false
-      hasBeenClosed = true
-    }
-  }
-
-  override def isOpen: Boolean = objOut != null
-
-  override def commitAndClose(): Unit = {
-    if (initialized) {
-      // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
-      //       serializer stream and the lower level stream.
-      objOut.flush()
-      bs.flush()
-      close()
-      finalPosition = file.length()
-      // In certain compression codecs, more bytes are written after close() is called
-      writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
-    } else {
-      finalPosition = file.length()
-    }
-    commitAndCloseHasBeenCalled = true
-  }
-
-  // Discard current writes. We do this by flushing the outstanding writes and then
-  // truncating the file to its initial position.
-  override def revertPartialWritesAndClose() {
-    try {
-      if (initialized) {
-        writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
-        writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
-        objOut.flush()
-        bs.flush()
-        close()
-      }
-
-      val truncateStream = new FileOutputStream(file, true)
-      try {
-        truncateStream.getChannel.truncate(initialPosition)
-      } finally {
-        truncateStream.close()
-      }
-    } catch {
-      case e: Exception =>
-        logError("Uncaught exception while reverting partial writes to file " + file, e)
-    }
-  }
-
-  override def write(key: Any, value: Any) {
-    if (!initialized) {
-      open()
-    }
-
-    objOut.writeKey(key)
-    objOut.writeValue(value)
-    recordWritten()
-  }
-
-  override def write(b: Int): Unit = throw new UnsupportedOperationException()
-
-  override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
-    if (!initialized) {
-      open()
-    }
-
-    bs.write(kvBytes, offs, len)
-  }
-
-  override def recordWritten(): Unit = {
-    numRecordsWritten += 1
-    writeMetrics.incShuffleRecordsWritten(1)
-
-    if (numRecordsWritten % 32 == 0) {
-      updateBytesWritten()
-    }
-  }
-
-  override def fileSegment(): FileSegment = {
-    if (!commitAndCloseHasBeenCalled) {
-      throw new IllegalStateException(
-        "fileSegment() is only valid after commitAndClose() has been called")
-    }
-    new FileSegment(file, initialPosition, finalPosition - initialPosition)
-  }
-
-  /**
-   * Report the number of bytes written in this writer's shuffle write metrics.
-   * Note that this is only valid before the underlying streams are closed.
-   */
-  private def updateBytesWritten() {
-    val pos = channel.position()
-    writeMetrics.incShuffleBytesWritten(pos - reportedPosition)
-    reportedPosition = pos
-  }
-
-  // For testing
-  private[spark] override def flush() {
-    objOut.flush()
-    bs.flush()
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
new file mode 100644
index 0000000..49d9154
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream}
+import java.nio.channels.FileChannel
+
+import org.apache.spark.Logging
+import org.apache.spark.serializer.{SerializerInstance, SerializationStream}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.util.Utils
+
+/**
+ * A class for writing JVM objects directly to a file on disk. This class allows data to be appended
+ * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to
+ * revert partial writes.
+ *
+ * This class does not support concurrent writes. Also, once the writer has been opened it cannot be
+ * reopened again.
+ */
+private[spark] class DiskBlockObjectWriter(
+    val blockId: BlockId,
+    file: File,
+    serializerInstance: SerializerInstance,
+    bufferSize: Int,
+    compressStream: OutputStream => OutputStream,
+    syncWrites: Boolean,
+    // These write metrics concurrently shared with other active DiskBlockObjectWriters who
+    // are themselves performing writes. All updates must be relative.
+    writeMetrics: ShuffleWriteMetrics)
+  extends OutputStream
+  with Logging {
+
+  /** The file channel, used for repositioning / truncating the file. */
+  private var channel: FileChannel = null
+  private var bs: OutputStream = null
+  private var fos: FileOutputStream = null
+  private var ts: TimeTrackingOutputStream = null
+  private var objOut: SerializationStream = null
+  private var initialized = false
+  private var hasBeenClosed = false
+  private var commitAndCloseHasBeenCalled = false
+
+  /**
+   * Cursors used to represent positions in the file.
+   *
+   * xxxxxxxx|--------|---       |
+   *         ^        ^          ^
+   *         |        |        finalPosition
+   *         |      reportedPosition
+   *       initialPosition
+   *
+   * initialPosition: Offset in the file where we start writing. Immutable.
+   * reportedPosition: Position at the time of the last update to the write metrics.
+   * finalPosition: Offset where we stopped writing. Set on closeAndCommit() then never changed.
+   * -----: Current writes to the underlying file.
+   * xxxxx: Existing contents of the file.
+   */
+  private val initialPosition = file.length()
+  private var finalPosition: Long = -1
+  private var reportedPosition = initialPosition
+
+  /**
+   * Keep track of number of records written and also use this to periodically
+   * output bytes written since the latter is expensive to do for each record.
+   */
+  private var numRecordsWritten = 0
+
+  def open(): DiskBlockObjectWriter = {
+    if (hasBeenClosed) {
+      throw new IllegalStateException("Writer already closed. Cannot be reopened.")
+    }
+    fos = new FileOutputStream(file, true)
+    ts = new TimeTrackingOutputStream(writeMetrics, fos)
+    channel = fos.getChannel()
+    bs = compressStream(new BufferedOutputStream(ts, bufferSize))
+    objOut = serializerInstance.serializeStream(bs)
+    initialized = true
+    this
+  }
+
+  override def close() {
+    if (initialized) {
+      Utils.tryWithSafeFinally {
+        if (syncWrites) {
+          // Force outstanding writes to disk and track how long it takes
+          objOut.flush()
+          val start = System.nanoTime()
+          fos.getFD.sync()
+          writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
+        }
+      } {
+        objOut.close()
+      }
+
+      channel = null
+      bs = null
+      fos = null
+      ts = null
+      objOut = null
+      initialized = false
+      hasBeenClosed = true
+    }
+  }
+
+  def isOpen: Boolean = objOut != null
+
+  /**
+   * Flush the partial writes and commit them as a single atomic block.
+   */
+  def commitAndClose(): Unit = {
+    if (initialized) {
+      // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
+      //       serializer stream and the lower level stream.
+      objOut.flush()
+      bs.flush()
+      close()
+      finalPosition = file.length()
+      // In certain compression codecs, more bytes are written after close() is called
+      writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
+    } else {
+      finalPosition = file.length()
+    }
+    commitAndCloseHasBeenCalled = true
+  }
+
+
+  /**
+   * Reverts writes that haven't been flushed yet. Callers should invoke this function
+   * when there are runtime exceptions. This method will not throw, though it may be
+   * unsuccessful in truncating written data.
+   */
+  def revertPartialWritesAndClose() {
+    // Discard current writes. We do this by flushing the outstanding writes and then
+    // truncating the file to its initial position.
+    try {
+      if (initialized) {
+        writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
+        writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
+        objOut.flush()
+        bs.flush()
+        close()
+      }
+
+      val truncateStream = new FileOutputStream(file, true)
+      try {
+        truncateStream.getChannel.truncate(initialPosition)
+      } finally {
+        truncateStream.close()
+      }
+    } catch {
+      case e: Exception =>
+        logError("Uncaught exception while reverting partial writes to file " + file, e)
+    }
+  }
+
+  /**
+   * Writes a key-value pair.
+   */
+  def write(key: Any, value: Any) {
+    if (!initialized) {
+      open()
+    }
+
+    objOut.writeKey(key)
+    objOut.writeValue(value)
+    recordWritten()
+  }
+
+  override def write(b: Int): Unit = throw new UnsupportedOperationException()
+
+  override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
+    if (!initialized) {
+      open()
+    }
+
+    bs.write(kvBytes, offs, len)
+  }
+
+  /**
+   * Notify the writer that a record worth of bytes has been written with OutputStream#write.
+   */
+  def recordWritten(): Unit = {
+    numRecordsWritten += 1
+    writeMetrics.incShuffleRecordsWritten(1)
+
+    if (numRecordsWritten % 32 == 0) {
+      updateBytesWritten()
+    }
+  }
+
+  /**
+   * Returns the file segment of committed data that this Writer has written.
+   * This is only valid after commitAndClose() has been called.
+   */
+  def fileSegment(): FileSegment = {
+    if (!commitAndCloseHasBeenCalled) {
+      throw new IllegalStateException(
+        "fileSegment() is only valid after commitAndClose() has been called")
+    }
+    new FileSegment(file, initialPosition, finalPosition - initialPosition)
+  }
+
+  /**
+   * Report the number of bytes written in this writer's shuffle write metrics.
+   * Note that this is only valid before the underlying streams are closed.
+   */
+  private def updateBytesWritten() {
+    val pos = channel.position()
+    writeMetrics.incShuffleBytesWritten(pos - reportedPosition)
+    reportedPosition = pos
+  }
+
+  // For testing
+  private[spark] override def flush() {
+    objOut.flush()
+    bs.flush()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
index 516aaa4..ae60f3b 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
@@ -37,7 +37,7 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
   private var _size: Long = 0
 
   /**
-   * Feed bytes from this buffer into a BlockObjectWriter.
+   * Feed bytes from this buffer into a DiskBlockObjectWriter.
    *
    * @param pos Offset in the buffer to read from.
    * @param os OutputStream to read into.

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 757dec6..ba7ec83 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -30,7 +30,7 @@ import org.apache.spark._
 import org.apache.spark.serializer._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
-import org.apache.spark.storage.{BlockId, BlockObjectWriter}
+import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
 
 /**
  * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -250,7 +250,7 @@ private[spark] class ExternalSorter[K, V, C](
     // These variables are reset after each flush
     var objectsWritten: Long = 0
     var spillMetrics: ShuffleWriteMetrics = null
-    var writer: BlockObjectWriter = null
+    var writer: DiskBlockObjectWriter = null
     def openWriter(): Unit = {
       assert (writer == null && spillMetrics == null)
       spillMetrics = new ShuffleWriteMetrics

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
index 04bb7fc..f5844d5 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
@@ -19,7 +19,6 @@ package org.apache.spark.util.collection
 
 import java.util.Comparator
 
-import org.apache.spark.storage.BlockObjectWriter
 import org.apache.spark.util.collection.WritablePartitionedPairCollection._
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
index ae9a487..87a786b 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
@@ -21,9 +21,8 @@ import java.io.InputStream
 import java.nio.IntBuffer
 import java.util.Comparator
 
-import org.apache.spark.SparkEnv
 import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
 import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
 
 /**
@@ -136,7 +135,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
       // current position in the meta buffer in ints
       var pos = 0
 
-      def writeNext(writer: BlockObjectWriter): Unit = {
+      def writeNext(writer: DiskBlockObjectWriter): Unit = {
         val keyStart = getKeyStartPos(metaBuffer, pos)
         val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
         pos += RECORD_SIZE

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index 7bc5989..38848e9 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -19,7 +19,7 @@ package org.apache.spark.util.collection
 
 import java.util.Comparator
 
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
 
 /**
  * A common interface for size-tracking collections of key-value pairs that
@@ -51,7 +51,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
     new WritablePartitionedIterator {
       private[this] var cur = if (it.hasNext) it.next() else null
 
-      def writeNext(writer: BlockObjectWriter): Unit = {
+      def writeNext(writer: DiskBlockObjectWriter): Unit = {
         writer.write(cur._1._2, cur._2)
         cur = if (it.hasNext) it.next() else null
       }
@@ -91,11 +91,11 @@ private[spark] object WritablePartitionedPairCollection {
 }
 
 /**
- * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element
+ * Iterator that writes elements to a DiskBlockObjectWriter instead of returning them. Each element
  * has an associated partition.
  */
 private[spark] trait WritablePartitionedIterator {
-  def writeNext(writer: BlockObjectWriter): Unit
+  def writeNext(writer: DiskBlockObjectWriter): Unit
 
   def hasNext(): Boolean
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 542f8f4..cc7342f 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -68,8 +68,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
       any[SerializerInstance],
       anyInt(),
       any[ShuffleWriteMetrics]
-    )).thenAnswer(new Answer[BlockObjectWriter] {
-      override def answer(invocation: InvocationOnMock): BlockObjectWriter = {
+    )).thenAnswer(new Answer[DiskBlockObjectWriter] {
+      override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = {
         val args = invocation.getArguments
         new DiskBlockObjectWriter(
           args(0).asInstanceOf[BlockId],

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
deleted file mode 100644
index 7bdea72..0000000
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ /dev/null
@@ -1,173 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.storage
-
-import java.io.File
-
-import org.scalatest.BeforeAndAfterEach
-
-import org.apache.spark.SparkConf
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.util.Utils
-
-class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
-
-  var tempDir: File = _
-
-  override def beforeEach(): Unit = {
-    tempDir = Utils.createTempDir()
-  }
-
-  override def afterEach(): Unit = {
-    Utils.deleteRecursively(tempDir)
-  }
-
-  test("verify write metrics") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
-      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
-
-    writer.write(Long.box(20), Long.box(30))
-    // Record metrics update on every write
-    assert(writeMetrics.shuffleRecordsWritten === 1)
-    // Metrics don't update on every write
-    assert(writeMetrics.shuffleBytesWritten == 0)
-    // After 32 writes, metrics should update
-    for (i <- 0 until 32) {
-      writer.flush()
-      writer.write(Long.box(i), Long.box(i))
-    }
-    assert(writeMetrics.shuffleBytesWritten > 0)
-    assert(writeMetrics.shuffleRecordsWritten === 33)
-    writer.commitAndClose()
-    assert(file.length() == writeMetrics.shuffleBytesWritten)
-  }
-
-  test("verify write metrics on revert") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
-      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
-
-    writer.write(Long.box(20), Long.box(30))
-    // Record metrics update on every write
-    assert(writeMetrics.shuffleRecordsWritten === 1)
-    // Metrics don't update on every write
-    assert(writeMetrics.shuffleBytesWritten == 0)
-    // After 32 writes, metrics should update
-    for (i <- 0 until 32) {
-      writer.flush()
-      writer.write(Long.box(i), Long.box(i))
-    }
-    assert(writeMetrics.shuffleBytesWritten > 0)
-    assert(writeMetrics.shuffleRecordsWritten === 33)
-    writer.revertPartialWritesAndClose()
-    assert(writeMetrics.shuffleBytesWritten == 0)
-    assert(writeMetrics.shuffleRecordsWritten == 0)
-  }
-
-  test("Reopening a closed block writer") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
-      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
-
-    writer.open()
-    writer.close()
-    intercept[IllegalStateException] {
-      writer.open()
-    }
-  }
-
-  test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
-      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
-    for (i <- 1 to 1000) {
-      writer.write(i, i)
-    }
-    writer.commitAndClose()
-    val bytesWritten = writeMetrics.shuffleBytesWritten
-    assert(writeMetrics.shuffleRecordsWritten === 1000)
-    writer.revertPartialWritesAndClose()
-    assert(writeMetrics.shuffleRecordsWritten === 1000)
-    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
-  }
-
-  test("commitAndClose() should be idempotent") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
-      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
-    for (i <- 1 to 1000) {
-      writer.write(i, i)
-    }
-    writer.commitAndClose()
-    val bytesWritten = writeMetrics.shuffleBytesWritten
-    val writeTime = writeMetrics.shuffleWriteTime
-    assert(writeMetrics.shuffleRecordsWritten === 1000)
-    writer.commitAndClose()
-    assert(writeMetrics.shuffleRecordsWritten === 1000)
-    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
-    assert(writeMetrics.shuffleWriteTime === writeTime)
-  }
-
-  test("revertPartialWritesAndClose() should be idempotent") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
-      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
-    for (i <- 1 to 1000) {
-      writer.write(i, i)
-    }
-    writer.revertPartialWritesAndClose()
-    val bytesWritten = writeMetrics.shuffleBytesWritten
-    val writeTime = writeMetrics.shuffleWriteTime
-    assert(writeMetrics.shuffleRecordsWritten === 0)
-    writer.revertPartialWritesAndClose()
-    assert(writeMetrics.shuffleRecordsWritten === 0)
-    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
-    assert(writeMetrics.shuffleWriteTime === writeTime)
-  }
-
-  test("fileSegment() can only be called after commitAndClose() has been called") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
-      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
-    for (i <- 1 to 1000) {
-      writer.write(i, i)
-    }
-    intercept[IllegalStateException] {
-      writer.fileSegment()
-    }
-    writer.close()
-  }
-
-  test("commitAndClose() without ever opening or writing") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
-      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
-    writer.commitAndClose()
-    assert(writer.fileSegment().length === 0)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
new file mode 100644
index 0000000..66af6e1
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import java.io.File
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.util.Utils
+
+class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+  var tempDir: File = _
+
+  override def beforeEach(): Unit = {
+    tempDir = Utils.createTempDir()
+  }
+
+  override def afterEach(): Unit = {
+    Utils.deleteRecursively(tempDir)
+  }
+
+  test("verify write metrics") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+
+    writer.write(Long.box(20), Long.box(30))
+    // Record metrics update on every write
+    assert(writeMetrics.shuffleRecordsWritten === 1)
+    // Metrics don't update on every write
+    assert(writeMetrics.shuffleBytesWritten == 0)
+    // After 32 writes, metrics should update
+    for (i <- 0 until 32) {
+      writer.flush()
+      writer.write(Long.box(i), Long.box(i))
+    }
+    assert(writeMetrics.shuffleBytesWritten > 0)
+    assert(writeMetrics.shuffleRecordsWritten === 33)
+    writer.commitAndClose()
+    assert(file.length() == writeMetrics.shuffleBytesWritten)
+  }
+
+  test("verify write metrics on revert") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+
+    writer.write(Long.box(20), Long.box(30))
+    // Record metrics update on every write
+    assert(writeMetrics.shuffleRecordsWritten === 1)
+    // Metrics don't update on every write
+    assert(writeMetrics.shuffleBytesWritten == 0)
+    // After 32 writes, metrics should update
+    for (i <- 0 until 32) {
+      writer.flush()
+      writer.write(Long.box(i), Long.box(i))
+    }
+    assert(writeMetrics.shuffleBytesWritten > 0)
+    assert(writeMetrics.shuffleRecordsWritten === 33)
+    writer.revertPartialWritesAndClose()
+    assert(writeMetrics.shuffleBytesWritten == 0)
+    assert(writeMetrics.shuffleRecordsWritten == 0)
+  }
+
+  test("Reopening a closed block writer") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+
+    writer.open()
+    writer.close()
+    intercept[IllegalStateException] {
+      writer.open()
+    }
+  }
+
+  test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    writer.commitAndClose()
+    val bytesWritten = writeMetrics.shuffleBytesWritten
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    writer.revertPartialWritesAndClose()
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+  }
+
+  test("commitAndClose() should be idempotent") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    writer.commitAndClose()
+    val bytesWritten = writeMetrics.shuffleBytesWritten
+    val writeTime = writeMetrics.shuffleWriteTime
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    writer.commitAndClose()
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+    assert(writeMetrics.shuffleWriteTime === writeTime)
+  }
+
+  test("revertPartialWritesAndClose() should be idempotent") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    writer.revertPartialWritesAndClose()
+    val bytesWritten = writeMetrics.shuffleBytesWritten
+    val writeTime = writeMetrics.shuffleWriteTime
+    assert(writeMetrics.shuffleRecordsWritten === 0)
+    writer.revertPartialWritesAndClose()
+    assert(writeMetrics.shuffleRecordsWritten === 0)
+    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+    assert(writeMetrics.shuffleWriteTime === writeTime)
+  }
+
+  test("fileSegment() can only be called after commitAndClose() has been called") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    intercept[IllegalStateException] {
+      writer.fileSegment()
+    }
+    writer.close()
+  }
+
+  test("commitAndClose() without ever opening or writing") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    writer.commitAndClose()
+    assert(writer.fileSegment().length === 0)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
index 6d2459d..3b67f62 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
@@ -17,15 +17,20 @@
 
 package org.apache.spark.util.collection
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
 
 import com.google.common.io.ByteStreams
 
+import org.mockito.Matchers.any
+import org.mockito.Mockito._
+import org.mockito.Mockito.RETURNS_SMART_NULLS
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
 import org.scalatest.Matchers._
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.storage.{FileSegment, BlockObjectWriter}
+import org.apache.spark.storage.DiskBlockObjectWriter
 
 class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
   test("OrderedInputStream single record") {
@@ -79,13 +84,13 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
     val struct = SomeStruct("something", 5)
     buffer.insert(4, 10, struct)
     val it = buffer.destructiveSortedWritablePartitionedIterator(None)
-    val writer = new SimpleBlockObjectWriter
+    val (writer, baos) = createMockWriter()
     assert(it.hasNext)
     it.nextPartition should be (4)
     it.writeNext(writer)
     assert(!it.hasNext)
 
-    val stream = serializerInstance.deserializeStream(writer.getInputStream)
+    val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
     stream.readObject[AnyRef]() should be (10)
     stream.readObject[AnyRef]() should be (struct)
   }
@@ -101,7 +106,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
     buffer.insert(5, 3, struct3)
 
     val it = buffer.destructiveSortedWritablePartitionedIterator(None)
-    val writer = new SimpleBlockObjectWriter
+    val (writer, baos) = createMockWriter()
     assert(it.hasNext)
     it.nextPartition should be (4)
     it.writeNext(writer)
@@ -113,7 +118,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
     it.writeNext(writer)
     assert(!it.hasNext)
 
-    val stream = serializerInstance.deserializeStream(writer.getInputStream)
+    val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
     val iter = stream.asIterator
     iter.next() should be (2)
     iter.next() should be (struct2)
@@ -123,26 +128,21 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
     iter.next() should be (struct1)
     assert(!iter.hasNext)
   }
-}
-
-case class SomeStruct(val str: String, val num: Int)
-
-class SimpleBlockObjectWriter extends BlockObjectWriter(null) {
-  val baos = new ByteArrayOutputStream()
 
-  override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
-    baos.write(bytes, offs, len)
+  def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = {
+    val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS)
+    val baos = new ByteArrayOutputStream()
+    when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] {
+      override def answer(invocationOnMock: InvocationOnMock): Unit = {
+        val args = invocationOnMock.getArguments
+        val bytes = args(0).asInstanceOf[Array[Byte]]
+        val offset = args(1).asInstanceOf[Int]
+        val length = args(2).asInstanceOf[Int]
+        baos.write(bytes, offset, length)
+      }
+    })
+    (writer, baos)
   }
-
-  def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray)
-
-  override def open(): BlockObjectWriter = this
-  override def close(): Unit = { }
-  override def isOpen: Boolean = true
-  override def commitAndClose(): Unit = { }
-  override def revertPartialWritesAndClose(): Unit = { }
-  override def fileSegment(): FileSegment = null
-  override def write(key: Any, value: Any): Unit = { }
-  override def recordWritten(): Unit = { }
-  override def write(b: Int): Unit = { }
 }
+
+case class SomeStruct(str: String, num: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org