You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2015/05/01 08:12:04 UTC
[2/2] spark git commit: [SPARK-4550] In sort-based shuffle,
store map outputs in serialized form
[SPARK-4550] In sort-based shuffle, store map outputs in serialized form
Refer to the JIRA for the design doc and some perf results.
I wanted to call out some of the more possibly controversial changes up front:
* Map outputs are only stored in serialized form when Kryo is in use. I'm still unsure whether Java-serialized objects can be relocated. At the very least, Java serialization writes out a stream header which causes problems with the current approach, so I decided to leave investigating this to future work.
* The shuffle now explicitly operates on key-value pairs instead of any object. Data is written to shuffle files in alternating keys and values instead of key-value tuples. `BlockObjectWriter.write` now accepts a key argument and a value argument instead of any object.
* The map output buffer can hold a max of Integer.MAX_VALUE bytes. Though this wouldn't be terribly difficult to change.
* When spilling occurs, the objects that still in memory at merge time end up serialized and deserialized an extra time.
Author: Sandy Ryza <sa...@cloudera.com>
Closes #4450 from sryza/sandy-spark-4550 and squashes the following commits:
8c70dd9 [Sandy Ryza] Fix serialization
9c16fe6 [Sandy Ryza] Fix a couple tests and move getAutoReset to KryoSerializerInstance
6c54e06 [Sandy Ryza] Fix scalastyle
d8462d8 [Sandy Ryza] SPARK-4550
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0a2b15ce
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0a2b15ce
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0a2b15ce
Branch: refs/heads/master
Commit: 0a2b15ce43cf6096e1a7ae060b7c8a4010ce3b92
Parents: a9fc505
Author: Sandy Ryza <sa...@cloudera.com>
Authored: Thu Apr 30 23:14:14 2015 -0700
Committer: Patrick Wendell <pa...@databricks.com>
Committed: Thu Apr 30 23:14:14 2015 -0700
----------------------------------------------------------------------
.../spark/serializer/KryoSerializer.scala | 10 +
.../apache/spark/serializer/Serializer.scala | 31 +++
.../spark/shuffle/hash/HashShuffleWriter.scala | 2 +-
.../spark/storage/BlockObjectWriter.scala | 37 ++-
.../storage/ShuffleBlockFetcherIterator.scala | 6 +-
.../spark/util/collection/ChainedBuffer.scala | 144 +++++++++++
.../util/collection/ExternalAppendOnlyMap.scala | 6 +-
.../spark/util/collection/ExternalSorter.scala | 144 ++++++-----
.../spark/util/collection/PairIterator.scala | 24 ++
.../collection/PartitionedAppendOnlyMap.scala | 44 ++++
.../util/collection/PartitionedPairBuffer.scala | 92 +++++++
.../PartitionedSerializedPairBuffer.scala | 254 +++++++++++++++++++
.../collection/SizeTrackingAppendOnlyMap.scala | 2 +-
.../collection/SizeTrackingPairBuffer.scala | 86 -------
.../collection/SizeTrackingPairCollection.scala | 34 ---
.../WritablePartitionedPairCollection.scala | 113 +++++++++
.../spark/serializer/KryoSerializerSuite.scala | 15 ++
.../spark/serializer/TestSerializer.scala | 4 +-
.../shuffle/hash/HashShuffleManagerSuite.scala | 12 +-
.../spark/storage/BlockObjectWriterSuite.scala | 8 +-
.../util/collection/ChainedBufferSuite.scala | 143 +++++++++++
.../util/collection/ExternalSorterSuite.scala | 189 ++++++++++----
.../PartitionedSerializedPairBufferSuite.scala | 149 +++++++++++
.../sql/execution/SparkSqlSerializer2.scala | 38 ++-
.../apache/spark/tools/StoragePerfTester.scala | 5 +-
25 files changed, 1321 insertions(+), 271 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 754832b..b7bc087 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -200,6 +200,16 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
override def deserializeStream(s: InputStream): DeserializationStream = {
new KryoDeserializationStream(kryo, s)
}
+
+ /**
+ * Returns true if auto-reset is on. The only reason this would be false is if the user-supplied
+ * registrator explicitly turns auto-reset off.
+ */
+ def getAutoReset(): Boolean = {
+ val field = classOf[Kryo].getDeclaredField("autoReset")
+ field.setAccessible(true)
+ field.get(kryo).asInstanceOf[Boolean]
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index ca6e971..c381672 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -101,7 +101,12 @@ abstract class SerializerInstance {
*/
@DeveloperApi
abstract class SerializationStream {
+ /** The most general-purpose method to write an object. */
def writeObject[T: ClassTag](t: T): SerializationStream
+ /** Writes the object representing the key of a key-value pair. */
+ def writeKey[T: ClassTag](key: T): SerializationStream = writeObject(key)
+ /** Writes the object representing the value of a key-value pair. */
+ def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value)
def flush(): Unit
def close(): Unit
@@ -120,7 +125,12 @@ abstract class SerializationStream {
*/
@DeveloperApi
abstract class DeserializationStream {
+ /** The most general-purpose method to read an object. */
def readObject[T: ClassTag](): T
+ /** Reads the object representing the key of a key-value pair. */
+ def readKey[T: ClassTag](): T = readObject[T]()
+ /** Reads the object representing the value of a key-value pair. */
+ def readValue[T: ClassTag](): T = readObject[T]()
def close(): Unit
/**
@@ -141,4 +151,25 @@ abstract class DeserializationStream {
DeserializationStream.this.close()
}
}
+
+ /**
+ * Read the elements of this stream through an iterator over key-value pairs. This can only be
+ * called once, as reading each element will consume data from the input source.
+ */
+ def asKeyValueIterator: Iterator[(Any, Any)] = new NextIterator[(Any, Any)] {
+ override protected def getNext() = {
+ try {
+ (readKey[Any](), readValue[Any]())
+ } catch {
+ case eof: EOFException => {
+ finished = true
+ null
+ }
+ }
+ }
+
+ override protected def close() {
+ DeserializationStream.this.close()
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 755f17d..cd27c9e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -63,7 +63,7 @@ private[spark] class HashShuffleWriter[K, V](
for (elem <- iter) {
val bucketId = dep.partitioner.getPartition(elem._1)
- shuffle.writers(bucketId).write(elem)
+ shuffle.writers(bucketId).write(elem._1, elem._2)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 1483379..499dd97 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -33,7 +33,7 @@ import org.apache.spark.util.Utils
* This interface does not support concurrent writes. Also, once the writer has
* been opened, it cannot be reopened again.
*/
-private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
+private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream {
def open(): BlockObjectWriter
@@ -54,9 +54,14 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
def revertPartialWritesAndClose()
/**
- * Writes an object.
+ * Writes a key-value pair.
*/
- def write(value: Any)
+ def write(key: Any, value: Any)
+
+ /**
+ * Notify the writer that a record worth of bytes has been written with writeBytes.
+ */
+ def recordWritten()
/**
* Returns the file segment of committed data that this Writer has written.
@@ -203,12 +208,32 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def write(value: Any) {
+ override def write(key: Any, value: Any) {
+ if (!initialized) {
+ open()
+ }
+
+ objOut.writeKey(key)
+ objOut.writeValue(value)
+ numRecordsWritten += 1
+ writeMetrics.incShuffleRecordsWritten(1)
+
+ if (numRecordsWritten % 32 == 0) {
+ updateBytesWritten()
+ }
+ }
+
+ override def write(b: Int): Unit = throw new UnsupportedOperationException()
+
+ override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
if (!initialized) {
open()
}
- objOut.writeObject(value)
+ bs.write(kvBytes, offs, len)
+ }
+
+ override def recordWritten(): Unit = {
numRecordsWritten += 1
writeMetrics.incShuffleRecordsWritten(1)
@@ -238,7 +263,7 @@ private[spark] class DiskBlockObjectWriter(
}
// For testing
- private[spark] def flush() {
+ private[spark] override def flush() {
objOut.flush()
bs.flush()
}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index f337952..d0faab6 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,14 +17,12 @@
package org.apache.spark.storage
-import java.io.{InputStream, IOException}
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import scala.util.{Failure, Success, Try}
+import scala.util.{Failure, Try}
import org.apache.spark.{Logging, TaskContext}
-import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.{SerializerInstance, Serializer}
@@ -301,7 +299,7 @@ final class ShuffleBlockFetcherIterator(
// the scheduler gets a FetchFailedException.
Try(buf.createInputStream()).map { is0 =>
val is = blockManager.wrapForCompression(blockId, is0)
- val iter = serializerInstance.deserializeStream(is).asIterator
+ val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
CompletionIterator[Any, Iterator[Any]](iter, {
// Once the iterator is exhausted, release the buffer and set currentResult to null
// so we don't release it again in cleanup.
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
new file mode 100644
index 0000000..a60bffe
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io.OutputStream
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The
+ * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts
+ * of memory and needing to copy the full contents. The disadvantage is that the contents don't
+ * occupy a contiguous segment of memory.
+ */
+private[spark] class ChainedBuffer(chunkSize: Int) {
+ private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt
+ assert(math.pow(2, chunkSizeLog2).toInt == chunkSize,
+ s"ChainedBuffer chunk size $chunkSize must be a power of two")
+ private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
+ private var _size: Int = _
+
+ /**
+ * Feed bytes from this buffer into a BlockObjectWriter.
+ *
+ * @param pos Offset in the buffer to read from.
+ * @param os OutputStream to read into.
+ * @param len Number of bytes to read.
+ */
+ def read(pos: Int, os: OutputStream, len: Int): Unit = {
+ if (pos + len > _size) {
+ throw new IndexOutOfBoundsException(
+ s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
+ }
+ var chunkIndex = pos >> chunkSizeLog2
+ var posInChunk = pos - (chunkIndex << chunkSizeLog2)
+ var written = 0
+ while (written < len) {
+ val toRead = math.min(len - written, chunkSize - posInChunk)
+ os.write(chunks(chunkIndex), posInChunk, toRead)
+ written += toRead
+ chunkIndex += 1
+ posInChunk = 0
+ }
+ }
+
+ /**
+ * Read bytes from this buffer into a byte array.
+ *
+ * @param pos Offset in the buffer to read from.
+ * @param bytes Byte array to read into.
+ * @param offs Offset in the byte array to read to.
+ * @param len Number of bytes to read.
+ */
+ def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
+ if (pos + len > _size) {
+ throw new IndexOutOfBoundsException(
+ s"Read of $len bytes at position $pos would go past size of buffer")
+ }
+ var chunkIndex = pos >> chunkSizeLog2
+ var posInChunk = pos - (chunkIndex << chunkSizeLog2)
+ var written = 0
+ while (written < len) {
+ val toRead = math.min(len - written, chunkSize - posInChunk)
+ System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
+ written += toRead
+ chunkIndex += 1
+ posInChunk = 0
+ }
+ }
+
+ /**
+ * Write bytes from a byte array into this buffer.
+ *
+ * @param pos Offset in the buffer to write to.
+ * @param bytes Byte array to write from.
+ * @param offs Offset in the byte array to write from.
+ * @param len Number of bytes to write.
+ */
+ def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
+ if (pos > _size) {
+ throw new IndexOutOfBoundsException(
+ s"Write at position $pos starts after end of buffer ${_size}")
+ }
+ // Grow if needed
+ val endChunkIndex = (pos + len - 1) >> chunkSizeLog2
+ while (endChunkIndex >= chunks.length) {
+ chunks += new Array[Byte](chunkSize)
+ }
+
+ var chunkIndex = pos >> chunkSizeLog2
+ var posInChunk = pos - (chunkIndex << chunkSizeLog2)
+ var written = 0
+ while (written < len) {
+ val toWrite = math.min(len - written, chunkSize - posInChunk)
+ System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
+ written += toWrite
+ chunkIndex += 1
+ posInChunk = 0
+ }
+
+ _size = math.max(_size, pos + len)
+ }
+
+ /**
+ * Total size of buffer that can be written to without allocating additional memory.
+ */
+ def capacity: Int = chunks.size * chunkSize
+
+ /**
+ * Size of the logical buffer.
+ */
+ def size: Int = _size
+}
+
+/**
+ * Output stream that writes to a ChainedBuffer.
+ */
+private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
+ private var pos = 0
+
+ override def write(b: Int): Unit = {
+ throw new UnsupportedOperationException()
+ }
+
+ override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
+ chainedBuffer.write(pos, bytes, offs, len)
+ pos += len
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index f912049..b850973 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -174,7 +174,7 @@ class ExternalAppendOnlyMap[K, V, C](
val it = currentMap.destructiveSortedIterator(keyComparator)
while (it.hasNext) {
val kv = it.next()
- writer.write(kv)
+ writer.write(kv._1, kv._2)
objectsWritten += 1
if (objectsWritten == serializerBatchSize) {
@@ -435,7 +435,9 @@ class ExternalAppendOnlyMap[K, V, C](
*/
private def readNextItem(): (K, C) = {
try {
- val item = deserializeStream.readObject().asInstanceOf[(K, C)]
+ val k = deserializeStream.readKey().asInstanceOf[K]
+ val c = deserializeStream.readValue().asInstanceOf[C]
+ val item = (k, c)
objectsRead += 1
if (objectsRead == serializerBatchSize) {
objectsRead = 0
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 4ed8a74..b7306cd 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable
import com.google.common.io.ByteStreams
import org.apache.spark._
-import org.apache.spark.serializer.{DeserializationStream, Serializer}
+import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.storage.{BlockObjectWriter, BlockId}
@@ -66,10 +66,11 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId}
*
* At a high level, this class works internally as follows:
*
- * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if
- * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers,
- * we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to
- * avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner).
+ * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if
+ * we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we
+ * don't. Inside these buffers, we sort elements by partition ID and then possibly also by key.
+ * To avoid calling the partitioner multiple times with each key, we store the partition ID
+ * alongside each record.
*
* - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first
* by partition ID and possibly second by key or by hash code of the key, if we want to do
@@ -96,7 +97,7 @@ private[spark] class ExternalSorter[K, V, C](
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
serializer: Option[Serializer] = None)
- extends Logging with Spillable[SizeTrackingPairCollection[(Int, K), C]] {
+ extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] {
private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
private val shouldPartition = numPartitions > 1
@@ -126,11 +127,22 @@ private[spark] class ExternalSorter[K, V, C](
if (shouldPartition) partitioner.get.getPartition(key) else 0
}
+ private val metaInitialRecords = 256
+ private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
+ private val useSerializedPairBuffer =
+ !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
+ ser.isInstanceOf[KryoSerializer] &&
+ serInstance.asInstanceOf[KryoSerializerInstance].getAutoReset
+
// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
// store them in an array buffer.
- private var map = new SizeTrackingAppendOnlyMap[(Int, K), C]
- private var buffer = new SizeTrackingPairBuffer[(Int, K), C]
+ private var map = new PartitionedAppendOnlyMap[K, C]
+ private var buffer = if (useSerializedPairBuffer) {
+ new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
+ } else {
+ new PartitionedPairBuffer[K, C]
+ }
// Total spilling statistics
private var _diskBytesSpilled = 0L
@@ -163,33 +175,6 @@ private[spark] class ExternalSorter[K, V, C](
}
})
- // A comparator for (Int, K) pairs that orders them by only their partition ID
- private val partitionComparator: Comparator[(Int, K)] = new Comparator[(Int, K)] {
- override def compare(a: (Int, K), b: (Int, K)): Int = {
- a._1 - b._1
- }
- }
-
- // A comparator that orders (Int, K) pairs by partition ID and then possibly by key
- private val partitionKeyComparator: Comparator[(Int, K)] = {
- if (ordering.isDefined || aggregator.isDefined) {
- // Sort by partition ID then key comparator
- new Comparator[(Int, K)] {
- override def compare(a: (Int, K), b: (Int, K)): Int = {
- val partitionDiff = a._1 - b._1
- if (partitionDiff != 0) {
- partitionDiff
- } else {
- keyComparator.compare(a._2, b._2)
- }
- }
- }
- } else {
- // Just sort it by partition ID
- partitionComparator
- }
- }
-
// Information about a spilled file. Includes sizes in bytes of "batches" written by the
// serializer as we periodically reset its stream, as well as number of elements in each
// partition, used to efficiently keep track of partitions when merging.
@@ -221,16 +206,18 @@ private[spark] class ExternalSorter[K, V, C](
} else if (bypassMergeSort) {
// SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
if (records.hasNext) {
- spillToPartitionFiles(records.map { kv =>
- ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
- })
+ spillToPartitionFiles(
+ WritablePartitionedIterator.fromIterator(records.map { kv =>
+ ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
+ })
+ )
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
- buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
+ buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
}
}
@@ -248,11 +235,15 @@ private[spark] class ExternalSorter[K, V, C](
if (usingMap) {
if (maybeSpill(map, map.estimateSize())) {
- map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+ map = new PartitionedAppendOnlyMap[K, C]
}
} else {
if (maybeSpill(buffer, buffer.estimateSize())) {
- buffer = new SizeTrackingPairBuffer[(Int, K), C]
+ buffer = if (useSerializedPairBuffer) {
+ new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
+ } else {
+ new PartitionedPairBuffer[K, C]
+ }
}
}
}
@@ -260,7 +251,7 @@ private[spark] class ExternalSorter[K, V, C](
/**
* Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
*/
- override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
+ override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
if (bypassMergeSort) {
spillToPartitionFiles(collection)
} else {
@@ -277,7 +268,7 @@ private[spark] class ExternalSorter[K, V, C](
*
* @param collection whichever collection we're using (map or buffer)
*/
- private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
+ private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = {
assert(!bypassMergeSort)
// Because these files may be read during shuffle, their compression must be controlled by
@@ -308,14 +299,10 @@ private[spark] class ExternalSorter[K, V, C](
var success = false
try {
- val it = collection.destructiveSortedIterator(partitionKeyComparator)
+ val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
- val elem = it.next()
- val partitionId = elem._1._1
- val key = elem._1._2
- val value = elem._2
- writer.write(key)
- writer.write(value)
+ val partitionId = it.nextPartition()
+ it.writeNext(writer)
elementsPerPartition(partitionId) += 1
objectsWritten += 1
@@ -357,11 +344,11 @@ private[spark] class ExternalSorter[K, V, C](
*
* @param collection whichever collection we're using (map or buffer)
*/
- private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
- spillToPartitionFiles(collection.iterator)
+ private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = {
+ spillToPartitionFiles(collection.writablePartitionedIterator())
}
- private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = {
+ private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = {
assert(bypassMergeSort)
// Create our file writers if we haven't done so yet
@@ -385,11 +372,8 @@ private[spark] class ExternalSorter[K, V, C](
// No need to sort stuff, just write each element out
while (iterator.hasNext) {
- val elem = iterator.next()
- val partitionId = elem._1._1
- val key = elem._1._2
- val value = elem._2
- partitionWriters(partitionId).write((key, value))
+ val partitionId = iterator.nextPartition()
+ iterator.writeNext(partitionWriters(partitionId))
}
}
@@ -618,8 +602,8 @@ private[spark] class ExternalSorter[K, V, C](
if (finished || deserializeStream == null) {
return null
}
- val k = deserializeStream.readObject().asInstanceOf[K]
- val c = deserializeStream.readObject().asInstanceOf[C]
+ val k = deserializeStream.readKey().asInstanceOf[K]
+ val c = deserializeStream.readValue().asInstanceOf[C]
lastPartitionId = partitionId
// Start reading the next batch if we're done with this one
indexInBatch += 1
@@ -695,27 +679,27 @@ private[spark] class ExternalSorter[K, V, C](
*/
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
val usingMap = aggregator.isDefined
- val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+ val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
if (spills.isEmpty && partitionWriters == null) {
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
// we don't even need to sort by anything other than partition ID
if (!ordering.isDefined) {
// The user hasn't requested sorted keys, so only sort by partition ID, not key
- groupByPartition(collection.destructiveSortedIterator(partitionComparator))
+ groupByPartition(collection.partitionedDestructiveSortedIterator(None))
} else {
// We do need to sort by both partition ID and key
- groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator))
+ groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator)))
}
} else if (bypassMergeSort) {
// Read data from each partition file and merge it together with the data in memory;
// note that there's no ordering or aggregator in this case -- we just partition objects
- val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator))
+ val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None))
collIter.map { case (partitionId, values) =>
(partitionId, values ++ readPartitionFile(partitionWriters(partitionId)))
}
} else {
// Merge spilled and in-memory data
- merge(spills, collection.destructiveSortedIterator(partitionKeyComparator))
+ merge(spills, collection.partitionedDestructiveSortedIterator(comparator))
}
}
@@ -762,15 +746,29 @@ private[spark] class ExternalSorter[K, V, C](
context.taskMetrics.shuffleWriteMetrics.foreach(
_.incShuffleWriteTime(System.nanoTime - writeStartTime))
}
+ } else if (spills.isEmpty && partitionWriters == null) {
+ // Case where we only have in-memory data
+ val collection = if (aggregator.isDefined) map else buffer
+ val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
+ while (it.hasNext) {
+ val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
+ context.taskMetrics.shuffleWriteMetrics.get)
+ val partitionId = it.nextPartition()
+ while (it.hasNext && it.nextPartition() == partitionId) {
+ it.writeNext(writer)
+ }
+ writer.commitAndClose()
+ val segment = writer.fileSegment()
+ lengths(partitionId) = segment.length
+ }
} else {
- // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by
- // partition and just write everything directly.
+ // Not bypassing merge-sort; get an iterator by partition and just write everything directly.
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics.shuffleWriteMetrics.get)
for (elem <- elements) {
- writer.write(elem)
+ writer.write(elem._1, elem._2)
}
writer.commitAndClose()
val segment = writer.fileSegment()
@@ -799,7 +797,7 @@ private[spark] class ExternalSorter[K, V, C](
if (writer.isOpen) {
writer.commitAndClose()
}
- blockManager.diskStore.getValues(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]]
+ new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get)
}
def stop(): Unit = {
@@ -829,6 +827,14 @@ private[spark] class ExternalSorter[K, V, C](
(0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
}
+ private def comparator: Option[Comparator[K]] = {
+ if (ordering.isDefined || aggregator.isDefined) {
+ Some(keyComparator)
+ } else {
+ None
+ }
+ }
+
/**
* An iterator that reads only the elements for a given partition ID from an underlying buffered
* stream, assuming this partition is the next one to be read. Used to make it easier to return
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
new file mode 100644
index 0000000..d75959f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] {
+ def hasNext: Boolean = iter.hasNext
+
+ def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V])
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
new file mode 100644
index 0000000..e2e2f1f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+import org.apache.spark.util.collection.WritablePartitionedPairCollection._
+
+/**
+ * Implementation of WritablePartitionedPairCollection that wraps a map in which the keys are tuples
+ * of (partition ID, K)
+ */
+private[spark] class PartitionedAppendOnlyMap[K, V]
+ extends SizeTrackingAppendOnlyMap[(Int, K), V] with WritablePartitionedPairCollection[K, V] {
+
+ def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+ : Iterator[((Int, K), V)] = {
+ val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
+ destructiveSortedIterator(comparator)
+ }
+
+ def writablePartitionedIterator(): WritablePartitionedIterator = {
+ WritablePartitionedIterator.fromIterator(super.iterator)
+ }
+
+ def insert(partition: Int, key: K, value: V): Unit = {
+ update((partition, key), value)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
new file mode 100644
index 0000000..e8332e1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.util.collection.WritablePartitionedPairCollection._
+
+/**
+ * Append-only buffer of key-value pairs, each with a corresponding partition ID, that keeps track
+ * of its estimated size in bytes.
+ */
+private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64)
+ extends WritablePartitionedPairCollection[K, V] with SizeTracker
+{
+ require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+ require(initialCapacity >= 1, "Invalid initial capacity")
+
+ // Basic growable array data structure. We use a single array of AnyRef to hold both the keys
+ // and the values, so that we can sort them efficiently with KVArraySortDataFormat.
+ private var capacity = initialCapacity
+ private var curSize = 0
+ private var data = new Array[AnyRef](2 * initialCapacity)
+
+ /** Add an element into the buffer */
+ def insert(partition: Int, key: K, value: V): Unit = {
+ if (curSize == capacity) {
+ growArray()
+ }
+ data(2 * curSize) = (partition, key.asInstanceOf[AnyRef])
+ data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
+ curSize += 1
+ afterUpdate()
+ }
+
+ /** Double the size of the array because we've reached capacity */
+ private def growArray(): Unit = {
+ if (capacity == (1 << 29)) {
+ // Doubling the capacity would create an array bigger than Int.MaxValue, so don't
+ throw new Exception("Can't grow buffer beyond 2^29 elements")
+ }
+ val newCapacity = capacity * 2
+ val newArray = new Array[AnyRef](2 * newCapacity)
+ System.arraycopy(data, 0, newArray, 0, 2 * capacity)
+ data = newArray
+ capacity = newCapacity
+ resetSamples()
+ }
+
+ /** Iterate through the data in a given order. For this class this is not really destructive. */
+ override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+ : Iterator[((Int, K), V)] = {
+ val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
+ new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator)
+ iterator
+ }
+
+ override def writablePartitionedIterator(): WritablePartitionedIterator = {
+ WritablePartitionedIterator.fromIterator(iterator)
+ }
+
+ private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] {
+ var pos = 0
+
+ override def hasNext: Boolean = pos < curSize
+
+ override def next(): ((Int, K), V) = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val pair = (data(2 * pos).asInstanceOf[(Int, K)], data(2 * pos + 1).asInstanceOf[V])
+ pos += 1
+ pair
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
new file mode 100644
index 0000000..b5ca0c6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io.InputStream
+import java.nio.IntBuffer
+import java.util.Comparator
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
+import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
+
+/**
+ * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes
+ * its records upon insert and stores them as raw bytes.
+ *
+ * We use two data-structures to store the contents. The serialized records are stored in a
+ * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a
+ * metadata buffer that stores pointers into the data buffer as well as the partition ID of each
+ * record. Each entry in the metadata buffer takes up a fixed amount of space.
+ *
+ * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not
+ * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can
+ * happen without following any pointers, which should minimize cache misses.
+ *
+ * Currently, only sorting by partition is supported.
+ *
+ * @param metaInitialRecords The initial number of entries in the metadata buffer.
+ * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records.
+ * @param serializerInstance the serializer used for serializing inserted records.
+ */
+private[spark] class PartitionedSerializedPairBuffer[K, V](
+ metaInitialRecords: Int,
+ kvBlockSize: Int,
+ serializerInstance: SerializerInstance)
+ extends WritablePartitionedPairCollection[K, V] with SizeTracker {
+
+ if (serializerInstance.isInstanceOf[JavaSerializerInstance]) {
+ throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" +
+ " Java-serialized objects.")
+ }
+
+ private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE)
+
+ private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize)
+ private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer)
+ private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream)
+
+ def insert(partition: Int, key: K, value: V): Unit = {
+ if (metaBuffer.position == metaBuffer.capacity) {
+ growMetaBuffer()
+ }
+
+ val keyStart = kvBuffer.size
+ if (keyStart < 0) {
+ throw new Exception(s"Can't grow buffer beyond ${1 << 31} bytes")
+ }
+ kvSerializationStream.writeObject[Any](key)
+ kvSerializationStream.flush()
+ val valueStart = kvBuffer.size
+ kvSerializationStream.writeObject[Any](value)
+ kvSerializationStream.flush()
+ val valueEnd = kvBuffer.size
+
+ metaBuffer.put(keyStart)
+ metaBuffer.put(valueStart)
+ metaBuffer.put(valueEnd)
+ metaBuffer.put(partition)
+ }
+
+ /** Double the size of the array because we've reached capacity */
+ private def growMetaBuffer(): Unit = {
+ if (metaBuffer.capacity.toLong * 2 > Int.MaxValue) {
+ // Doubling the capacity would create an array bigger than Int.MaxValue, so don't
+ throw new Exception(s"Can't grow buffer beyond ${Int.MaxValue} bytes")
+ }
+ val newMetaBuffer = IntBuffer.allocate(metaBuffer.capacity * 2)
+ newMetaBuffer.put(metaBuffer.array)
+ metaBuffer = newMetaBuffer
+ }
+
+ /** Iterate through the data in a given order. For this class this is not really destructive. */
+ override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+ : Iterator[((Int, K), V)] = {
+ sort(keyComparator)
+ val is = orderedInputStream
+ val deserStream = serializerInstance.deserializeStream(is)
+ new Iterator[((Int, K), V)] {
+ var metaBufferPos = 0
+ def hasNext: Boolean = metaBufferPos < metaBuffer.position
+ def next(): ((Int, K), V) = {
+ val key = deserStream.readKey[Any]().asInstanceOf[K]
+ val value = deserStream.readValue[Any]().asInstanceOf[V]
+ val partition = metaBuffer.get(metaBufferPos + PARTITION)
+ metaBufferPos += RECORD_SIZE
+ ((partition, key), value)
+ }
+ }
+ }
+
+ override def estimateSize: Long = metaBuffer.capacity * 4 + kvBuffer.capacity
+
+ override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
+ : WritablePartitionedIterator = {
+ sort(keyComparator)
+ writablePartitionedIterator
+ }
+
+ override def writablePartitionedIterator(): WritablePartitionedIterator = {
+ new WritablePartitionedIterator {
+ // current position in the meta buffer in ints
+ var pos = 0
+
+ def writeNext(writer: BlockObjectWriter): Unit = {
+ val keyStart = metaBuffer.get(pos + KEY_START)
+ val valueEnd = metaBuffer.get(pos + VAL_END)
+ pos += RECORD_SIZE
+ kvBuffer.read(keyStart, writer, valueEnd - keyStart)
+ writer.recordWritten()
+ }
+ def nextPartition(): Int = metaBuffer.get(pos + PARTITION)
+ def hasNext(): Boolean = pos < metaBuffer.position
+ }
+ }
+
+ // Visible for testing
+ def orderedInputStream: OrderedInputStream = {
+ new OrderedInputStream(metaBuffer, kvBuffer)
+ }
+
+ private def sort(keyComparator: Option[Comparator[K]]): Unit = {
+ val comparator = if (keyComparator.isEmpty) {
+ new Comparator[Int]() {
+ def compare(partition1: Int, partition2: Int): Int = {
+ partition1 - partition2
+ }
+ }
+ } else {
+ throw new UnsupportedOperationException()
+ }
+
+ val sorter = new Sorter(new SerializedSortDataFormat)
+ sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator)
+ }
+}
+
+private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer)
+ extends InputStream {
+
+ private var metaBufferPos = 0
+ private var kvBufferPos =
+ if (metaBuffer.position > 0) metaBuffer.get(metaBufferPos + KEY_START) else 0
+
+ override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length)
+
+ override def read(bytes: Array[Byte], offs: Int, len: Int): Int = {
+ if (metaBufferPos >= metaBuffer.position) {
+ return -1
+ }
+ val bytesRemainingInRecord = metaBuffer.get(metaBufferPos + VAL_END) - kvBufferPos
+ val toRead = math.min(bytesRemainingInRecord, len)
+ kvBuffer.read(kvBufferPos, bytes, offs, toRead)
+ if (toRead == bytesRemainingInRecord) {
+ metaBufferPos += RECORD_SIZE
+ if (metaBufferPos < metaBuffer.position) {
+ kvBufferPos = metaBuffer.get(metaBufferPos + KEY_START)
+ }
+ } else {
+ kvBufferPos += toRead
+ }
+ toRead
+ }
+
+ override def read(): Int = {
+ throw new UnsupportedOperationException()
+ }
+}
+
+private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] {
+
+ private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE)
+
+ /** Return the sort key for the element at the given index. */
+ override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = {
+ metaBuffer.get(pos * RECORD_SIZE + PARTITION)
+ }
+
+ /** Swap two elements. */
+ override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = {
+ val iOff = pos0 * RECORD_SIZE
+ val jOff = pos1 * RECORD_SIZE
+ System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE)
+ System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE)
+ System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE)
+ }
+
+ /** Copy a single element from src(srcPos) to dst(dstPos). */
+ override def copyElement(
+ src: IntBuffer,
+ srcPos: Int,
+ dst: IntBuffer,
+ dstPos: Int): Unit = {
+ val srcOff = srcPos * RECORD_SIZE
+ val dstOff = dstPos * RECORD_SIZE
+ System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE)
+ }
+
+ /**
+ * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
+ * Overlapping ranges are allowed.
+ */
+ override def copyRange(
+ src: IntBuffer,
+ srcPos: Int,
+ dst: IntBuffer,
+ dstPos: Int,
+ length: Int): Unit = {
+ val srcOff = srcPos * RECORD_SIZE
+ val dstOff = dstPos * RECORD_SIZE
+ System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length)
+ }
+
+ /**
+ * Allocates a Buffer that can hold up to 'length' elements.
+ * All elements of the buffer should be considered invalid until data is explicitly copied in.
+ */
+ override def allocate(length: Int): IntBuffer = {
+ IntBuffer.allocate(length * RECORD_SIZE)
+ }
+}
+
+private[spark] object PartitionedSerializedPairBuffer {
+ val KEY_START = 0
+ val VAL_START = 1
+ val VAL_END = 2
+ val PARTITION = 3
+ val RECORD_SIZE = Seq(KEY_START, VAL_START, VAL_END, PARTITION).size // num ints of metadata
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
index eb4de41..722f78b 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
@@ -21,7 +21,7 @@ package org.apache.spark.util.collection
* An append-only map that keeps track of its estimated size in bytes.
*/
private[spark] class SizeTrackingAppendOnlyMap[K, V]
- extends AppendOnlyMap[K, V] with SizeTracker with SizeTrackingPairCollection[K, V]
+ extends AppendOnlyMap[K, V] with SizeTracker
{
override def update(key: K, value: V): Unit = {
super.update(key, value)
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
deleted file mode 100644
index 9e9c16c..0000000
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.util.Comparator
-
-/**
- * Append-only buffer of key-value pairs that keeps track of its estimated size in bytes.
- */
-private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64)
- extends SizeTracker with SizeTrackingPairCollection[K, V]
-{
- require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
- require(initialCapacity >= 1, "Invalid initial capacity")
-
- // Basic growable array data structure. We use a single array of AnyRef to hold both the keys
- // and the values, so that we can sort them efficiently with KVArraySortDataFormat.
- private var capacity = initialCapacity
- private var curSize = 0
- private var data = new Array[AnyRef](2 * initialCapacity)
-
- /** Add an element into the buffer */
- def insert(key: K, value: V): Unit = {
- if (curSize == capacity) {
- growArray()
- }
- data(2 * curSize) = key.asInstanceOf[AnyRef]
- data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
- curSize += 1
- afterUpdate()
- }
-
- /** Total number of elements in buffer */
- override def size: Int = curSize
-
- /** Iterate over the elements of the buffer */
- override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
- var pos = 0
-
- override def hasNext: Boolean = pos < curSize
-
- override def next(): (K, V) = {
- if (!hasNext) {
- throw new NoSuchElementException
- }
- val pair = (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
- pos += 1
- pair
- }
- }
-
- /** Double the size of the array because we've reached capacity */
- private def growArray(): Unit = {
- if (capacity == (1 << 29)) {
- // Doubling the capacity would create an array bigger than Int.MaxValue, so don't
- throw new Exception("Can't grow buffer beyond 2^29 elements")
- }
- val newCapacity = capacity * 2
- val newArray = new Array[AnyRef](2 * newCapacity)
- System.arraycopy(data, 0, newArray, 0, 2 * capacity)
- data = newArray
- capacity = newCapacity
- resetSamples()
- }
-
- /** Iterate through the data in a given order. For this class this is not really destructive. */
- override def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
- new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, curSize, keyComparator)
- iterator
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
deleted file mode 100644
index faa4e2b..0000000
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.util.Comparator
-
-/**
- * A common interface for our size-tracking collections of key-value pairs, which are used in
- * external operations. These all support estimating the size and obtaining a memory-efficient
- * sorted iterator.
- */
-// TODO: should extend Iterable[Product2[K, V]] instead of (K, V)
-private[spark] trait SizeTrackingPairCollection[K, V] extends Iterable[(K, V)] {
- /** Estimate the collection's current memory usage in bytes. */
- def estimateSize(): Long
-
- /** Iterate through the data in a given key order. This may destroy the underlying collection. */
- def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)]
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
new file mode 100644
index 0000000..f26d161
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+import org.apache.spark.storage.BlockObjectWriter
+
+/**
+ * A common interface for size-tracking collections of key-value pairs that
+ * - Have an associated partition for each key-value pair.
+ * - Support a memory-efficient sorted iterator
+ * - Support a WritablePartitionedIterator for writing the contents directly as bytes.
+ */
+private[spark] trait WritablePartitionedPairCollection[K, V] {
+ /**
+ * Insert a key-value pair with a partition into the collection
+ */
+ def insert(partition: Int, key: K, value: V): Unit
+
+ /**
+ * Iterate through the data in order of partition ID and then the given comparator. This may
+ * destroy the underlying collection.
+ */
+ def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+ : Iterator[((Int, K), V)]
+
+ /**
+ * Iterate through the data and write out the elements instead of returning them. Records are
+ * returned in order of their partition ID and then the given comparator.
+ * This may destroy the underlying collection.
+ */
+ def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
+ : WritablePartitionedIterator = {
+ WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator))
+ }
+
+ /**
+ * Iterate through the data and write out the elements instead of returning them.
+ */
+ def writablePartitionedIterator(): WritablePartitionedIterator
+}
+
+private[spark] object WritablePartitionedPairCollection {
+ /**
+ * A comparator for (Int, K) pairs that orders them by only their partition ID.
+ */
+ def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] {
+ override def compare(a: (Int, K), b: (Int, K)): Int = {
+ a._1 - b._1
+ }
+ }
+
+ /**
+ * A comparator for (Int, K) pairs that orders them both by their partition ID and a key ordering.
+ */
+ def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = {
+ new Comparator[(Int, K)] {
+ override def compare(a: (Int, K), b: (Int, K)): Int = {
+ val partitionDiff = a._1 - b._1
+ if (partitionDiff != 0) {
+ partitionDiff
+ } else {
+ keyComparator.compare(a._2, b._2)
+ }
+ }
+ }
+ }
+}
+
+/**
+ * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element
+ * has an associated partition.
+ */
+private[spark] trait WritablePartitionedIterator {
+ def writeNext(writer: BlockObjectWriter): Unit
+
+ def hasNext(): Boolean
+
+ def nextPartition(): Int
+}
+
+private[spark] object WritablePartitionedIterator {
+ def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = {
+ new WritablePartitionedIterator {
+ var cur = if (it.hasNext) it.next() else null
+
+ def writeNext(writer: BlockObjectWriter): Unit = {
+ writer.write(cur._1._2, cur._2)
+ cur = if (it.hasNext) it.next() else null
+ }
+
+ def hasNext(): Boolean = cur != null
+
+ def nextPartition(): Int = cur._1._1
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 1b13559..778a7ee 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -280,6 +280,15 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
val thrown = intercept[SparkException](ser.serialize(largeObject))
assert(thrown.getMessage.contains(kryoBufferMaxProperty))
}
+
+ test("getAutoReset") {
+ val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance]
+ assert(ser.getAutoReset)
+ val conf = new SparkConf().set("spark.kryo.registrator",
+ classOf[RegistratorWithoutAutoReset].getName)
+ val ser2 = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance]
+ assert(!ser2.getAutoReset)
+ }
}
@@ -313,4 +322,10 @@ object KryoTest {
k.register(classOf[java.util.HashMap[_, _]])
}
}
+
+ class RegistratorWithoutAutoReset extends KryoRegistrator {
+ override def registerClasses(k: Kryo) {
+ k.setAutoReset(false)
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
index 963264c..86fcf44 100644
--- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
@@ -24,7 +24,7 @@ import scala.reflect.ClassTag
/**
- * A serializer implementation that always return a single element in a deserialization stream.
+ * A serializer implementation that always returns two elements in a deserialization stream.
*/
class TestSerializer extends Serializer {
override def newInstance(): TestSerializerInstance = new TestSerializerInstance
@@ -51,7 +51,7 @@ class TestDeserializationStream extends DeserializationStream {
override def readObject[T: ClassTag](): T = {
count += 1
- if (count == 2) {
+ if (count == 3) {
throw new EOFException
}
new Object().asInstanceOf[T]
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
index 7d76435..84384bb 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
@@ -59,8 +59,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf),
new ShuffleWriteMetrics)
for (writer <- shuffle1.writers) {
- writer.write("test1")
- writer.write("test2")
+ writer.write("test1", "value")
+ writer.write("test2", "value")
}
for (writer <- shuffle1.writers) {
writer.commitAndClose()
@@ -73,8 +73,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
new ShuffleWriteMetrics)
for (writer <- shuffle2.writers) {
- writer.write("test3")
- writer.write("test4")
+ writer.write("test3", "value")
+ writer.write("test4", "vlue")
}
for (writer <- shuffle2.writers) {
writer.commitAndClose()
@@ -91,8 +91,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf),
new ShuffleWriteMetrics)
for (writer <- shuffle3.writers) {
- writer.write("test3")
- writer.write("test4")
+ writer.write("test3", "value")
+ writer.write("test4", "value")
}
for (writer <- shuffle3.writers) {
writer.commitAndClose()
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
index 003a728..43ef469 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
@@ -32,7 +32,7 @@ class BlockObjectWriterSuite extends FunSuite {
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
- writer.write(Long.box(20))
+ writer.write(Long.box(20), Long.box(30))
// Record metrics update on every write
assert(writeMetrics.shuffleRecordsWritten === 1)
// Metrics don't update on every write
@@ -40,7 +40,7 @@ class BlockObjectWriterSuite extends FunSuite {
// After 32 writes, metrics should update
for (i <- 0 until 32) {
writer.flush()
- writer.write(Long.box(i))
+ writer.write(Long.box(i), Long.box(i))
}
assert(writeMetrics.shuffleBytesWritten > 0)
assert(writeMetrics.shuffleRecordsWritten === 33)
@@ -54,7 +54,7 @@ class BlockObjectWriterSuite extends FunSuite {
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
- writer.write(Long.box(20))
+ writer.write(Long.box(20), Long.box(30))
// Record metrics update on every write
assert(writeMetrics.shuffleRecordsWritten === 1)
// Metrics don't update on every write
@@ -62,7 +62,7 @@ class BlockObjectWriterSuite extends FunSuite {
// After 32 writes, metrics should update
for (i <- 0 until 32) {
writer.flush()
- writer.write(Long.box(i))
+ writer.write(Long.box(i), Long.box(i))
}
assert(writeMetrics.shuffleBytesWritten > 0)
assert(writeMetrics.shuffleRecordsWritten === 33)
http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
new file mode 100644
index 0000000..c0c38cd
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.nio.ByteBuffer
+
+import org.scalatest.FunSuite
+import org.scalatest.Matchers._
+
+class ChainedBufferSuite extends FunSuite {
+ test("write and read at start") {
+ // write from start of source array
+ val buffer = new ChainedBuffer(8)
+ buffer.capacity should be (0)
+ verifyWriteAndRead(buffer, 0, 0, 0, 4)
+ buffer.capacity should be (8)
+
+ // write from middle of source array
+ verifyWriteAndRead(buffer, 0, 5, 0, 4)
+ buffer.capacity should be (8)
+
+ // read to middle of target array
+ verifyWriteAndRead(buffer, 0, 0, 5, 4)
+ buffer.capacity should be (8)
+
+ // write up to border
+ verifyWriteAndRead(buffer, 0, 0, 0, 8)
+ buffer.capacity should be (8)
+
+ // expand into second buffer
+ verifyWriteAndRead(buffer, 0, 0, 0, 12)
+ buffer.capacity should be (16)
+
+ // expand into multiple buffers
+ verifyWriteAndRead(buffer, 0, 0, 0, 28)
+ buffer.capacity should be (32)
+ }
+
+ test("write and read at middle") {
+ val buffer = new ChainedBuffer(8)
+
+ // fill to a middle point
+ verifyWriteAndRead(buffer, 0, 0, 0, 3)
+
+ // write from start of source array
+ verifyWriteAndRead(buffer, 3, 0, 0, 4)
+ buffer.capacity should be (8)
+
+ // write from middle of source array
+ verifyWriteAndRead(buffer, 3, 5, 0, 4)
+ buffer.capacity should be (8)
+
+ // read to middle of target array
+ verifyWriteAndRead(buffer, 3, 0, 5, 4)
+ buffer.capacity should be (8)
+
+ // write up to border
+ verifyWriteAndRead(buffer, 3, 0, 0, 5)
+ buffer.capacity should be (8)
+
+ // expand into second buffer
+ verifyWriteAndRead(buffer, 3, 0, 0, 12)
+ buffer.capacity should be (16)
+
+ // expand into multiple buffers
+ verifyWriteAndRead(buffer, 3, 0, 0, 28)
+ buffer.capacity should be (32)
+ }
+
+ test("write and read at later buffer") {
+ val buffer = new ChainedBuffer(8)
+
+ // fill to a middle point
+ verifyWriteAndRead(buffer, 0, 0, 0, 11)
+
+ // write from start of source array
+ verifyWriteAndRead(buffer, 11, 0, 0, 4)
+ buffer.capacity should be (16)
+
+ // write from middle of source array
+ verifyWriteAndRead(buffer, 11, 5, 0, 4)
+ buffer.capacity should be (16)
+
+ // read to middle of target array
+ verifyWriteAndRead(buffer, 11, 0, 5, 4)
+ buffer.capacity should be (16)
+
+ // write up to border
+ verifyWriteAndRead(buffer, 11, 0, 0, 5)
+ buffer.capacity should be (16)
+
+ // expand into second buffer
+ verifyWriteAndRead(buffer, 11, 0, 0, 12)
+ buffer.capacity should be (24)
+
+ // expand into multiple buffers
+ verifyWriteAndRead(buffer, 11, 0, 0, 28)
+ buffer.capacity should be (40)
+ }
+
+
+ // Used to make sure we're writing different bytes each time
+ var rangeStart = 0
+
+ /**
+ * @param buffer The buffer to write to and read from.
+ * @param offsetInBuffer The offset to write to in the buffer.
+ * @param offsetInSource The offset in the array that the bytes are written from.
+ * @param offsetInTarget The offset in the array to read the bytes into.
+ * @param length The number of bytes to read and write
+ */
+ def verifyWriteAndRead(
+ buffer: ChainedBuffer,
+ offsetInBuffer: Int,
+ offsetInSource: Int,
+ offsetInTarget: Int,
+ length: Int): Unit = {
+ val source = new Array[Byte](offsetInSource + length)
+ (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource)
+ buffer.write(offsetInBuffer, source, offsetInSource, length)
+ val target = new Array[Byte](offsetInTarget + length)
+ buffer.read(offsetInBuffer, target, offsetInTarget, length)
+ ByteBuffer.wrap(source, offsetInSource, length) should be
+ (ByteBuffer.wrap(target, offsetInTarget, length))
+
+ rangeStart += 100
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org