You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2015/10/22 18:46:39 UTC
[1/4] spark git commit: [SPARK-10708] Consolidate sort shuffle
implementations
Repository: spark
Updated Branches:
refs/heads/master 94e2064fa -> f6d06adf0
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 341f56d..b92a302 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -33,7 +33,8 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark._
import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics}
-import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer}
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.serializer.{JavaSerializer, SerializerInstance}
import org.apache.spark.storage._
import org.apache.spark.util.Utils
@@ -42,25 +43,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
@Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
@Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
@Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _
private var taskMetrics: TaskMetrics = _
- private var shuffleWriteMetrics: ShuffleWriteMetrics = _
private var tempDir: File = _
private var outputFile: File = _
private val conf: SparkConf = new SparkConf(loadDefaults = false)
private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
- private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0)
- private val serializer: Serializer = new JavaSerializer(conf)
+ private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _
override def beforeEach(): Unit = {
tempDir = Utils.createTempDir()
outputFile = File.createTempFile("shuffle", null, tempDir)
- shuffleWriteMetrics = new ShuffleWriteMetrics
taskMetrics = new TaskMetrics
- taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
MockitoAnnotations.initMocks(this)
+ shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int](
+ shuffleId = 0,
+ numMaps = 2,
+ dependency = dependency
+ )
+ when(dependency.partitioner).thenReturn(new HashPartitioner(7))
+ when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf)))
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
+ when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
when(blockManager.getDiskWriter(
any[BlockId],
@@ -107,18 +114,20 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
test("write empty iterator") {
val writer = new BypassMergeSortShuffleWriter[Int, Int](
- new SparkConf(loadDefaults = false),
blockManager,
- new HashPartitioner(7),
- shuffleWriteMetrics,
- serializer
+ blockResolver,
+ shuffleHandle,
+ 0, // MapId
+ taskContext,
+ conf
)
- writer.insertAll(Iterator.empty)
- val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
- assert(partitionLengths.sum === 0)
+ writer.write(Iterator.empty)
+ writer.stop( /* success = */ true)
+ assert(writer.getPartitionLengths.sum === 0)
assert(outputFile.exists())
assert(outputFile.length() === 0)
assert(temporaryFilesCreated.isEmpty)
+ val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get
assert(shuffleWriteMetrics.shuffleBytesWritten === 0)
assert(shuffleWriteMetrics.shuffleRecordsWritten === 0)
assert(taskMetrics.diskBytesSpilled === 0)
@@ -129,17 +138,19 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
def records: Iterator[(Int, Int)] =
Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
val writer = new BypassMergeSortShuffleWriter[Int, Int](
- new SparkConf(loadDefaults = false),
blockManager,
- new HashPartitioner(7),
- shuffleWriteMetrics,
- serializer
+ blockResolver,
+ shuffleHandle,
+ 0, // MapId
+ taskContext,
+ conf
)
- writer.insertAll(records)
+ writer.write(records)
+ writer.stop( /* success = */ true)
assert(temporaryFilesCreated.nonEmpty)
- val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
- assert(partitionLengths.sum === outputFile.length())
+ assert(writer.getPartitionLengths.sum === outputFile.length())
assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
+ val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get
assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length())
assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length)
assert(taskMetrics.diskBytesSpilled === 0)
@@ -148,14 +159,15 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
test("cleanup of intermediate files after errors") {
val writer = new BypassMergeSortShuffleWriter[Int, Int](
- new SparkConf(loadDefaults = false),
blockManager,
- new HashPartitioner(7),
- shuffleWriteMetrics,
- serializer
+ blockResolver,
+ shuffleHandle,
+ 0, // MapId
+ taskContext,
+ conf
)
intercept[SparkException] {
- writer.insertAll((0 until 100000).iterator.map(i => {
+ writer.write((0 until 100000).iterator.map(i => {
if (i == 99990) {
throw new SparkException("Intentional failure")
}
@@ -163,7 +175,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
}))
}
assert(temporaryFilesCreated.nonEmpty)
- writer.stop()
+ writer.stop( /* success = */ false)
assert(temporaryFilesCreated.count(_.exists()) === 0)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
new file mode 100644
index 0000000..8744a07
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.Matchers
+
+import org.apache.spark._
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
+
+/**
+ * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
+ * performed in other suites.
+ */
+class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
+
+ import SortShuffleManager.canUseSerializedShuffle
+
+ private class RuntimeExceptionAnswer extends Answer[Object] {
+ override def answer(invocation: InvocationOnMock): Object = {
+ throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName)
+ }
+ }
+
+ private def shuffleDep(
+ partitioner: Partitioner,
+ serializer: Option[Serializer],
+ keyOrdering: Option[Ordering[Any]],
+ aggregator: Option[Aggregator[Any, Any, Any]],
+ mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = {
+ val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer())
+ doReturn(0).when(dep).shuffleId
+ doReturn(partitioner).when(dep).partitioner
+ doReturn(serializer).when(dep).serializer
+ doReturn(keyOrdering).when(dep).keyOrdering
+ doReturn(aggregator).when(dep).aggregator
+ doReturn(mapSideCombine).when(dep).mapSideCombine
+ dep
+ }
+
+ test("supported shuffle dependencies for serialized shuffle") {
+ val kryo = Some(new KryoSerializer(new SparkConf()))
+
+ assert(canUseSerializedShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
+ when(rangePartitioner.numPartitions).thenReturn(2)
+ assert(canUseSerializedShuffle(shuffleDep(
+ partitioner = rangePartitioner,
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ // Shuffles with key orderings are supported as long as no aggregator is specified
+ assert(canUseSerializedShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = Some(mock(classOf[Ordering[Any]])),
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ }
+
+ test("unsupported shuffle dependencies for serialized shuffle") {
+ val kryo = Some(new KryoSerializer(new SparkConf()))
+ val java = Some(new JavaSerializer(new SparkConf()))
+
+ // We only support serializers that support object relocation
+ assert(!canUseSerializedShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = java,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ // The serialized shuffle path do not support shuffles with more than 16 million output
+ // partitions, due to a limitation in its sorter implementation.
+ assert(!canUseSerializedShuffle(shuffleDep(
+ partitioner = new HashPartitioner(
+ SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1),
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ // We do not support shuffles that perform aggregation
+ assert(!canUseSerializedShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+ mapSideCombine = false
+ )))
+ assert(!canUseSerializedShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = Some(mock(classOf[Ordering[Any]])),
+ aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+ mapSideCombine = true
+ )))
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
deleted file mode 100644
index 34b4984..0000000
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.sort
-
-import org.mockito.Mockito._
-
-import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite}
-
-class SortShuffleWriterSuite extends SparkFunSuite {
-
- import SortShuffleWriter._
-
- test("conditions for bypassing merge-sort") {
- val conf = new SparkConf(loadDefaults = false)
- val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS)
- val ord = implicitly[Ordering[Int]]
-
- // Numbers of partitions that are above and below the default bypassMergeThreshold
- val FEW_PARTITIONS = 50
- val MANY_PARTITIONS = 10000
-
- // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high
- assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None))
- assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None))
-
- // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions
- assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord)))
- assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None))
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
deleted file mode 100644
index 6727934..0000000
--- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
+++ /dev/null
@@ -1,129 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe
-
-import org.mockito.Mockito._
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.stubbing.Answer
-import org.scalatest.Matchers
-
-import org.apache.spark._
-import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
-
-/**
- * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
- * performed in other suites.
- */
-class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
-
- import UnsafeShuffleManager.canUseUnsafeShuffle
-
- private class RuntimeExceptionAnswer extends Answer[Object] {
- override def answer(invocation: InvocationOnMock): Object = {
- throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName)
- }
- }
-
- private def shuffleDep(
- partitioner: Partitioner,
- serializer: Option[Serializer],
- keyOrdering: Option[Ordering[Any]],
- aggregator: Option[Aggregator[Any, Any, Any]],
- mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = {
- val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer())
- doReturn(0).when(dep).shuffleId
- doReturn(partitioner).when(dep).partitioner
- doReturn(serializer).when(dep).serializer
- doReturn(keyOrdering).when(dep).keyOrdering
- doReturn(aggregator).when(dep).aggregator
- doReturn(mapSideCombine).when(dep).mapSideCombine
- dep
- }
-
- test("supported shuffle dependencies") {
- val kryo = Some(new KryoSerializer(new SparkConf()))
-
- assert(canUseUnsafeShuffle(shuffleDep(
- partitioner = new HashPartitioner(2),
- serializer = kryo,
- keyOrdering = None,
- aggregator = None,
- mapSideCombine = false
- )))
-
- val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
- when(rangePartitioner.numPartitions).thenReturn(2)
- assert(canUseUnsafeShuffle(shuffleDep(
- partitioner = rangePartitioner,
- serializer = kryo,
- keyOrdering = None,
- aggregator = None,
- mapSideCombine = false
- )))
-
- // Shuffles with key orderings are supported as long as no aggregator is specified
- assert(canUseUnsafeShuffle(shuffleDep(
- partitioner = new HashPartitioner(2),
- serializer = kryo,
- keyOrdering = Some(mock(classOf[Ordering[Any]])),
- aggregator = None,
- mapSideCombine = false
- )))
-
- }
-
- test("unsupported shuffle dependencies") {
- val kryo = Some(new KryoSerializer(new SparkConf()))
- val java = Some(new JavaSerializer(new SparkConf()))
-
- // We only support serializers that support object relocation
- assert(!canUseUnsafeShuffle(shuffleDep(
- partitioner = new HashPartitioner(2),
- serializer = java,
- keyOrdering = None,
- aggregator = None,
- mapSideCombine = false
- )))
-
- // We do not support shuffles with more than 16 million output partitions
- assert(!canUseUnsafeShuffle(shuffleDep(
- partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1),
- serializer = kryo,
- keyOrdering = None,
- aggregator = None,
- mapSideCombine = false
- )))
-
- // We do not support shuffles that perform aggregation
- assert(!canUseUnsafeShuffle(shuffleDep(
- partitioner = new HashPartitioner(2),
- serializer = kryo,
- keyOrdering = None,
- aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
- mapSideCombine = false
- )))
- assert(!canUseUnsafeShuffle(shuffleDep(
- partitioner = new HashPartitioner(2),
- serializer = kryo,
- keyOrdering = Some(mock(classOf[Ordering[Any]])),
- aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
- mapSideCombine = true
- )))
- }
-
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
deleted file mode 100644
index 259020a..0000000
--- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe
-
-import java.io.File
-
-import scala.collection.JavaConverters._
-
-import org.apache.commons.io.FileUtils
-import org.apache.commons.io.filefilter.TrueFileFilter
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
-import org.apache.spark.rdd.ShuffledRDD
-import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
-import org.apache.spark.util.Utils
-
-class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
-
- // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
-
- override def beforeAll() {
- conf.set("spark.shuffle.manager", "tungsten-sort")
- }
-
- test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
- val tmpDir = Utils.createTempDir()
- try {
- val myConf = conf.clone()
- .set("spark.local.dir", tmpDir.getAbsolutePath)
- sc = new SparkContext("local", "test", myConf)
- // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
- val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
- val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
- .setSerializer(new KryoSerializer(myConf))
- val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
- assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
- def getAllFiles: Set[File] =
- FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
- val filesBeforeShuffle = getAllFiles
- // Force the shuffle to be performed
- shuffledRdd.count()
- // Ensure that the shuffle actually created files that will need to be cleaned up
- val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
- filesCreatedByShuffle.map(_.getName) should be
- Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
- // Check that the cleanup actually removes the files
- sc.env.blockManager.master.removeShuffle(0, blocking = true)
- for (file <- filesCreatedByShuffle) {
- assert (!file.exists(), s"Shuffle file $file was not cleaned up")
- }
- } finally {
- Utils.deleteRecursively(tmpDir)
- }
- }
-
- test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
- val tmpDir = Utils.createTempDir()
- try {
- val myConf = conf.clone()
- .set("spark.local.dir", tmpDir.getAbsolutePath)
- sc = new SparkContext("local", "test", myConf)
- // Create a shuffled RDD and verify that it will actually use the old SortShuffle path
- val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
- val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
- .setSerializer(new JavaSerializer(myConf))
- val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
- assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
- def getAllFiles: Set[File] =
- FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
- val filesBeforeShuffle = getAllFiles
- // Force the shuffle to be performed
- shuffledRdd.count()
- // Ensure that the shuffle actually created files that will need to be cleaned up
- val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
- filesCreatedByShuffle.map(_.getName) should be
- Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
- // Check that the cleanup actually removes the files
- sc.env.blockManager.master.removeShuffle(0, blocking = true)
- for (file <- filesCreatedByShuffle) {
- assert (!file.exists(), s"Shuffle file $file was not cleaned up")
- }
- } finally {
- Utils.deleteRecursively(tmpDir)
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
deleted file mode 100644
index 05306f4..0000000
--- a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.nio.ByteBuffer
-
-import org.scalatest.Matchers._
-
-import org.apache.spark.SparkFunSuite
-
-class ChainedBufferSuite extends SparkFunSuite {
- test("write and read at start") {
- // write from start of source array
- val buffer = new ChainedBuffer(8)
- buffer.capacity should be (0)
- verifyWriteAndRead(buffer, 0, 0, 0, 4)
- buffer.capacity should be (8)
-
- // write from middle of source array
- verifyWriteAndRead(buffer, 0, 5, 0, 4)
- buffer.capacity should be (8)
-
- // read to middle of target array
- verifyWriteAndRead(buffer, 0, 0, 5, 4)
- buffer.capacity should be (8)
-
- // write up to border
- verifyWriteAndRead(buffer, 0, 0, 0, 8)
- buffer.capacity should be (8)
-
- // expand into second buffer
- verifyWriteAndRead(buffer, 0, 0, 0, 12)
- buffer.capacity should be (16)
-
- // expand into multiple buffers
- verifyWriteAndRead(buffer, 0, 0, 0, 28)
- buffer.capacity should be (32)
- }
-
- test("write and read at middle") {
- val buffer = new ChainedBuffer(8)
-
- // fill to a middle point
- verifyWriteAndRead(buffer, 0, 0, 0, 3)
-
- // write from start of source array
- verifyWriteAndRead(buffer, 3, 0, 0, 4)
- buffer.capacity should be (8)
-
- // write from middle of source array
- verifyWriteAndRead(buffer, 3, 5, 0, 4)
- buffer.capacity should be (8)
-
- // read to middle of target array
- verifyWriteAndRead(buffer, 3, 0, 5, 4)
- buffer.capacity should be (8)
-
- // write up to border
- verifyWriteAndRead(buffer, 3, 0, 0, 5)
- buffer.capacity should be (8)
-
- // expand into second buffer
- verifyWriteAndRead(buffer, 3, 0, 0, 12)
- buffer.capacity should be (16)
-
- // expand into multiple buffers
- verifyWriteAndRead(buffer, 3, 0, 0, 28)
- buffer.capacity should be (32)
- }
-
- test("write and read at later buffer") {
- val buffer = new ChainedBuffer(8)
-
- // fill to a middle point
- verifyWriteAndRead(buffer, 0, 0, 0, 11)
-
- // write from start of source array
- verifyWriteAndRead(buffer, 11, 0, 0, 4)
- buffer.capacity should be (16)
-
- // write from middle of source array
- verifyWriteAndRead(buffer, 11, 5, 0, 4)
- buffer.capacity should be (16)
-
- // read to middle of target array
- verifyWriteAndRead(buffer, 11, 0, 5, 4)
- buffer.capacity should be (16)
-
- // write up to border
- verifyWriteAndRead(buffer, 11, 0, 0, 5)
- buffer.capacity should be (16)
-
- // expand into second buffer
- verifyWriteAndRead(buffer, 11, 0, 0, 12)
- buffer.capacity should be (24)
-
- // expand into multiple buffers
- verifyWriteAndRead(buffer, 11, 0, 0, 28)
- buffer.capacity should be (40)
- }
-
-
- // Used to make sure we're writing different bytes each time
- var rangeStart = 0
-
- /**
- * @param buffer The buffer to write to and read from.
- * @param offsetInBuffer The offset to write to in the buffer.
- * @param offsetInSource The offset in the array that the bytes are written from.
- * @param offsetInTarget The offset in the array to read the bytes into.
- * @param length The number of bytes to read and write
- */
- def verifyWriteAndRead(
- buffer: ChainedBuffer,
- offsetInBuffer: Int,
- offsetInSource: Int,
- offsetInTarget: Int,
- length: Int): Unit = {
- val source = new Array[Byte](offsetInSource + length)
- (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource)
- buffer.write(offsetInBuffer, source, offsetInSource, length)
- val target = new Array[Byte](offsetInTarget + length)
- buffer.read(offsetInBuffer, target, offsetInTarget, length)
- ByteBuffer.wrap(source, offsetInSource, length) should be
- (ByteBuffer.wrap(target, offsetInTarget, length))
-
- rangeStart += 100
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
deleted file mode 100644
index 3b67f62..0000000
--- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-
-import com.google.common.io.ByteStreams
-
-import org.mockito.Matchers.any
-import org.mockito.Mockito._
-import org.mockito.Mockito.RETURNS_SMART_NULLS
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.stubbing.Answer
-import org.scalatest.Matchers._
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.storage.DiskBlockObjectWriter
-
-class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
- test("OrderedInputStream single record") {
- val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-
- val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
- val struct = SomeStruct("something", 5)
- buffer.insert(4, 10, struct)
-
- val bytes = ByteStreams.toByteArray(buffer.orderedInputStream)
-
- val baos = new ByteArrayOutputStream()
- val stream = serializerInstance.serializeStream(baos)
- stream.writeObject(10)
- stream.writeObject(struct)
- stream.close()
-
- baos.toByteArray should be (bytes)
- }
-
- test("insert single record") {
- val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
- val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
- val struct = SomeStruct("something", 5)
- buffer.insert(4, 10, struct)
- val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
- elements.size should be (1)
- elements.head should be (((4, 10), struct))
- }
-
- test("insert multiple records") {
- val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
- val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
- val struct1 = SomeStruct("something1", 8)
- buffer.insert(6, 1, struct1)
- val struct2 = SomeStruct("something2", 9)
- buffer.insert(4, 2, struct2)
- val struct3 = SomeStruct("something3", 10)
- buffer.insert(5, 3, struct3)
-
- val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
- elements.size should be (3)
- elements(0) should be (((4, 2), struct2))
- elements(1) should be (((5, 3), struct3))
- elements(2) should be (((6, 1), struct1))
- }
-
- test("write single record") {
- val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
- val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
- val struct = SomeStruct("something", 5)
- buffer.insert(4, 10, struct)
- val it = buffer.destructiveSortedWritablePartitionedIterator(None)
- val (writer, baos) = createMockWriter()
- assert(it.hasNext)
- it.nextPartition should be (4)
- it.writeNext(writer)
- assert(!it.hasNext)
-
- val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
- stream.readObject[AnyRef]() should be (10)
- stream.readObject[AnyRef]() should be (struct)
- }
-
- test("write multiple records") {
- val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
- val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
- val struct1 = SomeStruct("something1", 8)
- buffer.insert(6, 1, struct1)
- val struct2 = SomeStruct("something2", 9)
- buffer.insert(4, 2, struct2)
- val struct3 = SomeStruct("something3", 10)
- buffer.insert(5, 3, struct3)
-
- val it = buffer.destructiveSortedWritablePartitionedIterator(None)
- val (writer, baos) = createMockWriter()
- assert(it.hasNext)
- it.nextPartition should be (4)
- it.writeNext(writer)
- assert(it.hasNext)
- it.nextPartition should be (5)
- it.writeNext(writer)
- assert(it.hasNext)
- it.nextPartition should be (6)
- it.writeNext(writer)
- assert(!it.hasNext)
-
- val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
- val iter = stream.asIterator
- iter.next() should be (2)
- iter.next() should be (struct2)
- iter.next() should be (3)
- iter.next() should be (struct3)
- iter.next() should be (1)
- iter.next() should be (struct1)
- assert(!iter.hasNext)
- }
-
- def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = {
- val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS)
- val baos = new ByteArrayOutputStream()
- when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] {
- override def answer(invocationOnMock: InvocationOnMock): Unit = {
- val args = invocationOnMock.getArguments
- val bytes = args(0).asInstanceOf[Array[Byte]]
- val offset = args(1).asInstanceOf[Int]
- val length = args(2).asInstanceOf[Int]
- baos.write(bytes, offset, length)
- }
- })
- (writer, baos)
- }
-}
-
-case class SomeStruct(str: String, num: Int)
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index 46d92ce..be9c36b 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -437,12 +437,9 @@ Apart from these, the following properties are also available, and may be useful
<td><code>spark.shuffle.manager</code></td>
<td>sort</td>
<td>
- Implementation to use for shuffling data. There are three implementations available:
- <code>sort</code>, <code>hash</code> and the new (1.5+) <code>tungsten-sort</code>.
+ Implementation to use for shuffling data. There are two implementations available:
+ <code>sort</code> and <code>hash</code>.
Sort-based shuffle is more memory-efficient and is the default option starting in 1.2.
- Tungsten-sort is similar to the sort based shuffle, with a direct binary cache-friendly
- implementation with a fall back to regular sort based shuffle if its requirements are not
- met.
</td>
</tr>
<tr>
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 0872d3f..b5e661d 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -37,6 +37,7 @@ object MimaExcludes {
Seq(
MimaBuild.excludeSparkPackage("deploy"),
MimaBuild.excludeSparkPackage("network"),
+ MimaBuild.excludeSparkPackage("unsafe"),
// These are needed if checking against the sbt build, since they are part of
// the maven-generated artifacts in 1.3.
excludePackage("org.spark-project.jetty"),
@@ -44,7 +45,11 @@ object MimaExcludes {
// SQL execution is considered private.
excludePackage("org.apache.spark.sql.execution"),
// SQL columnar is considered private.
- excludePackage("org.apache.spark.sql.columnar")
+ excludePackage("org.apache.spark.sql.columnar"),
+ // The shuffle package is considered private.
+ excludePackage("org.apache.spark.shuffle"),
+ // The collections utlities are considered pricate.
+ excludePackage("org.apache.spark.util.collection")
) ++
MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++
MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++
@@ -750,4 +755,4 @@ object MimaExcludes {
MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD")
case _ => Seq()
}
-}
\ No newline at end of file
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 1d3379a..7f60c8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -24,7 +24,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
@@ -87,10 +86,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
// fewer partitions (like RangePartitioner, for example).
val conf = child.sqlContext.sparkContext.conf
val shuffleManager = SparkEnv.get.shuffleManager
- val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] ||
- shuffleManager.isInstanceOf[UnsafeShuffleManager]
+ val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager]
val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
- val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true)
if (sortBasedShuffleOn) {
val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) {
@@ -99,22 +96,18 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
// doesn't buffer deserialized records.
// Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
false
- } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) {
- // SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting
- // them. This optimization is guarded by a feature-flag and is only applied in cases where
- // shuffle dependency does not specify an aggregator or ordering and the record serializer
- // has certain properties. If this optimization is enabled, we can safely avoid the copy.
+ } else if (serializer.supportsRelocationOfSerializedObjects) {
+ // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records
+ // prior to sorting them. This optimization is only applied in cases where shuffle
+ // dependency does not specify an aggregator or ordering and the record serializer has
+ // certain properties. If this optimization is enabled, we can safely avoid the copy.
//
// Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only
// need to check whether the optimization is enabled and supported by our serializer.
- //
- // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081).
false
} else {
- // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code
- // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls
- // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In
- // both cases, we must copy.
+ // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must
+ // copy.
true
}
} else if (shuffleManager.isInstanceOf[HashShuffleManager]) {
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 75d1fce..1680d7e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -101,7 +101,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten
Utils.tryWithSafeFinally {
val conf = new SparkConf()
- .set("spark.shuffle.spill.initialMemoryThreshold", "1024")
+ .set("spark.shuffle.spill.initialMemoryThreshold", "1")
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
.set("spark.testing.memory", "80000")
@@ -109,7 +109,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
// prepare data
val converter = unsafeRowConverter(Array(IntegerType))
- val data = (1 to 1000).iterator.map { i =>
+ val data = (1 to 10000).iterator.map { i =>
(i, converter(Row(i)))
}
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
@@ -141,9 +141,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}
}
- test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") {
- val conf = new SparkConf()
- .set("spark.shuffle.manager", "tungsten-sort")
+ test("SPARK-10403: unsafe row serializer with SortShuffleManager") {
+ val conf = new SparkConf().set("spark.shuffle.manager", "sort")
sc = new SparkContext("local", "test", conf)
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org
[3/4] spark git commit: [SPARK-10708] Consolidate sort shuffle
implementations
Posted by jo...@apache.org.
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
deleted file mode 100644
index e73ba39..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ /dev/null
@@ -1,479 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import javax.annotation.Nullable;
-import java.io.File;
-import java.io.IOException;
-import java.util.LinkedList;
-
-import scala.Tuple2;
-
-import com.google.common.annotations.VisibleForTesting;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.TaskContext;
-import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.serializer.DummySerializerInstance;
-import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.storage.BlockManager;
-import org.apache.spark.storage.DiskBlockObjectWriter;
-import org.apache.spark.storage.TempShuffleBlockId;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
-import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import org.apache.spark.util.Utils;
-
-/**
- * An external sorter that is specialized for sort-based shuffle.
- * <p>
- * Incoming records are appended to data pages. When all records have been inserted (or when the
- * current thread's shuffle memory limit is reached), the in-memory records are sorted according to
- * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
- * written to a single output file (or multiple files, if we've spilled). The format of the output
- * files is the same as the format of the final output file written by
- * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
- * written as a single serialized, compressed stream that can be read with a new decompression and
- * deserialization stream.
- * <p>
- * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its
- * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
- * specialized merge procedure that avoids extra serialization/deserialization.
- */
-final class UnsafeShuffleExternalSorter {
-
- private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
-
- @VisibleForTesting
- static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
-
- private final int initialSize;
- private final int numPartitions;
- private final int pageSizeBytes;
- @VisibleForTesting
- final int maxRecordSizeBytes;
- private final TaskMemoryManager taskMemoryManager;
- private final ShuffleMemoryManager shuffleMemoryManager;
- private final BlockManager blockManager;
- private final TaskContext taskContext;
- private final ShuffleWriteMetrics writeMetrics;
-
- /** The buffer size to use when writing spills using DiskBlockObjectWriter */
- private final int fileBufferSizeBytes;
-
- /**
- * Memory pages that hold the records being sorted. The pages in this list are freed when
- * spilling, although in principle we could recycle these pages across spills (on the other hand,
- * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
- * itself).
- */
- private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
-
- private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
-
- /** Peak memory used by this sorter so far, in bytes. **/
- private long peakMemoryUsedBytes;
-
- // These variables are reset after spilling:
- @Nullable private UnsafeShuffleInMemorySorter inMemSorter;
- @Nullable private MemoryBlock currentPage = null;
- private long currentPagePosition = -1;
- private long freeSpaceInCurrentPage = 0;
-
- public UnsafeShuffleExternalSorter(
- TaskMemoryManager memoryManager,
- ShuffleMemoryManager shuffleMemoryManager,
- BlockManager blockManager,
- TaskContext taskContext,
- int initialSize,
- int numPartitions,
- SparkConf conf,
- ShuffleWriteMetrics writeMetrics) throws IOException {
- this.taskMemoryManager = memoryManager;
- this.shuffleMemoryManager = shuffleMemoryManager;
- this.blockManager = blockManager;
- this.taskContext = taskContext;
- this.initialSize = initialSize;
- this.peakMemoryUsedBytes = initialSize;
- this.numPartitions = numPartitions;
- // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
- this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
- this.pageSizeBytes = (int) Math.min(
- PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
- this.maxRecordSizeBytes = pageSizeBytes - 4;
- this.writeMetrics = writeMetrics;
- initializeForWriting();
-
- // preserve first page to ensure that we have at least one page to work with. Otherwise,
- // other operators in the same task may starve this sorter (SPARK-9709).
- acquireNewPageIfNecessary(pageSizeBytes);
- }
-
- /**
- * Allocates new sort data structures. Called when creating the sorter and after each spill.
- */
- private void initializeForWriting() throws IOException {
- // TODO: move this sizing calculation logic into a static method of sorter:
- final long memoryRequested = initialSize * 8L;
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
- if (memoryAcquired != memoryRequested) {
- shuffleMemoryManager.release(memoryAcquired);
- throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
- }
-
- this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize);
- }
-
- /**
- * Sorts the in-memory records and writes the sorted records to an on-disk file.
- * This method does not free the sort data structures.
- *
- * @param isLastFile if true, this indicates that we're writing the final output file and that the
- * bytes written should be counted towards shuffle spill metrics rather than
- * shuffle write metrics.
- */
- private void writeSortedFile(boolean isLastFile) throws IOException {
-
- final ShuffleWriteMetrics writeMetricsToUse;
-
- if (isLastFile) {
- // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
- writeMetricsToUse = writeMetrics;
- } else {
- // We're spilling, so bytes written should be counted towards spill rather than write.
- // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
- // them towards shuffle bytes written.
- writeMetricsToUse = new ShuffleWriteMetrics();
- }
-
- // This call performs the actual sort.
- final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
- inMemSorter.getSortedIterator();
-
- // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
- // after SPARK-5581 is fixed.
- DiskBlockObjectWriter writer;
-
- // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
- // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
- // data through a byte array. This array does not need to be large enough to hold a single
- // record;
- final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
-
- // Because this output will be read during shuffle, its compression codec must be controlled by
- // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
- // createTempShuffleBlock here; see SPARK-3426 for more details.
- final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
- blockManager.diskBlockManager().createTempShuffleBlock();
- final File file = spilledFileInfo._2();
- final TempShuffleBlockId blockId = spilledFileInfo._1();
- final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
-
- // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
- // Our write path doesn't actually use this serializer (since we end up calling the `write()`
- // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
- // around this, we pass a dummy no-op serializer.
- final SerializerInstance ser = DummySerializerInstance.INSTANCE;
-
- writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
-
- int currentPartition = -1;
- while (sortedRecords.hasNext()) {
- sortedRecords.loadNext();
- final int partition = sortedRecords.packedRecordPointer.getPartitionId();
- assert (partition >= currentPartition);
- if (partition != currentPartition) {
- // Switch to the new partition
- if (currentPartition != -1) {
- writer.commitAndClose();
- spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
- }
- currentPartition = partition;
- writer =
- blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
- }
-
- final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
- final Object recordPage = taskMemoryManager.getPage(recordPointer);
- final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
- int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage);
- long recordReadPosition = recordOffsetInPage + 4; // skip over record length
- while (dataRemaining > 0) {
- final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
- Platform.copyMemory(
- recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
- writer.write(writeBuffer, 0, toTransfer);
- recordReadPosition += toTransfer;
- dataRemaining -= toTransfer;
- }
- writer.recordWritten();
- }
-
- if (writer != null) {
- writer.commitAndClose();
- // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
- // then the file might be empty. Note that it might be better to avoid calling
- // writeSortedFile() in that case.
- if (currentPartition != -1) {
- spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
- spills.add(spillInfo);
- }
- }
-
- if (!isLastFile) { // i.e. this is a spill file
- // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
- // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
- // relies on its `recordWritten()` method being called in order to trigger periodic updates to
- // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
- // counter at a higher-level, then the in-progress metrics for records written and bytes
- // written would get out of sync.
- //
- // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
- // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
- // metrics to the true write metrics here. The reason for performing this copying is so that
- // we can avoid reporting spilled bytes as shuffle write bytes.
- //
- // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
- // Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
- // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
- writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten());
- taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten());
- }
- }
-
- /**
- * Sort and spill the current records in response to memory pressure.
- */
- @VisibleForTesting
- void spill() throws IOException {
- logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
- Thread.currentThread().getId(),
- Utils.bytesToString(getMemoryUsage()),
- spills.size(),
- spills.size() > 1 ? " times" : " time");
-
- writeSortedFile(false);
- final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
- inMemSorter = null;
- shuffleMemoryManager.release(inMemSorterMemoryUsage);
- final long spillSize = freeMemory();
- taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
-
- initializeForWriting();
- }
-
- private long getMemoryUsage() {
- long totalPageSize = 0;
- for (MemoryBlock page : allocatedPages) {
- totalPageSize += page.size();
- }
- return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
- }
-
- private void updatePeakMemoryUsed() {
- long mem = getMemoryUsage();
- if (mem > peakMemoryUsedBytes) {
- peakMemoryUsedBytes = mem;
- }
- }
-
- /**
- * Return the peak memory used so far, in bytes.
- */
- long getPeakMemoryUsedBytes() {
- updatePeakMemoryUsed();
- return peakMemoryUsedBytes;
- }
-
- private long freeMemory() {
- updatePeakMemoryUsed();
- long memoryFreed = 0;
- for (MemoryBlock block : allocatedPages) {
- taskMemoryManager.freePage(block);
- shuffleMemoryManager.release(block.size());
- memoryFreed += block.size();
- }
- allocatedPages.clear();
- currentPage = null;
- currentPagePosition = -1;
- freeSpaceInCurrentPage = 0;
- return memoryFreed;
- }
-
- /**
- * Force all memory and spill files to be deleted; called by shuffle error-handling code.
- */
- public void cleanupResources() {
- freeMemory();
- for (SpillInfo spill : spills) {
- if (spill.file.exists() && !spill.file.delete()) {
- logger.error("Unable to delete spill file {}", spill.file.getPath());
- }
- }
- if (inMemSorter != null) {
- shuffleMemoryManager.release(inMemSorter.getMemoryUsage());
- inMemSorter = null;
- }
- }
-
- /**
- * Checks whether there is enough space to insert an additional record in to the sort pointer
- * array and grows the array if additional space is required. If the required space cannot be
- * obtained, then the in-memory data will be spilled to disk.
- */
- private void growPointerArrayIfNecessary() throws IOException {
- assert(inMemSorter != null);
- if (!inMemSorter.hasSpaceForAnotherRecord()) {
- logger.debug("Attempting to expand sort pointer array");
- final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
- final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
- if (memoryAcquired < memoryToGrowPointerArray) {
- shuffleMemoryManager.release(memoryAcquired);
- spill();
- } else {
- inMemSorter.expandPointerArray();
- shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
- }
- }
- }
-
- /**
- * Allocates more memory in order to insert an additional record. This will request additional
- * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
- * obtained.
- *
- * @param requiredSpace the required space in the data page, in bytes, including space for storing
- * the record size. This must be less than or equal to the page size (records
- * that exceed the page size are handled via a different code path which uses
- * special overflow pages).
- */
- private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
- growPointerArrayIfNecessary();
- if (requiredSpace > freeSpaceInCurrentPage) {
- logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
- freeSpaceInCurrentPage);
- // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
- // without using the free space at the end of the current page. We should also do this for
- // BytesToBytesMap.
- if (requiredSpace > pageSizeBytes) {
- throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
- pageSizeBytes + ")");
- } else {
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryAcquired < pageSizeBytes) {
- shuffleMemoryManager.release(memoryAcquired);
- spill();
- final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryAcquiredAfterSpilling != pageSizeBytes) {
- shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
- throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
- }
- }
- currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
- currentPagePosition = currentPage.getBaseOffset();
- freeSpaceInCurrentPage = pageSizeBytes;
- allocatedPages.add(currentPage);
- }
- }
- }
-
- /**
- * Write a record to the shuffle sorter.
- */
- public void insertRecord(
- Object recordBaseObject,
- long recordBaseOffset,
- int lengthInBytes,
- int partitionId) throws IOException {
-
- growPointerArrayIfNecessary();
- // Need 4 bytes to store the record length.
- final int totalSpaceRequired = lengthInBytes + 4;
-
- // --- Figure out where to insert the new record ----------------------------------------------
-
- final MemoryBlock dataPage;
- long dataPagePosition;
- boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
- if (useOverflowPage) {
- long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
- // The record is larger than the page size, so allocate a special overflow page just to hold
- // that record.
- final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
- if (memoryGranted != overflowPageSize) {
- shuffleMemoryManager.release(memoryGranted);
- spill();
- final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
- if (memoryGrantedAfterSpill != overflowPageSize) {
- shuffleMemoryManager.release(memoryGrantedAfterSpill);
- throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
- }
- }
- MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
- allocatedPages.add(overflowPage);
- dataPage = overflowPage;
- dataPagePosition = overflowPage.getBaseOffset();
- } else {
- // The record is small enough to fit in a regular data page, but the current page might not
- // have enough space to hold it (or no pages have been allocated yet).
- acquireNewPageIfNecessary(totalSpaceRequired);
- dataPage = currentPage;
- dataPagePosition = currentPagePosition;
- // Update bookkeeping information
- freeSpaceInCurrentPage -= totalSpaceRequired;
- currentPagePosition += totalSpaceRequired;
- }
- final Object dataPageBaseObject = dataPage.getBaseObject();
-
- final long recordAddress =
- taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
- Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
- dataPagePosition += 4;
- Platform.copyMemory(
- recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
- assert(inMemSorter != null);
- inMemSorter.insertRecord(recordAddress, partitionId);
- }
-
- /**
- * Close the sorter, causing any buffered data to be sorted and written out to disk.
- *
- * @return metadata for the spill files written by this sorter. If no records were ever inserted
- * into this sorter, then this will return an empty array.
- * @throws IOException
- */
- public SpillInfo[] closeAndGetSpills() throws IOException {
- try {
- if (inMemSorter != null) {
- // Do not count the final file towards the spill count.
- writeSortedFile(true);
- freeMemory();
- }
- return spills.toArray(new SpillInfo[spills.size()]);
- } catch (IOException e) {
- cleanupResources();
- throw e;
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
deleted file mode 100644
index 5bab501..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import java.util.Comparator;
-
-import org.apache.spark.util.collection.Sorter;
-
-final class UnsafeShuffleInMemorySorter {
-
- private final Sorter<PackedRecordPointer, long[]> sorter;
- private static final class SortComparator implements Comparator<PackedRecordPointer> {
- @Override
- public int compare(PackedRecordPointer left, PackedRecordPointer right) {
- return left.getPartitionId() - right.getPartitionId();
- }
- }
- private static final SortComparator SORT_COMPARATOR = new SortComparator();
-
- /**
- * An array of record pointers and partition ids that have been encoded by
- * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
- * records.
- */
- private long[] pointerArray;
-
- /**
- * The position in the pointer array where new records can be inserted.
- */
- private int pointerArrayInsertPosition = 0;
-
- public UnsafeShuffleInMemorySorter(int initialSize) {
- assert (initialSize > 0);
- this.pointerArray = new long[initialSize];
- this.sorter = new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
- }
-
- public void expandPointerArray() {
- final long[] oldArray = pointerArray;
- // Guard against overflow:
- final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
- pointerArray = new long[newLength];
- System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
- }
-
- public boolean hasSpaceForAnotherRecord() {
- return pointerArrayInsertPosition + 1 < pointerArray.length;
- }
-
- public long getMemoryUsage() {
- return pointerArray.length * 8L;
- }
-
- /**
- * Inserts a record to be sorted.
- *
- * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to
- * certain pointer compression techniques used by the sorter, the sort can
- * only operate on pointers that point to locations in the first
- * {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page.
- * @param partitionId the partition id, which must be less than or equal to
- * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}.
- */
- public void insertRecord(long recordPointer, int partitionId) {
- if (!hasSpaceForAnotherRecord()) {
- if (pointerArray.length == Integer.MAX_VALUE) {
- throw new IllegalStateException("Sort pointer array has reached maximum size");
- } else {
- expandPointerArray();
- }
- }
- pointerArray[pointerArrayInsertPosition] =
- PackedRecordPointer.packPointer(recordPointer, partitionId);
- pointerArrayInsertPosition++;
- }
-
- /**
- * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
- */
- public static final class UnsafeShuffleSorterIterator {
-
- private final long[] pointerArray;
- private final int numRecords;
- final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
- private int position = 0;
-
- public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
- this.numRecords = numRecords;
- this.pointerArray = pointerArray;
- }
-
- public boolean hasNext() {
- return position < numRecords;
- }
-
- public void loadNext() {
- packedRecordPointer.set(pointerArray[position]);
- position++;
- }
- }
-
- /**
- * Return an iterator over record pointers in sorted order.
- */
- public UnsafeShuffleSorterIterator getSortedIterator() {
- sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
- return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
deleted file mode 100644
index a66d74e..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import org.apache.spark.util.collection.SortDataFormat;
-
-final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
-
- public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
-
- private UnsafeShuffleSortDataFormat() { }
-
- @Override
- public PackedRecordPointer getKey(long[] data, int pos) {
- // Since we re-use keys, this method shouldn't be called.
- throw new UnsupportedOperationException();
- }
-
- @Override
- public PackedRecordPointer newKey() {
- return new PackedRecordPointer();
- }
-
- @Override
- public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
- reuse.set(data[pos]);
- return reuse;
- }
-
- @Override
- public void swap(long[] data, int pos0, int pos1) {
- final long temp = data[pos0];
- data[pos0] = data[pos1];
- data[pos1] = temp;
- }
-
- @Override
- public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
- dst[dstPos] = src[srcPos];
- }
-
- @Override
- public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
- System.arraycopy(src, srcPos, dst, dstPos, length);
- }
-
- @Override
- public long[] allocate(int length) {
- return new long[length];
- }
-
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
deleted file mode 100644
index fdb309e..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ /dev/null
@@ -1,489 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import javax.annotation.Nullable;
-import java.io.*;
-import java.nio.channels.FileChannel;
-import java.util.Iterator;
-
-import scala.Option;
-import scala.Product2;
-import scala.collection.JavaConverters;
-import scala.collection.immutable.Map;
-import scala.reflect.ClassTag;
-import scala.reflect.ClassTag$;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.io.ByteStreams;
-import com.google.common.io.Closeables;
-import com.google.common.io.Files;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import org.apache.spark.*;
-import org.apache.spark.annotation.Private;
-import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.io.CompressionCodec;
-import org.apache.spark.io.CompressionCodec$;
-import org.apache.spark.io.LZFCompressionCodec;
-import org.apache.spark.network.util.LimitedInputStream;
-import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.scheduler.MapStatus$;
-import org.apache.spark.serializer.SerializationStream;
-import org.apache.spark.serializer.Serializer;
-import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.shuffle.ShuffleWriter;
-import org.apache.spark.storage.BlockManager;
-import org.apache.spark.storage.TimeTrackingOutputStream;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-
-@Private
-public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
-
- private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
-
- private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
-
- @VisibleForTesting
- static final int INITIAL_SORT_BUFFER_SIZE = 4096;
-
- private final BlockManager blockManager;
- private final IndexShuffleBlockResolver shuffleBlockResolver;
- private final TaskMemoryManager memoryManager;
- private final ShuffleMemoryManager shuffleMemoryManager;
- private final SerializerInstance serializer;
- private final Partitioner partitioner;
- private final ShuffleWriteMetrics writeMetrics;
- private final int shuffleId;
- private final int mapId;
- private final TaskContext taskContext;
- private final SparkConf sparkConf;
- private final boolean transferToEnabled;
-
- @Nullable private MapStatus mapStatus;
- @Nullable private UnsafeShuffleExternalSorter sorter;
- private long peakMemoryUsedBytes = 0;
-
- /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
- private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
- public MyByteArrayOutputStream(int size) { super(size); }
- public byte[] getBuf() { return buf; }
- }
-
- private MyByteArrayOutputStream serBuffer;
- private SerializationStream serOutputStream;
-
- /**
- * Are we in the process of stopping? Because map tasks can call stop() with success = true
- * and then call stop() with success = false if they get an exception, we want to make sure
- * we don't try deleting files, etc twice.
- */
- private boolean stopping = false;
-
- public UnsafeShuffleWriter(
- BlockManager blockManager,
- IndexShuffleBlockResolver shuffleBlockResolver,
- TaskMemoryManager memoryManager,
- ShuffleMemoryManager shuffleMemoryManager,
- UnsafeShuffleHandle<K, V> handle,
- int mapId,
- TaskContext taskContext,
- SparkConf sparkConf) throws IOException {
- final int numPartitions = handle.dependency().partitioner().numPartitions();
- if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) {
- throw new IllegalArgumentException(
- "UnsafeShuffleWriter can only be used for shuffles with at most " +
- UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions");
- }
- this.blockManager = blockManager;
- this.shuffleBlockResolver = shuffleBlockResolver;
- this.memoryManager = memoryManager;
- this.shuffleMemoryManager = shuffleMemoryManager;
- this.mapId = mapId;
- final ShuffleDependency<K, V, V> dep = handle.dependency();
- this.shuffleId = dep.shuffleId();
- this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
- this.partitioner = dep.partitioner();
- this.writeMetrics = new ShuffleWriteMetrics();
- taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
- this.taskContext = taskContext;
- this.sparkConf = sparkConf;
- this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
- open();
- }
-
- @VisibleForTesting
- public int maxRecordSizeBytes() {
- assert(sorter != null);
- return sorter.maxRecordSizeBytes;
- }
-
- private void updatePeakMemoryUsed() {
- // sorter can be null if this writer is closed
- if (sorter != null) {
- long mem = sorter.getPeakMemoryUsedBytes();
- if (mem > peakMemoryUsedBytes) {
- peakMemoryUsedBytes = mem;
- }
- }
- }
-
- /**
- * Return the peak memory used so far, in bytes.
- */
- public long getPeakMemoryUsedBytes() {
- updatePeakMemoryUsed();
- return peakMemoryUsedBytes;
- }
-
- /**
- * This convenience method should only be called in test code.
- */
- @VisibleForTesting
- public void write(Iterator<Product2<K, V>> records) throws IOException {
- write(JavaConverters.asScalaIteratorConverter(records).asScala());
- }
-
- @Override
- public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
- // Keep track of success so we know if we encountered an exception
- // We do this rather than a standard try/catch/re-throw to handle
- // generic throwables.
- boolean success = false;
- try {
- while (records.hasNext()) {
- insertRecordIntoSorter(records.next());
- }
- closeAndWriteOutput();
- success = true;
- } finally {
- if (sorter != null) {
- try {
- sorter.cleanupResources();
- } catch (Exception e) {
- // Only throw this error if we won't be masking another
- // error.
- if (success) {
- throw e;
- } else {
- logger.error("In addition to a failure during writing, we failed during " +
- "cleanup.", e);
- }
- }
- }
- }
- }
-
- private void open() throws IOException {
- assert (sorter == null);
- sorter = new UnsafeShuffleExternalSorter(
- memoryManager,
- shuffleMemoryManager,
- blockManager,
- taskContext,
- INITIAL_SORT_BUFFER_SIZE,
- partitioner.numPartitions(),
- sparkConf,
- writeMetrics);
- serBuffer = new MyByteArrayOutputStream(1024 * 1024);
- serOutputStream = serializer.serializeStream(serBuffer);
- }
-
- @VisibleForTesting
- void closeAndWriteOutput() throws IOException {
- assert(sorter != null);
- updatePeakMemoryUsed();
- serBuffer = null;
- serOutputStream = null;
- final SpillInfo[] spills = sorter.closeAndGetSpills();
- sorter = null;
- final long[] partitionLengths;
- try {
- partitionLengths = mergeSpills(spills);
- } finally {
- for (SpillInfo spill : spills) {
- if (spill.file.exists() && ! spill.file.delete()) {
- logger.error("Error while deleting spill file {}", spill.file.getPath());
- }
- }
- }
- shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
- }
-
- @VisibleForTesting
- void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
- assert(sorter != null);
- final K key = record._1();
- final int partitionId = partitioner.getPartition(key);
- serBuffer.reset();
- serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
- serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
- serOutputStream.flush();
-
- final int serializedRecordSize = serBuffer.size();
- assert (serializedRecordSize > 0);
-
- sorter.insertRecord(
- serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
- }
-
- @VisibleForTesting
- void forceSorterToSpill() throws IOException {
- assert (sorter != null);
- sorter.spill();
- }
-
- /**
- * Merge zero or more spill files together, choosing the fastest merging strategy based on the
- * number of spills and the IO compression codec.
- *
- * @return the partition lengths in the merged file.
- */
- private long[] mergeSpills(SpillInfo[] spills) throws IOException {
- final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
- final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
- final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
- final boolean fastMergeEnabled =
- sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
- final boolean fastMergeIsSupported =
- !compressionEnabled || compressionCodec instanceof LZFCompressionCodec;
- try {
- if (spills.length == 0) {
- new FileOutputStream(outputFile).close(); // Create an empty file
- return new long[partitioner.numPartitions()];
- } else if (spills.length == 1) {
- // Here, we don't need to perform any metrics updates because the bytes written to this
- // output file would have already been counted as shuffle bytes written.
- Files.move(spills[0].file, outputFile);
- return spills[0].partitionLengths;
- } else {
- final long[] partitionLengths;
- // There are multiple spills to merge, so none of these spill files' lengths were counted
- // towards our shuffle write count or shuffle write time. If we use the slow merge path,
- // then the final output file's size won't necessarily be equal to the sum of the spill
- // files' sizes. To guard against this case, we look at the output file's actual size when
- // computing shuffle bytes written.
- //
- // We allow the individual merge methods to report their own IO times since different merge
- // strategies use different IO techniques. We count IO during merge towards the shuffle
- // shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
- // branch in ExternalSorter.
- if (fastMergeEnabled && fastMergeIsSupported) {
- // Compression is disabled or we are using an IO compression codec that supports
- // decompression of concatenated compressed streams, so we can perform a fast spill merge
- // that doesn't need to interpret the spilled bytes.
- if (transferToEnabled) {
- logger.debug("Using transferTo-based fast merge");
- partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
- } else {
- logger.debug("Using fileStream-based fast merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
- }
- } else {
- logger.debug("Using slow merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
- }
- // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
- // in-memory records, we write out the in-memory records to a file but do not count that
- // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
- // to be counted as shuffle write, but this will lead to double-counting of the final
- // SpillInfo's bytes.
- writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length());
- writeMetrics.incShuffleBytesWritten(outputFile.length());
- return partitionLengths;
- }
- } catch (IOException e) {
- if (outputFile.exists() && !outputFile.delete()) {
- logger.error("Unable to delete output file {}", outputFile.getPath());
- }
- throw e;
- }
- }
-
- /**
- * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
- * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
- * cases where the IO compression codec does not support concatenation of compressed data, or in
- * cases where users have explicitly disabled use of {@code transferTo} in order to work around
- * kernel bugs.
- *
- * @param spills the spills to merge.
- * @param outputFile the file to write the merged data to.
- * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
- * @return the partition lengths in the merged file.
- */
- private long[] mergeSpillsWithFileStream(
- SpillInfo[] spills,
- File outputFile,
- @Nullable CompressionCodec compressionCodec) throws IOException {
- assert (spills.length >= 2);
- final int numPartitions = partitioner.numPartitions();
- final long[] partitionLengths = new long[numPartitions];
- final InputStream[] spillInputStreams = new FileInputStream[spills.length];
- OutputStream mergedFileOutputStream = null;
-
- boolean threwException = true;
- try {
- for (int i = 0; i < spills.length; i++) {
- spillInputStreams[i] = new FileInputStream(spills[i].file);
- }
- for (int partition = 0; partition < numPartitions; partition++) {
- final long initialFileLength = outputFile.length();
- mergedFileOutputStream =
- new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
- if (compressionCodec != null) {
- mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
- }
-
- for (int i = 0; i < spills.length; i++) {
- final long partitionLengthInSpill = spills[i].partitionLengths[partition];
- if (partitionLengthInSpill > 0) {
- InputStream partitionInputStream =
- new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill);
- if (compressionCodec != null) {
- partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
- }
- ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
- }
- }
- mergedFileOutputStream.flush();
- mergedFileOutputStream.close();
- partitionLengths[partition] = (outputFile.length() - initialFileLength);
- }
- threwException = false;
- } finally {
- // To avoid masking exceptions that caused us to prematurely enter the finally block, only
- // throw exceptions during cleanup if threwException == false.
- for (InputStream stream : spillInputStreams) {
- Closeables.close(stream, threwException);
- }
- Closeables.close(mergedFileOutputStream, threwException);
- }
- return partitionLengths;
- }
-
- /**
- * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
- * This is only safe when the IO compression codec and serializer support concatenation of
- * serialized streams.
- *
- * @return the partition lengths in the merged file.
- */
- private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
- assert (spills.length >= 2);
- final int numPartitions = partitioner.numPartitions();
- final long[] partitionLengths = new long[numPartitions];
- final FileChannel[] spillInputChannels = new FileChannel[spills.length];
- final long[] spillInputChannelPositions = new long[spills.length];
- FileChannel mergedFileOutputChannel = null;
-
- boolean threwException = true;
- try {
- for (int i = 0; i < spills.length; i++) {
- spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
- }
- // This file needs to opened in append mode in order to work around a Linux kernel bug that
- // affects transferTo; see SPARK-3948 for more details.
- mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
-
- long bytesWrittenToMergedFile = 0;
- for (int partition = 0; partition < numPartitions; partition++) {
- for (int i = 0; i < spills.length; i++) {
- final long partitionLengthInSpill = spills[i].partitionLengths[partition];
- long bytesToTransfer = partitionLengthInSpill;
- final FileChannel spillInputChannel = spillInputChannels[i];
- final long writeStartTime = System.nanoTime();
- while (bytesToTransfer > 0) {
- final long actualBytesTransferred = spillInputChannel.transferTo(
- spillInputChannelPositions[i],
- bytesToTransfer,
- mergedFileOutputChannel);
- spillInputChannelPositions[i] += actualBytesTransferred;
- bytesToTransfer -= actualBytesTransferred;
- }
- writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
- bytesWrittenToMergedFile += partitionLengthInSpill;
- partitionLengths[partition] += partitionLengthInSpill;
- }
- }
- // Check the position after transferTo loop to see if it is in the right position and raise an
- // exception if it is incorrect. The position will not be increased to the expected length
- // after calling transferTo in kernel version 2.6.32. This issue is described at
- // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
- if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
- throw new IOException(
- "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
- "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
- " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
- "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
- "to disable this NIO feature."
- );
- }
- threwException = false;
- } finally {
- // To avoid masking exceptions that caused us to prematurely enter the finally block, only
- // throw exceptions during cleanup if threwException == false.
- for (int i = 0; i < spills.length; i++) {
- assert(spillInputChannelPositions[i] == spills[i].file.length());
- Closeables.close(spillInputChannels[i], threwException);
- }
- Closeables.close(mergedFileOutputChannel, threwException);
- }
- return partitionLengths;
- }
-
- @Override
- public Option<MapStatus> stop(boolean success) {
- try {
- // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite)
- Map<String, Accumulator<Object>> internalAccumulators =
- taskContext.internalMetricsToAccumulators();
- if (internalAccumulators != null) {
- internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY())
- .add(getPeakMemoryUsedBytes());
- }
-
- if (stopping) {
- return Option.apply(null);
- } else {
- stopping = true;
- if (success) {
- if (mapStatus == null) {
- throw new IllegalStateException("Cannot call stop(true) without having called write()");
- }
- return Option.apply(mapStatus);
- } else {
- // The map task failed, so delete our output data.
- shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
- return Option.apply(null);
- }
- }
- } finally {
- if (sorter != null) {
- // If sorter is non-null, then this implies that we called stop() in response to an error,
- // so we need to clean up memory and spill files created by the sorter
- sorter.cleanupResources();
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index c329983..704158b 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -330,7 +330,7 @@ object SparkEnv extends Logging {
val shortShuffleMgrNames = Map(
"hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
"sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager",
- "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager")
+ "tungsten-sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 9df4e55..1105167 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -19,9 +19,53 @@ package org.apache.spark.shuffle.sort
import java.util.concurrent.ConcurrentHashMap
-import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency}
+import org.apache.spark._
+import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
+/**
+ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
+ * written to a single map output file. Reducers fetch contiguous regions of this file in order to
+ * read their portion of the map output. In cases where the map output data is too large to fit in
+ * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
+ * to produce the final output file.
+ *
+ * Sort-based shuffle has two different write paths for producing its map output files:
+ *
+ * - Serialized sorting: used when all three of the following conditions hold:
+ * 1. The shuffle dependency specifies no aggregation or output ordering.
+ * 2. The shuffle serializer supports relocation of serialized values (this is currently
+ * supported by KryoSerializer and Spark SQL's custom serializers).
+ * 3. The shuffle produces fewer than 16777216 output partitions.
+ * - Deserialized sorting: used to handle all other cases.
+ *
+ * -----------------------
+ * Serialized sorting mode
+ * -----------------------
+ *
+ * In the serialized sorting mode, incoming records are serialized as soon as they are passed to the
+ * shuffle writer and are buffered in a serialized form during sorting. This write path implements
+ * several optimizations:
+ *
+ * - Its sort operates on serialized binary data rather than Java objects, which reduces memory
+ * consumption and GC overheads. This optimization requires the record serializer to have certain
+ * properties to allow serialized records to be re-ordered without requiring deserialization.
+ * See SPARK-4550, where this optimization was first proposed and implemented, for more details.
+ *
+ * - It uses a specialized cache-efficient sorter ([[ShuffleExternalSorter]]) that sorts
+ * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
+ * record in the sorting array, this fits more of the array into cache.
+ *
+ * - The spill merging procedure operates on blocks of serialized records that belong to the same
+ * partition and does not need to deserialize records during the merge.
+ *
+ * - When the spill compression codec supports concatenation of compressed data, the spill merge
+ * simply concatenates the serialized and compressed spill partitions to produce the final output
+ * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used
+ * and avoids the need to allocate decompression or copying buffers during the merge.
+ *
+ * For more details on these optimizations, see SPARK-7081.
+ */
private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
if (!conf.getBoolean("spark.shuffle.spill", true)) {
@@ -30,8 +74,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
" Shuffle will continue to spill to disk when necessary.")
}
- private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf)
- private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]()
+ /**
+ * A mapping from shuffle ids to the number of mappers producing output for those shuffles.
+ */
+ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]()
+
+ override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
/**
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
@@ -40,7 +88,22 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
- new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) {
+ // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
+ // need map-side aggregation, then write numPartitions files directly and just concatenate
+ // them at the end. This avoids doing serialization and deserialization twice to merge
+ // together the spilled files, which would happen with the normal code path. The downside is
+ // having multiple files open at a time and thus more memory allocated to buffers.
+ new BypassMergeSortShuffleHandle[K, V](
+ shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
+ // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
+ new SerializedShuffleHandle[K, V](
+ shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else {
+ // Otherwise, buffer map outputs in a deserialized form:
+ new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ }
}
/**
@@ -52,38 +115,114 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
- // We currently use the same block store shuffle fetcher as the hash-based shuffle.
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
/** Get a writer for a given partition. Called on executors by map tasks. */
- override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
- : ShuffleWriter[K, V] = {
- val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]]
- shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps)
- new SortShuffleWriter(
- shuffleBlockResolver, baseShuffleHandle, mapId, context)
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Int,
+ context: TaskContext): ShuffleWriter[K, V] = {
+ numMapsForShuffle.putIfAbsent(
+ handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
+ val env = SparkEnv.get
+ handle match {
+ case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
+ new UnsafeShuffleWriter(
+ env.blockManager,
+ shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+ context.taskMemoryManager(),
+ env.shuffleMemoryManager,
+ unsafeShuffleHandle,
+ mapId,
+ context,
+ env.conf)
+ case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
+ new BypassMergeSortShuffleWriter(
+ env.blockManager,
+ shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+ bypassMergeSortHandle,
+ mapId,
+ context,
+ env.conf)
+ case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
+ new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
+ }
}
/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Boolean = {
- if (shuffleMapNumber.containsKey(shuffleId)) {
- val numMaps = shuffleMapNumber.remove(shuffleId)
- (0 until numMaps).map{ mapId =>
+ Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps =>
+ (0 until numMaps).foreach { mapId =>
shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
}
}
true
}
- override val shuffleBlockResolver: IndexShuffleBlockResolver = {
- indexShuffleBlockResolver
- }
-
/** Shut down this ShuffleManager. */
override def stop(): Unit = {
shuffleBlockResolver.stop()
}
}
+
+private[spark] object SortShuffleManager extends Logging {
+
+ /**
+ * The maximum number of shuffle output partitions that SortShuffleManager supports when
+ * buffering map outputs in a serialized form. This is an extreme defensive programming measure,
+ * since it's extremely unlikely that a single shuffle produces over 16 million output partitions.
+ * */
+ val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE =
+ PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
+
+ /**
+ * Helper method for determining whether a shuffle should use an optimized serialized shuffle
+ * path or whether it should fall back to the original path that operates on deserialized objects.
+ */
+ def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
+ val shufId = dependency.shuffleId
+ val numPartitions = dependency.partitioner.numPartitions
+ val serializer = Serializer.getSerializer(dependency.serializer)
+ if (!serializer.supportsRelocationOfSerializedObjects) {
+ log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
+ s"${serializer.getClass.getName}, does not support object relocation")
+ false
+ } else if (dependency.aggregator.isDefined) {
+ log.debug(
+ s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined")
+ false
+ } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
+ log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
+ s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions")
+ false
+ } else {
+ log.debug(s"Can use serialized shuffle for shuffle $shufId")
+ true
+ }
+ }
+}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
+ * serialized shuffle.
+ */
+private[spark] class SerializedShuffleHandle[K, V](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, V])
+ extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
+ * bypass merge sort shuffle path.
+ */
+private[spark] class BypassMergeSortShuffleHandle[K, V](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, V])
+ extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 5865e76..bbd9c1a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -20,7 +20,6 @@ package org.apache.spark.shuffle.sort
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
@@ -36,7 +35,7 @@ private[spark] class SortShuffleWriter[K, V, C](
private val blockManager = SparkEnv.get.blockManager
- private var sorter: SortShuffleFileWriter[K, V] = null
+ private var sorter: ExternalSorter[K, V, _] = null
// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
@@ -54,15 +53,6 @@ private[spark] class SortShuffleWriter[K, V, C](
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
- } else if (SortShuffleWriter.shouldBypassMergeSort(
- SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
- // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
- // need local aggregation and sorting, write numPartitions files directly and just concatenate
- // them at the end. This avoids doing serialization and deserialization twice to merge
- // together the spilled files, which would happen with the normal code path. The downside is
- // having multiple files open at a time and thus more memory allocated to buffers.
- new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
- writeMetrics, Serializer.getSerializer(dep.serializer))
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
@@ -111,12 +101,14 @@ private[spark] class SortShuffleWriter[K, V, C](
}
private[spark] object SortShuffleWriter {
- def shouldBypassMergeSort(
- conf: SparkConf,
- numPartitions: Int,
- aggregator: Option[Aggregator[_, _, _]],
- keyOrdering: Option[Ordering[_]]): Boolean = {
- val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
- numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty
+ def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
+ // We cannot bypass sorting if we need to do map-side aggregation.
+ if (dep.mapSideCombine) {
+ require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
+ false
+ } else {
+ val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+ dep.partitioner.numPartitions <= bypassMergeThreshold
+ }
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
deleted file mode 100644
index 75f22f6..0000000
--- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
+++ /dev/null
@@ -1,202 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe
-
-import java.util.Collections
-import java.util.concurrent.ConcurrentHashMap
-
-import org.apache.spark._
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle._
-import org.apache.spark.shuffle.sort.SortShuffleManager
-
-/**
- * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
- */
-private[spark] class UnsafeShuffleHandle[K, V](
- shuffleId: Int,
- numMaps: Int,
- dependency: ShuffleDependency[K, V, V])
- extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
-}
-
-private[spark] object UnsafeShuffleManager extends Logging {
-
- /**
- * The maximum number of shuffle output partitions that UnsafeShuffleManager supports.
- */
- val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
-
- /**
- * Helper method for determining whether a shuffle should use the optimized unsafe shuffle
- * path or whether it should fall back to the original sort-based shuffle.
- */
- def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
- val shufId = dependency.shuffleId
- val serializer = Serializer.getSerializer(dependency.serializer)
- if (!serializer.supportsRelocationOfSerializedObjects) {
- log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
- s"${serializer.getClass.getName}, does not support object relocation")
- false
- } else if (dependency.aggregator.isDefined) {
- log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
- false
- } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) {
- log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " +
- s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions")
- false
- } else {
- log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
- true
- }
- }
-}
-
-/**
- * A shuffle implementation that uses directly-managed memory to implement several performance
- * optimizations for certain types of shuffles. In cases where the new performance optimizations
- * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those
- * shuffles.
- *
- * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold:
- *
- * - The shuffle dependency specifies no aggregation or output ordering.
- * - The shuffle serializer supports relocation of serialized values (this is currently supported
- * by KryoSerializer and Spark SQL's custom serializers).
- * - The shuffle produces fewer than 16777216 output partitions.
- * - No individual record is larger than 128 MB when serialized.
- *
- * In addition, extra spill-merging optimizations are automatically applied when the shuffle
- * compression codec supports concatenation of serialized streams. This is currently supported by
- * Spark's LZF serializer.
- *
- * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager.
- * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
- * written to a single map output file. Reducers fetch contiguous regions of this file in order to
- * read their portion of the map output. In cases where the map output data is too large to fit in
- * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
- * to produce the final output file.
- *
- * UnsafeShuffleManager optimizes this process in several ways:
- *
- * - Its sort operates on serialized binary data rather than Java objects, which reduces memory
- * consumption and GC overheads. This optimization requires the record serializer to have certain
- * properties to allow serialized records to be re-ordered without requiring deserialization.
- * See SPARK-4550, where this optimization was first proposed and implemented, for more details.
- *
- * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts
- * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
- * record in the sorting array, this fits more of the array into cache.
- *
- * - The spill merging procedure operates on blocks of serialized records that belong to the same
- * partition and does not need to deserialize records during the merge.
- *
- * - When the spill compression codec supports concatenation of compressed data, the spill merge
- * simply concatenates the serialized and compressed spill partitions to produce the final output
- * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used
- * and avoids the need to allocate decompression or copying buffers during the merge.
- *
- * For more details on UnsafeShuffleManager's design, see SPARK-7081.
- */
-private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
-
- if (!conf.getBoolean("spark.shuffle.spill", true)) {
- logWarning(
- "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " +
- "manager; its optimized shuffles will continue to spill to disk when necessary.")
- }
-
- private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
- private[this] val shufflesThatFellBackToSortShuffle =
- Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
- private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()
-
- /**
- * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
- */
- override def registerShuffle[K, V, C](
- shuffleId: Int,
- numMaps: Int,
- dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
- if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
- new UnsafeShuffleHandle[K, V](
- shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
- } else {
- new BaseShuffleHandle(shuffleId, numMaps, dependency)
- }
- }
-
- /**
- * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
- * Called on executors by reduce tasks.
- */
- override def getReader[K, C](
- handle: ShuffleHandle,
- startPartition: Int,
- endPartition: Int,
- context: TaskContext): ShuffleReader[K, C] = {
- sortShuffleManager.getReader(handle, startPartition, endPartition, context)
- }
-
- /** Get a writer for a given partition. Called on executors by map tasks. */
- override def getWriter[K, V](
- handle: ShuffleHandle,
- mapId: Int,
- context: TaskContext): ShuffleWriter[K, V] = {
- handle match {
- case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] =>
- numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
- val env = SparkEnv.get
- new UnsafeShuffleWriter(
- env.blockManager,
- shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
- context.taskMemoryManager(),
- env.shuffleMemoryManager,
- unsafeShuffleHandle,
- mapId,
- context,
- env.conf)
- case other =>
- shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
- sortShuffleManager.getWriter(handle, mapId, context)
- }
- }
-
- /** Remove a shuffle's metadata from the ShuffleManager. */
- override def unregisterShuffle(shuffleId: Int): Boolean = {
- if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
- sortShuffleManager.unregisterShuffle(shuffleId)
- } else {
- Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
- (0 until numMaps).foreach { mapId =>
- shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
- }
- }
- true
- }
- }
-
- override val shuffleBlockResolver: IndexShuffleBlockResolver = {
- sortShuffleManager.shuffleBlockResolver
- }
-
- /** Shut down this ShuffleManager. */
- override def stop(): Unit = {
- sortShuffleManager.stop()
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
deleted file mode 100644
index ae60f3b..0000000
--- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.OutputStream
-
-import scala.collection.mutable.ArrayBuffer
-
-/**
- * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The
- * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts
- * of memory and needing to copy the full contents. The disadvantage is that the contents don't
- * occupy a contiguous segment of memory.
- */
-private[spark] class ChainedBuffer(chunkSize: Int) {
-
- private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros(
- java.lang.Long.highestOneBit(chunkSize))
- assert((1 << chunkSizeLog2) == chunkSize,
- s"ChainedBuffer chunk size $chunkSize must be a power of two")
- private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
- private var _size: Long = 0
-
- /**
- * Feed bytes from this buffer into a DiskBlockObjectWriter.
- *
- * @param pos Offset in the buffer to read from.
- * @param os OutputStream to read into.
- * @param len Number of bytes to read.
- */
- def read(pos: Long, os: OutputStream, len: Int): Unit = {
- if (pos + len > _size) {
- throw new IndexOutOfBoundsException(
- s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
- }
- var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
- var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
- var written: Int = 0
- while (written < len) {
- val toRead: Int = math.min(len - written, chunkSize - posInChunk)
- os.write(chunks(chunkIndex), posInChunk, toRead)
- written += toRead
- chunkIndex += 1
- posInChunk = 0
- }
- }
-
- /**
- * Read bytes from this buffer into a byte array.
- *
- * @param pos Offset in the buffer to read from.
- * @param bytes Byte array to read into.
- * @param offs Offset in the byte array to read to.
- * @param len Number of bytes to read.
- */
- def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
- if (pos + len > _size) {
- throw new IndexOutOfBoundsException(
- s"Read of $len bytes at position $pos would go past size of buffer")
- }
- var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
- var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
- var written: Int = 0
- while (written < len) {
- val toRead: Int = math.min(len - written, chunkSize - posInChunk)
- System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
- written += toRead
- chunkIndex += 1
- posInChunk = 0
- }
- }
-
- /**
- * Write bytes from a byte array into this buffer.
- *
- * @param pos Offset in the buffer to write to.
- * @param bytes Byte array to write from.
- * @param offs Offset in the byte array to write from.
- * @param len Number of bytes to write.
- */
- def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
- if (pos > _size) {
- throw new IndexOutOfBoundsException(
- s"Write at position $pos starts after end of buffer ${_size}")
- }
- // Grow if needed
- val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt
- while (endChunkIndex >= chunks.length) {
- chunks += new Array[Byte](chunkSize)
- }
-
- var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
- var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
- var written: Int = 0
- while (written < len) {
- val toWrite: Int = math.min(len - written, chunkSize - posInChunk)
- System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
- written += toWrite
- chunkIndex += 1
- posInChunk = 0
- }
-
- _size = math.max(_size, pos + len)
- }
-
- /**
- * Total size of buffer that can be written to without allocating additional memory.
- */
- def capacity: Long = chunks.size.toLong * chunkSize
-
- /**
- * Size of the logical buffer.
- */
- def size: Long = _size
-}
-
-/**
- * Output stream that writes to a ChainedBuffer.
- */
-private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
- private var pos: Long = 0
-
- override def write(b: Int): Unit = {
- throw new UnsupportedOperationException()
- }
-
- override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
- chainedBuffer.write(pos, bytes, offs, len)
- pos += len
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 749be34..c48c453 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -29,7 +29,6 @@ import com.google.common.io.ByteStreams
import org.apache.spark._
import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
/**
@@ -69,8 +68,8 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
* At a high level, this class works internally as follows:
*
* - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if
- * we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we
- * don't. Inside these buffers, we sort elements by partition ID and then possibly also by key.
+ * we want to combine by key, or a PartitionedPairBuffer if we don't.
+ * Inside these buffers, we sort elements by partition ID and then possibly also by key.
* To avoid calling the partitioner multiple times with each key, we store the partition ID
* alongside each record.
*
@@ -93,8 +92,7 @@ private[spark] class ExternalSorter[K, V, C](
ordering: Option[Ordering[K]] = None,
serializer: Option[Serializer] = None)
extends Logging
- with Spillable[WritablePartitionedPairCollection[K, C]]
- with SortShuffleFileWriter[K, V] {
+ with Spillable[WritablePartitionedPairCollection[K, C]] {
private val conf = SparkEnv.get.conf
@@ -104,13 +102,6 @@ private[spark] class ExternalSorter[K, V, C](
if (shouldPartition) partitioner.get.getPartition(key) else 0
}
- // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class.
- // As a sanity check, make sure that we're not handling a shuffle which should use that path.
- if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) {
- throw new IllegalArgumentException("ExternalSorter should not be used to handle "
- + " a sort that the BypassMergeSortShuffleWriter should handle")
- }
-
private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
private val ser = Serializer.getSerializer(serializer)
@@ -128,23 +119,11 @@ private[spark] class ExternalSorter[K, V, C](
// grow internal data structures by growing + copying every time the number of objects doubles.
private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
- private val useSerializedPairBuffer =
- ordering.isEmpty &&
- conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
- ser.supportsRelocationOfSerializedObjects
- private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
- private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = {
- if (useSerializedPairBuffer) {
- new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance)
- } else {
- new PartitionedPairBuffer[K, C]
- }
- }
// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
// store them in an array buffer.
private var map = new PartitionedAppendOnlyMap[K, C]
- private var buffer = newBuffer()
+ private var buffer = new PartitionedPairBuffer[K, C]
// Total spilling statistics
private var _diskBytesSpilled = 0L
@@ -192,7 +171,7 @@ private[spark] class ExternalSorter[K, V, C](
*/
private[spark] def numSpills: Int = spills.size
- override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
+ def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined
@@ -236,7 +215,7 @@ private[spark] class ExternalSorter[K, V, C](
} else {
estimatedSize = buffer.estimateSize()
if (maybeSpill(buffer, estimatedSize)) {
- buffer = newBuffer()
+ buffer = new PartitionedPairBuffer[K, C]
}
}
@@ -659,7 +638,7 @@ private[spark] class ExternalSorter[K, V, C](
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
- override def writePartitionedFile(
+ def writePartitionedFile(
blockId: BlockId,
context: TaskContext,
outputFile: File): Array[Long] = {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org
[4/4] spark git commit: [SPARK-10708] Consolidate sort shuffle
implementations
Posted by jo...@apache.org.
[SPARK-10708] Consolidate sort shuffle implementations
There's a lot of duplication between SortShuffleManager and UnsafeShuffleManager. Given that these now provide the same set of functionality, now that UnsafeShuffleManager supports large records, I think that we should replace SortShuffleManager's serialized shuffle implementation with UnsafeShuffleManager's and should merge the two managers together.
Author: Josh Rosen <jo...@databricks.com>
Closes #8829 from JoshRosen/consolidate-sort-shuffle-implementations.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f6d06adf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f6d06adf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f6d06adf
Branch: refs/heads/master
Commit: f6d06adf05afa9c5386dc2396c94e7a98730289f
Parents: 94e2064
Author: Josh Rosen <jo...@databricks.com>
Authored: Thu Oct 22 09:46:30 2015 -0700
Committer: Josh Rosen <jo...@databricks.com>
Committed: Thu Oct 22 09:46:30 2015 -0700
----------------------------------------------------------------------
.../sort/BypassMergeSortShuffleWriter.java | 106 +++-
.../spark/shuffle/sort/PackedRecordPointer.java | 92 +++
.../shuffle/sort/ShuffleExternalSorter.java | 491 ++++++++++++++++
.../shuffle/sort/ShuffleInMemorySorter.java | 124 ++++
.../shuffle/sort/ShuffleSortDataFormat.java | 67 +++
.../shuffle/sort/SortShuffleFileWriter.java | 53 --
.../apache/spark/shuffle/sort/SpillInfo.java | 37 ++
.../spark/shuffle/sort/UnsafeShuffleWriter.java | 489 ++++++++++++++++
.../shuffle/unsafe/PackedRecordPointer.java | 92 ---
.../apache/spark/shuffle/unsafe/SpillInfo.java | 37 --
.../unsafe/UnsafeShuffleExternalSorter.java | 479 ----------------
.../unsafe/UnsafeShuffleInMemorySorter.java | 124 ----
.../unsafe/UnsafeShuffleSortDataFormat.java | 67 ---
.../shuffle/unsafe/UnsafeShuffleWriter.java | 489 ----------------
.../main/scala/org/apache/spark/SparkEnv.scala | 2 +-
.../spark/shuffle/sort/SortShuffleManager.scala | 175 +++++-
.../spark/shuffle/sort/SortShuffleWriter.scala | 28 +-
.../shuffle/unsafe/UnsafeShuffleManager.scala | 202 -------
.../spark/util/collection/ChainedBuffer.scala | 146 -----
.../spark/util/collection/ExternalSorter.scala | 35 +-
.../PartitionedSerializedPairBuffer.scala | 273 ---------
.../shuffle/sort/PackedRecordPointerSuite.java | 102 ++++
.../sort/ShuffleInMemorySorterSuite.java | 124 ++++
.../shuffle/sort/UnsafeShuffleWriterSuite.java | 560 +++++++++++++++++++
.../unsafe/PackedRecordPointerSuite.java | 101 ----
.../UnsafeShuffleInMemorySorterSuite.java | 124 ----
.../unsafe/UnsafeShuffleWriterSuite.java | 560 -------------------
.../org/apache/spark/SortShuffleSuite.scala | 65 +++
.../spark/scheduler/DAGSchedulerSuite.scala | 6 +-
.../BypassMergeSortShuffleWriterSuite.scala | 64 ++-
.../shuffle/sort/SortShuffleManagerSuite.scala | 131 +++++
.../shuffle/sort/SortShuffleWriterSuite.scala | 45 --
.../unsafe/UnsafeShuffleManagerSuite.scala | 129 -----
.../shuffle/unsafe/UnsafeShuffleSuite.scala | 102 ----
.../util/collection/ChainedBufferSuite.scala | 144 -----
.../PartitionedSerializedPairBufferSuite.scala | 148 -----
docs/configuration.md | 7 +-
project/MimaExcludes.scala | 9 +-
.../apache/spark/sql/execution/Exchange.scala | 23 +-
.../execution/UnsafeRowSerializerSuite.scala | 9 +-
40 files changed, 2600 insertions(+), 3461 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index f5d80bb..ee82d67 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -21,21 +21,30 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
+import javax.annotation.Nullable;
+import scala.None$;
+import scala.Option;
import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.Partitioner;
+import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@@ -62,7 +71,7 @@ import org.apache.spark.util.Utils;
* <p>
* There have been proposals to completely remove this code path; see SPARK-6026 for details.
*/
-final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
+final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
@@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
private final BlockManager blockManager;
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
+ private final int shuffleId;
+ private final int mapId;
private final Serializer serializer;
+ private final IndexShuffleBlockResolver shuffleBlockResolver;
/** Array of file writers, one for each partition */
private DiskBlockObjectWriter[] partitionWriters;
+ @Nullable private MapStatus mapStatus;
+ private long[] partitionLengths;
+
+ /**
+ * Are we in the process of stopping? Because map tasks can call stop() with success = true
+ * and then call stop() with success = false if they get an exception, we want to make sure
+ * we don't try deleting files, etc twice.
+ */
+ private boolean stopping = false;
public BypassMergeSortShuffleWriter(
- SparkConf conf,
BlockManager blockManager,
- Partitioner partitioner,
- ShuffleWriteMetrics writeMetrics,
- Serializer serializer) {
+ IndexShuffleBlockResolver shuffleBlockResolver,
+ BypassMergeSortShuffleHandle<K, V> handle,
+ int mapId,
+ TaskContext taskContext,
+ SparkConf conf) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
- this.numPartitions = partitioner.numPartitions();
this.blockManager = blockManager;
- this.partitioner = partitioner;
- this.writeMetrics = writeMetrics;
- this.serializer = serializer;
+ final ShuffleDependency<K, V, V> dep = handle.dependency();
+ this.mapId = mapId;
+ this.shuffleId = dep.shuffleId();
+ this.partitioner = dep.partitioner();
+ this.numPartitions = partitioner.numPartitions();
+ this.writeMetrics = new ShuffleWriteMetrics();
+ taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+ this.serializer = Serializer.getSerializer(dep.serializer());
+ this.shuffleBlockResolver = shuffleBlockResolver;
}
@Override
- public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
+ public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
+ partitionLengths = new long[numPartitions];
+ shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
@@ -124,13 +154,24 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
for (DiskBlockObjectWriter writer : partitionWriters) {
writer.commitAndClose();
}
+
+ partitionLengths =
+ writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
+ shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}
- @Override
- public long[] writePartitionedFile(
- BlockId blockId,
- TaskContext context,
- File outputFile) throws IOException {
+ @VisibleForTesting
+ long[] getPartitionLengths() {
+ return partitionLengths;
+ }
+
+ /**
+ * Concatenate all of the per-partition files into a single combined file.
+ *
+ * @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
+ */
+ private long[] writePartitionedFile(File outputFile) throws IOException {
// Track location of the partition starts in the output file
final long[] lengths = new long[numPartitions];
if (partitionWriters == null) {
@@ -165,18 +206,33 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
}
@Override
- public void stop() throws IOException {
- if (partitionWriters != null) {
- try {
- for (DiskBlockObjectWriter writer : partitionWriters) {
- // This method explicitly does _not_ throw exceptions:
- File file = writer.revertPartialWritesAndClose();
- if (!file.delete()) {
- logger.error("Error while deleting file {}", file.getAbsolutePath());
+ public Option<MapStatus> stop(boolean success) {
+ if (stopping) {
+ return None$.empty();
+ } else {
+ stopping = true;
+ if (success) {
+ if (mapStatus == null) {
+ throw new IllegalStateException("Cannot call stop(true) without having called write()");
+ }
+ return Option.apply(mapStatus);
+ } else {
+ // The map task failed, so delete our output data.
+ if (partitionWriters != null) {
+ try {
+ for (DiskBlockObjectWriter writer : partitionWriters) {
+ // This method explicitly does _not_ throw exceptions:
+ File file = writer.revertPartialWritesAndClose();
+ if (!file.delete()) {
+ logger.error("Error while deleting file {}", file.getAbsolutePath());
+ }
+ }
+ } finally {
+ partitionWriters = null;
}
}
- } finally {
- partitionWriters = null;
+ shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
+ return None$.empty();
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
new file mode 100644
index 0000000..c117119
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+/**
+ * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
+ * <p>
+ * Within the long, the data is laid out as follows:
+ * <pre>
+ * [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * </pre>
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
+ * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
+ * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
+ * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
+ * <p>
+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
+ * optimization to future work as it will require more careful design to ensure that addresses are
+ * properly aligned (e.g. by padding records).
+ */
+final class PackedRecordPointer {
+
+ static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes
+
+ /**
+ * The maximum partition identifier that can be encoded. Note that partition ids start from 0.
+ */
+ static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
+
+ /** Bit mask for the lower 40 bits of a long. */
+ private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
+
+ /** Bit mask for the upper 24 bits of a long */
+ private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS;
+
+ /** Bit mask for the lower 27 bits of a long. */
+ private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1;
+
+ /** Bit mask for the lower 51 bits of a long. */
+ private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1;
+
+ /** Bit mask for the upper 13 bits of a long */
+ private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
+
+ /**
+ * Pack a record address and partition id into a single word.
+ *
+ * @param recordPointer a record pointer encoded by TaskMemoryManager.
+ * @param partitionId a shuffle partition id (maximum value of 2^24).
+ * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class.
+ */
+ public static long packPointer(long recordPointer, int partitionId) {
+ assert (partitionId <= MAXIMUM_PARTITION_ID);
+ // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page.
+ // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses.
+ final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24;
+ final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS);
+ return (((long) partitionId) << 40) | compressedAddress;
+ }
+
+ private long packedRecordPointer;
+
+ public void set(long packedRecordPointer) {
+ this.packedRecordPointer = packedRecordPointer;
+ }
+
+ public int getPartitionId() {
+ return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
+ }
+
+ public long getRecordPointer() {
+ final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS;
+ final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS;
+ return pageNumber | offsetInPage;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
new file mode 100644
index 0000000..85fdaa8
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -0,0 +1,491 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import javax.annotation.Nullable;
+import java.io.File;
+import java.io.IOException;
+import java.util.LinkedList;
+
+import scala.Tuple2;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.DiskBlockObjectWriter;
+import org.apache.spark.storage.TempShuffleBlockId;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+/**
+ * An external sorter that is specialized for sort-based shuffle.
+ * <p>
+ * Incoming records are appended to data pages. When all records have been inserted (or when the
+ * current thread's shuffle memory limit is reached), the in-memory records are sorted according to
+ * their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then
+ * written to a single output file (or multiple files, if we've spilled). The format of the output
+ * files is the same as the format of the final output file written by
+ * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
+ * written as a single serialized, compressed stream that can be read with a new decompression and
+ * deserialization stream.
+ * <p>
+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its
+ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
+ * specialized merge procedure that avoids extra serialization/deserialization.
+ */
+final class ShuffleExternalSorter {
+
+ private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
+
+ @VisibleForTesting
+ static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+
+ private final int initialSize;
+ private final int numPartitions;
+ private final int pageSizeBytes;
+ @VisibleForTesting
+ final int maxRecordSizeBytes;
+ private final TaskMemoryManager taskMemoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final BlockManager blockManager;
+ private final TaskContext taskContext;
+ private final ShuffleWriteMetrics writeMetrics;
+ private long numRecordsInsertedSinceLastSpill = 0;
+
+ /** Force this sorter to spill when there are this many elements in memory. For testing only */
+ private final long numElementsForSpillThreshold;
+
+ /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+ private final int fileBufferSizeBytes;
+
+ /**
+ * Memory pages that hold the records being sorted. The pages in this list are freed when
+ * spilling, although in principle we could recycle these pages across spills (on the other hand,
+ * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+ * itself).
+ */
+ private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
+
+ private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
+
+ /** Peak memory used by this sorter so far, in bytes. **/
+ private long peakMemoryUsedBytes;
+
+ // These variables are reset after spilling:
+ @Nullable private ShuffleInMemorySorter inMemSorter;
+ @Nullable private MemoryBlock currentPage = null;
+ private long currentPagePosition = -1;
+ private long freeSpaceInCurrentPage = 0;
+
+ public ShuffleExternalSorter(
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ int initialSize,
+ int numPartitions,
+ SparkConf conf,
+ ShuffleWriteMetrics writeMetrics) throws IOException {
+ this.taskMemoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.blockManager = blockManager;
+ this.taskContext = taskContext;
+ this.initialSize = initialSize;
+ this.peakMemoryUsedBytes = initialSize;
+ this.numPartitions = numPartitions;
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.numElementsForSpillThreshold =
+ conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
+ this.pageSizeBytes = (int) Math.min(
+ PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
+ this.maxRecordSizeBytes = pageSizeBytes - 4;
+ this.writeMetrics = writeMetrics;
+ initializeForWriting();
+
+ // preserve first page to ensure that we have at least one page to work with. Otherwise,
+ // other operators in the same task may starve this sorter (SPARK-9709).
+ acquireNewPageIfNecessary(pageSizeBytes);
+ }
+
+ /**
+ * Allocates new sort data structures. Called when creating the sorter and after each spill.
+ */
+ private void initializeForWriting() throws IOException {
+ // TODO: move this sizing calculation logic into a static method of sorter:
+ final long memoryRequested = initialSize * 8L;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryAcquired != memoryRequested) {
+ shuffleMemoryManager.release(memoryAcquired);
+ throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+ }
+
+ this.inMemSorter = new ShuffleInMemorySorter(initialSize);
+ numRecordsInsertedSinceLastSpill = 0;
+ }
+
+ /**
+ * Sorts the in-memory records and writes the sorted records to an on-disk file.
+ * This method does not free the sort data structures.
+ *
+ * @param isLastFile if true, this indicates that we're writing the final output file and that the
+ * bytes written should be counted towards shuffle spill metrics rather than
+ * shuffle write metrics.
+ */
+ private void writeSortedFile(boolean isLastFile) throws IOException {
+
+ final ShuffleWriteMetrics writeMetricsToUse;
+
+ if (isLastFile) {
+ // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
+ writeMetricsToUse = writeMetrics;
+ } else {
+ // We're spilling, so bytes written should be counted towards spill rather than write.
+ // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
+ // them towards shuffle bytes written.
+ writeMetricsToUse = new ShuffleWriteMetrics();
+ }
+
+ // This call performs the actual sort.
+ final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
+ inMemSorter.getSortedIterator();
+
+ // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
+ // after SPARK-5581 is fixed.
+ DiskBlockObjectWriter writer;
+
+ // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+ // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+ // data through a byte array. This array does not need to be large enough to hold a single
+ // record;
+ final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+ // Because this output will be read during shuffle, its compression codec must be controlled by
+ // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+ // createTempShuffleBlock here; see SPARK-3426 for more details.
+ final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = spilledFileInfo._2();
+ final TempShuffleBlockId blockId = spilledFileInfo._1();
+ final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
+
+ // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+ // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ // around this, we pass a dummy no-op serializer.
+ final SerializerInstance ser = DummySerializerInstance.INSTANCE;
+
+ writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+
+ int currentPartition = -1;
+ while (sortedRecords.hasNext()) {
+ sortedRecords.loadNext();
+ final int partition = sortedRecords.packedRecordPointer.getPartitionId();
+ assert (partition >= currentPartition);
+ if (partition != currentPartition) {
+ // Switch to the new partition
+ if (currentPartition != -1) {
+ writer.commitAndClose();
+ spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+ }
+ currentPartition = partition;
+ writer =
+ blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+ }
+
+ final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
+ final Object recordPage = taskMemoryManager.getPage(recordPointer);
+ final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
+ int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage);
+ long recordReadPosition = recordOffsetInPage + 4; // skip over record length
+ while (dataRemaining > 0) {
+ final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
+ Platform.copyMemory(
+ recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
+ writer.write(writeBuffer, 0, toTransfer);
+ recordReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ }
+ writer.recordWritten();
+ }
+
+ if (writer != null) {
+ writer.commitAndClose();
+ // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
+ // then the file might be empty. Note that it might be better to avoid calling
+ // writeSortedFile() in that case.
+ if (currentPartition != -1) {
+ spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+ spills.add(spillInfo);
+ }
+ }
+
+ if (!isLastFile) { // i.e. this is a spill file
+ // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
+ // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
+ // relies on its `recordWritten()` method being called in order to trigger periodic updates to
+ // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
+ // counter at a higher-level, then the in-progress metrics for records written and bytes
+ // written would get out of sync.
+ //
+ // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
+ // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
+ // metrics to the true write metrics here. The reason for performing this copying is so that
+ // we can avoid reporting spilled bytes as shuffle write bytes.
+ //
+ // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
+ // Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
+ // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
+ writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten());
+ taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten());
+ }
+ }
+
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
+ @VisibleForTesting
+ void spill() throws IOException {
+ logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+ Thread.currentThread().getId(),
+ Utils.bytesToString(getMemoryUsage()),
+ spills.size(),
+ spills.size() > 1 ? " times" : " time");
+
+ writeSortedFile(false);
+ final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
+ inMemSorter = null;
+ shuffleMemoryManager.release(inMemSorterMemoryUsage);
+ final long spillSize = freeMemory();
+ taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+
+ initializeForWriting();
+ }
+
+ private long getMemoryUsage() {
+ long totalPageSize = 0;
+ for (MemoryBlock page : allocatedPages) {
+ totalPageSize += page.size();
+ }
+ return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
+ }
+
+ private void updatePeakMemoryUsed() {
+ long mem = getMemoryUsage();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
+ }
+
+ private long freeMemory() {
+ updatePeakMemoryUsed();
+ long memoryFreed = 0;
+ for (MemoryBlock block : allocatedPages) {
+ taskMemoryManager.freePage(block);
+ shuffleMemoryManager.release(block.size());
+ memoryFreed += block.size();
+ }
+ allocatedPages.clear();
+ currentPage = null;
+ currentPagePosition = -1;
+ freeSpaceInCurrentPage = 0;
+ return memoryFreed;
+ }
+
+ /**
+ * Force all memory and spill files to be deleted; called by shuffle error-handling code.
+ */
+ public void cleanupResources() {
+ freeMemory();
+ for (SpillInfo spill : spills) {
+ if (spill.file.exists() && !spill.file.delete()) {
+ logger.error("Unable to delete spill file {}", spill.file.getPath());
+ }
+ }
+ if (inMemSorter != null) {
+ shuffleMemoryManager.release(inMemSorter.getMemoryUsage());
+ inMemSorter = null;
+ }
+ }
+
+ /**
+ * Checks whether there is enough space to insert an additional record in to the sort pointer
+ * array and grows the array if additional space is required. If the required space cannot be
+ * obtained, then the in-memory data will be spilled to disk.
+ */
+ private void growPointerArrayIfNecessary() throws IOException {
+ assert(inMemSorter != null);
+ if (!inMemSorter.hasSpaceForAnotherRecord()) {
+ logger.debug("Attempting to expand sort pointer array");
+ final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
+ final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+ if (memoryAcquired < memoryToGrowPointerArray) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ } else {
+ inMemSorter.expandPointerArray();
+ shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+ }
+ }
+ }
+
+ /**
+ * Allocates more memory in order to insert an additional record. This will request additional
+ * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+ * obtained.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size. This must be less than or equal to the page size (records
+ * that exceed the page size are handled via a different code path which uses
+ * special overflow pages).
+ */
+ private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
+ growPointerArrayIfNecessary();
+ if (requiredSpace > freeSpaceInCurrentPage) {
+ logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
+ freeSpaceInCurrentPage);
+ // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
+ // without using the free space at the end of the current page. We should also do this for
+ // BytesToBytesMap.
+ if (requiredSpace > pageSizeBytes) {
+ throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
+ pageSizeBytes + ")");
+ } else {
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryAcquired < pageSizeBytes) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryAcquiredAfterSpilling != pageSizeBytes) {
+ shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
+ }
+ }
+ currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+ currentPagePosition = currentPage.getBaseOffset();
+ freeSpaceInCurrentPage = pageSizeBytes;
+ allocatedPages.add(currentPage);
+ }
+ }
+ }
+
+ /**
+ * Write a record to the shuffle sorter.
+ */
+ public void insertRecord(
+ Object recordBaseObject,
+ long recordBaseOffset,
+ int lengthInBytes,
+ int partitionId) throws IOException {
+
+ if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) {
+ spill();
+ }
+
+ growPointerArrayIfNecessary();
+ // Need 4 bytes to store the record length.
+ final int totalSpaceRequired = lengthInBytes + 4;
+
+ // --- Figure out where to insert the new record ----------------------------------------------
+
+ final MemoryBlock dataPage;
+ long dataPagePosition;
+ boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
+ if (useOverflowPage) {
+ long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
+ // The record is larger than the page size, so allocate a special overflow page just to hold
+ // that record.
+ final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+ if (memoryGranted != overflowPageSize) {
+ shuffleMemoryManager.release(memoryGranted);
+ spill();
+ final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+ if (memoryGrantedAfterSpill != overflowPageSize) {
+ shuffleMemoryManager.release(memoryGrantedAfterSpill);
+ throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
+ }
+ }
+ MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+ allocatedPages.add(overflowPage);
+ dataPage = overflowPage;
+ dataPagePosition = overflowPage.getBaseOffset();
+ } else {
+ // The record is small enough to fit in a regular data page, but the current page might not
+ // have enough space to hold it (or no pages have been allocated yet).
+ acquireNewPageIfNecessary(totalSpaceRequired);
+ dataPage = currentPage;
+ dataPagePosition = currentPagePosition;
+ // Update bookkeeping information
+ freeSpaceInCurrentPage -= totalSpaceRequired;
+ currentPagePosition += totalSpaceRequired;
+ }
+ final Object dataPageBaseObject = dataPage.getBaseObject();
+
+ final long recordAddress =
+ taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
+ Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
+ dataPagePosition += 4;
+ Platform.copyMemory(
+ recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
+ assert(inMemSorter != null);
+ inMemSorter.insertRecord(recordAddress, partitionId);
+ numRecordsInsertedSinceLastSpill += 1;
+ }
+
+ /**
+ * Close the sorter, causing any buffered data to be sorted and written out to disk.
+ *
+ * @return metadata for the spill files written by this sorter. If no records were ever inserted
+ * into this sorter, then this will return an empty array.
+ * @throws IOException
+ */
+ public SpillInfo[] closeAndGetSpills() throws IOException {
+ try {
+ if (inMemSorter != null) {
+ // Do not count the final file towards the spill count.
+ writeSortedFile(true);
+ freeMemory();
+ }
+ return spills.toArray(new SpillInfo[spills.size()]);
+ } catch (IOException e) {
+ cleanupResources();
+ throw e;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
new file mode 100644
index 0000000..a8dee6c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.util.Comparator;
+
+import org.apache.spark.util.collection.Sorter;
+
+final class ShuffleInMemorySorter {
+
+ private final Sorter<PackedRecordPointer, long[]> sorter;
+ private static final class SortComparator implements Comparator<PackedRecordPointer> {
+ @Override
+ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
+ return left.getPartitionId() - right.getPartitionId();
+ }
+ }
+ private static final SortComparator SORT_COMPARATOR = new SortComparator();
+
+ /**
+ * An array of record pointers and partition ids that have been encoded by
+ * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
+ * records.
+ */
+ private long[] pointerArray;
+
+ /**
+ * The position in the pointer array where new records can be inserted.
+ */
+ private int pointerArrayInsertPosition = 0;
+
+ public ShuffleInMemorySorter(int initialSize) {
+ assert (initialSize > 0);
+ this.pointerArray = new long[initialSize];
+ this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE);
+ }
+
+ public void expandPointerArray() {
+ final long[] oldArray = pointerArray;
+ // Guard against overflow:
+ final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+ pointerArray = new long[newLength];
+ System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+ }
+
+ public boolean hasSpaceForAnotherRecord() {
+ return pointerArrayInsertPosition + 1 < pointerArray.length;
+ }
+
+ public long getMemoryUsage() {
+ return pointerArray.length * 8L;
+ }
+
+ /**
+ * Inserts a record to be sorted.
+ *
+ * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to
+ * certain pointer compression techniques used by the sorter, the sort can
+ * only operate on pointers that point to locations in the first
+ * {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page.
+ * @param partitionId the partition id, which must be less than or equal to
+ * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}.
+ */
+ public void insertRecord(long recordPointer, int partitionId) {
+ if (!hasSpaceForAnotherRecord()) {
+ if (pointerArray.length == Integer.MAX_VALUE) {
+ throw new IllegalStateException("Sort pointer array has reached maximum size");
+ } else {
+ expandPointerArray();
+ }
+ }
+ pointerArray[pointerArrayInsertPosition] =
+ PackedRecordPointer.packPointer(recordPointer, partitionId);
+ pointerArrayInsertPosition++;
+ }
+
+ /**
+ * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
+ */
+ public static final class ShuffleSorterIterator {
+
+ private final long[] pointerArray;
+ private final int numRecords;
+ final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
+ private int position = 0;
+
+ public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
+ this.numRecords = numRecords;
+ this.pointerArray = pointerArray;
+ }
+
+ public boolean hasNext() {
+ return position < numRecords;
+ }
+
+ public void loadNext() {
+ packedRecordPointer.set(pointerArray[position]);
+ position++;
+ }
+ }
+
+ /**
+ * Return an iterator over record pointers in sorted order.
+ */
+ public ShuffleSorterIterator getSortedIterator() {
+ sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
+ return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
new file mode 100644
index 0000000..8a1e5ae
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
+
+ public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();
+
+ private ShuffleSortDataFormat() { }
+
+ @Override
+ public PackedRecordPointer getKey(long[] data, int pos) {
+ // Since we re-use keys, this method shouldn't be called.
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public PackedRecordPointer newKey() {
+ return new PackedRecordPointer();
+ }
+
+ @Override
+ public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
+ reuse.set(data[pos]);
+ return reuse;
+ }
+
+ @Override
+ public void swap(long[] data, int pos0, int pos1) {
+ final long temp = data[pos0];
+ data[pos0] = data[pos1];
+ data[pos1] = temp;
+ }
+
+ @Override
+ public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+ dst[dstPos] = src[srcPos];
+ }
+
+ @Override
+ public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+ System.arraycopy(src, srcPos, dst, dstPos, length);
+ }
+
+ @Override
+ public long[] allocate(int length) {
+ return new long[length];
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
deleted file mode 100644
index 656ea04..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.sort;
-
-import java.io.File;
-import java.io.IOException;
-
-import scala.Product2;
-import scala.collection.Iterator;
-
-import org.apache.spark.annotation.Private;
-import org.apache.spark.TaskContext;
-import org.apache.spark.storage.BlockId;
-
-/**
- * Interface for objects that {@link SortShuffleWriter} uses to write its output files.
- */
-@Private
-public interface SortShuffleFileWriter<K, V> {
-
- void insertAll(Iterator<Product2<K, V>> records) throws IOException;
-
- /**
- * Write all the data added into this shuffle sorter into a file in the disk store. This is
- * called by the SortShuffleWriter and can go through an efficient path of just concatenating
- * binary files if we decided to avoid merge-sorting.
- *
- * @param blockId block ID to write to. The index file will be blockId.name + ".index".
- * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
- * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
- */
- long[] writePartitionedFile(
- BlockId blockId,
- TaskContext context,
- File outputFile) throws IOException;
-
- void stop() throws IOException;
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
new file mode 100644
index 0000000..df9f7b7
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+
+import org.apache.spark.storage.TempShuffleBlockId;
+
+/**
+ * Metadata for a block of data written by {@link ShuffleExternalSorter}.
+ */
+final class SpillInfo {
+ final long[] partitionLengths;
+ final File file;
+ final TempShuffleBlockId blockId;
+
+ public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
+ this.partitionLengths = new long[numPartitions];
+ this.file = file;
+ this.blockId = blockId;
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
new file mode 100644
index 0000000..e8f050c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -0,0 +1,489 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import javax.annotation.Nullable;
+import java.io.*;
+import java.nio.channels.FileChannel;
+import java.util.Iterator;
+
+import scala.Option;
+import scala.Product2;
+import scala.collection.JavaConverters;
+import scala.collection.immutable.Map;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.*;
+import org.apache.spark.annotation.Private;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.TimeTrackingOutputStream;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+@Private
+public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
+
+ private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
+
+ @VisibleForTesting
+ static final int INITIAL_SORT_BUFFER_SIZE = 4096;
+
+ private final BlockManager blockManager;
+ private final IndexShuffleBlockResolver shuffleBlockResolver;
+ private final TaskMemoryManager memoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final SerializerInstance serializer;
+ private final Partitioner partitioner;
+ private final ShuffleWriteMetrics writeMetrics;
+ private final int shuffleId;
+ private final int mapId;
+ private final TaskContext taskContext;
+ private final SparkConf sparkConf;
+ private final boolean transferToEnabled;
+
+ @Nullable private MapStatus mapStatus;
+ @Nullable private ShuffleExternalSorter sorter;
+ private long peakMemoryUsedBytes = 0;
+
+ /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
+ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
+ public MyByteArrayOutputStream(int size) { super(size); }
+ public byte[] getBuf() { return buf; }
+ }
+
+ private MyByteArrayOutputStream serBuffer;
+ private SerializationStream serOutputStream;
+
+ /**
+ * Are we in the process of stopping? Because map tasks can call stop() with success = true
+ * and then call stop() with success = false if they get an exception, we want to make sure
+ * we don't try deleting files, etc twice.
+ */
+ private boolean stopping = false;
+
+ public UnsafeShuffleWriter(
+ BlockManager blockManager,
+ IndexShuffleBlockResolver shuffleBlockResolver,
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ SerializedShuffleHandle<K, V> handle,
+ int mapId,
+ TaskContext taskContext,
+ SparkConf sparkConf) throws IOException {
+ final int numPartitions = handle.dependency().partitioner().numPartitions();
+ if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
+ throw new IllegalArgumentException(
+ "UnsafeShuffleWriter can only be used for shuffles with at most " +
+ SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions");
+ }
+ this.blockManager = blockManager;
+ this.shuffleBlockResolver = shuffleBlockResolver;
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.mapId = mapId;
+ final ShuffleDependency<K, V, V> dep = handle.dependency();
+ this.shuffleId = dep.shuffleId();
+ this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
+ this.partitioner = dep.partitioner();
+ this.writeMetrics = new ShuffleWriteMetrics();
+ taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+ this.taskContext = taskContext;
+ this.sparkConf = sparkConf;
+ this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
+ open();
+ }
+
+ @VisibleForTesting
+ public int maxRecordSizeBytes() {
+ assert(sorter != null);
+ return sorter.maxRecordSizeBytes;
+ }
+
+ private void updatePeakMemoryUsed() {
+ // sorter can be null if this writer is closed
+ if (sorter != null) {
+ long mem = sorter.getPeakMemoryUsedBytes();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
+ }
+
+ /**
+ * This convenience method should only be called in test code.
+ */
+ @VisibleForTesting
+ public void write(Iterator<Product2<K, V>> records) throws IOException {
+ write(JavaConverters.asScalaIteratorConverter(records).asScala());
+ }
+
+ @Override
+ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
+ // Keep track of success so we know if we encountered an exception
+ // We do this rather than a standard try/catch/re-throw to handle
+ // generic throwables.
+ boolean success = false;
+ try {
+ while (records.hasNext()) {
+ insertRecordIntoSorter(records.next());
+ }
+ closeAndWriteOutput();
+ success = true;
+ } finally {
+ if (sorter != null) {
+ try {
+ sorter.cleanupResources();
+ } catch (Exception e) {
+ // Only throw this error if we won't be masking another
+ // error.
+ if (success) {
+ throw e;
+ } else {
+ logger.error("In addition to a failure during writing, we failed during " +
+ "cleanup.", e);
+ }
+ }
+ }
+ }
+ }
+
+ private void open() throws IOException {
+ assert (sorter == null);
+ sorter = new ShuffleExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ INITIAL_SORT_BUFFER_SIZE,
+ partitioner.numPartitions(),
+ sparkConf,
+ writeMetrics);
+ serBuffer = new MyByteArrayOutputStream(1024 * 1024);
+ serOutputStream = serializer.serializeStream(serBuffer);
+ }
+
+ @VisibleForTesting
+ void closeAndWriteOutput() throws IOException {
+ assert(sorter != null);
+ updatePeakMemoryUsed();
+ serBuffer = null;
+ serOutputStream = null;
+ final SpillInfo[] spills = sorter.closeAndGetSpills();
+ sorter = null;
+ final long[] partitionLengths;
+ try {
+ partitionLengths = mergeSpills(spills);
+ } finally {
+ for (SpillInfo spill : spills) {
+ if (spill.file.exists() && ! spill.file.delete()) {
+ logger.error("Error while deleting spill file {}", spill.file.getPath());
+ }
+ }
+ }
+ shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ }
+
+ @VisibleForTesting
+ void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
+ assert(sorter != null);
+ final K key = record._1();
+ final int partitionId = partitioner.getPartition(key);
+ serBuffer.reset();
+ serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
+ serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
+ serOutputStream.flush();
+
+ final int serializedRecordSize = serBuffer.size();
+ assert (serializedRecordSize > 0);
+
+ sorter.insertRecord(
+ serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
+ }
+
+ @VisibleForTesting
+ void forceSorterToSpill() throws IOException {
+ assert (sorter != null);
+ sorter.spill();
+ }
+
+ /**
+ * Merge zero or more spill files together, choosing the fastest merging strategy based on the
+ * number of spills and the IO compression codec.
+ *
+ * @return the partition lengths in the merged file.
+ */
+ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
+ final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+ final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
+ final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
+ final boolean fastMergeEnabled =
+ sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
+ final boolean fastMergeIsSupported =
+ !compressionEnabled || compressionCodec instanceof LZFCompressionCodec;
+ try {
+ if (spills.length == 0) {
+ new FileOutputStream(outputFile).close(); // Create an empty file
+ return new long[partitioner.numPartitions()];
+ } else if (spills.length == 1) {
+ // Here, we don't need to perform any metrics updates because the bytes written to this
+ // output file would have already been counted as shuffle bytes written.
+ Files.move(spills[0].file, outputFile);
+ return spills[0].partitionLengths;
+ } else {
+ final long[] partitionLengths;
+ // There are multiple spills to merge, so none of these spill files' lengths were counted
+ // towards our shuffle write count or shuffle write time. If we use the slow merge path,
+ // then the final output file's size won't necessarily be equal to the sum of the spill
+ // files' sizes. To guard against this case, we look at the output file's actual size when
+ // computing shuffle bytes written.
+ //
+ // We allow the individual merge methods to report their own IO times since different merge
+ // strategies use different IO techniques. We count IO during merge towards the shuffle
+ // shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
+ // branch in ExternalSorter.
+ if (fastMergeEnabled && fastMergeIsSupported) {
+ // Compression is disabled or we are using an IO compression codec that supports
+ // decompression of concatenated compressed streams, so we can perform a fast spill merge
+ // that doesn't need to interpret the spilled bytes.
+ if (transferToEnabled) {
+ logger.debug("Using transferTo-based fast merge");
+ partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
+ } else {
+ logger.debug("Using fileStream-based fast merge");
+ partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
+ }
+ } else {
+ logger.debug("Using slow merge");
+ partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+ }
+ // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
+ // in-memory records, we write out the in-memory records to a file but do not count that
+ // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
+ // to be counted as shuffle write, but this will lead to double-counting of the final
+ // SpillInfo's bytes.
+ writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length());
+ writeMetrics.incShuffleBytesWritten(outputFile.length());
+ return partitionLengths;
+ }
+ } catch (IOException e) {
+ if (outputFile.exists() && !outputFile.delete()) {
+ logger.error("Unable to delete output file {}", outputFile.getPath());
+ }
+ throw e;
+ }
+ }
+
+ /**
+ * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
+ * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
+ * cases where the IO compression codec does not support concatenation of compressed data, or in
+ * cases where users have explicitly disabled use of {@code transferTo} in order to work around
+ * kernel bugs.
+ *
+ * @param spills the spills to merge.
+ * @param outputFile the file to write the merged data to.
+ * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
+ * @return the partition lengths in the merged file.
+ */
+ private long[] mergeSpillsWithFileStream(
+ SpillInfo[] spills,
+ File outputFile,
+ @Nullable CompressionCodec compressionCodec) throws IOException {
+ assert (spills.length >= 2);
+ final int numPartitions = partitioner.numPartitions();
+ final long[] partitionLengths = new long[numPartitions];
+ final InputStream[] spillInputStreams = new FileInputStream[spills.length];
+ OutputStream mergedFileOutputStream = null;
+
+ boolean threwException = true;
+ try {
+ for (int i = 0; i < spills.length; i++) {
+ spillInputStreams[i] = new FileInputStream(spills[i].file);
+ }
+ for (int partition = 0; partition < numPartitions; partition++) {
+ final long initialFileLength = outputFile.length();
+ mergedFileOutputStream =
+ new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
+ if (compressionCodec != null) {
+ mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
+ }
+
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+ if (partitionLengthInSpill > 0) {
+ InputStream partitionInputStream =
+ new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill);
+ if (compressionCodec != null) {
+ partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+ }
+ ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
+ }
+ }
+ mergedFileOutputStream.flush();
+ mergedFileOutputStream.close();
+ partitionLengths[partition] = (outputFile.length() - initialFileLength);
+ }
+ threwException = false;
+ } finally {
+ // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+ // throw exceptions during cleanup if threwException == false.
+ for (InputStream stream : spillInputStreams) {
+ Closeables.close(stream, threwException);
+ }
+ Closeables.close(mergedFileOutputStream, threwException);
+ }
+ return partitionLengths;
+ }
+
+ /**
+ * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
+ * This is only safe when the IO compression codec and serializer support concatenation of
+ * serialized streams.
+ *
+ * @return the partition lengths in the merged file.
+ */
+ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
+ assert (spills.length >= 2);
+ final int numPartitions = partitioner.numPartitions();
+ final long[] partitionLengths = new long[numPartitions];
+ final FileChannel[] spillInputChannels = new FileChannel[spills.length];
+ final long[] spillInputChannelPositions = new long[spills.length];
+ FileChannel mergedFileOutputChannel = null;
+
+ boolean threwException = true;
+ try {
+ for (int i = 0; i < spills.length; i++) {
+ spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
+ }
+ // This file needs to opened in append mode in order to work around a Linux kernel bug that
+ // affects transferTo; see SPARK-3948 for more details.
+ mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
+
+ long bytesWrittenToMergedFile = 0;
+ for (int partition = 0; partition < numPartitions; partition++) {
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+ long bytesToTransfer = partitionLengthInSpill;
+ final FileChannel spillInputChannel = spillInputChannels[i];
+ final long writeStartTime = System.nanoTime();
+ while (bytesToTransfer > 0) {
+ final long actualBytesTransferred = spillInputChannel.transferTo(
+ spillInputChannelPositions[i],
+ bytesToTransfer,
+ mergedFileOutputChannel);
+ spillInputChannelPositions[i] += actualBytesTransferred;
+ bytesToTransfer -= actualBytesTransferred;
+ }
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+ bytesWrittenToMergedFile += partitionLengthInSpill;
+ partitionLengths[partition] += partitionLengthInSpill;
+ }
+ }
+ // Check the position after transferTo loop to see if it is in the right position and raise an
+ // exception if it is incorrect. The position will not be increased to the expected length
+ // after calling transferTo in kernel version 2.6.32. This issue is described at
+ // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
+ if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
+ throw new IOException(
+ "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
+ "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
+ " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
+ "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
+ "to disable this NIO feature."
+ );
+ }
+ threwException = false;
+ } finally {
+ // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+ // throw exceptions during cleanup if threwException == false.
+ for (int i = 0; i < spills.length; i++) {
+ assert(spillInputChannelPositions[i] == spills[i].file.length());
+ Closeables.close(spillInputChannels[i], threwException);
+ }
+ Closeables.close(mergedFileOutputChannel, threwException);
+ }
+ return partitionLengths;
+ }
+
+ @Override
+ public Option<MapStatus> stop(boolean success) {
+ try {
+ // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite)
+ Map<String, Accumulator<Object>> internalAccumulators =
+ taskContext.internalMetricsToAccumulators();
+ if (internalAccumulators != null) {
+ internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY())
+ .add(getPeakMemoryUsedBytes());
+ }
+
+ if (stopping) {
+ return Option.apply(null);
+ } else {
+ stopping = true;
+ if (success) {
+ if (mapStatus == null) {
+ throw new IllegalStateException("Cannot call stop(true) without having called write()");
+ }
+ return Option.apply(mapStatus);
+ } else {
+ // The map task failed, so delete our output data.
+ shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
+ return Option.apply(null);
+ }
+ }
+ } finally {
+ if (sorter != null) {
+ // If sorter is non-null, then this implies that we called stop() in response to an error,
+ // so we need to clean up memory and spill files created by the sorter
+ sorter.cleanupResources();
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
deleted file mode 100644
index 4ee6a82..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-/**
- * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
- * <p>
- * Within the long, the data is laid out as follows:
- * <pre>
- * [24 bit partition number][13 bit memory page number][27 bit offset in page]
- * </pre>
- * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
- * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
- * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
- * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
- * <p>
- * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
- * optimization to future work as it will require more careful design to ensure that addresses are
- * properly aligned (e.g. by padding records).
- */
-final class PackedRecordPointer {
-
- static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes
-
- /**
- * The maximum partition identifier that can be encoded. Note that partition ids start from 0.
- */
- static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
-
- /** Bit mask for the lower 40 bits of a long. */
- private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
-
- /** Bit mask for the upper 24 bits of a long */
- private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS;
-
- /** Bit mask for the lower 27 bits of a long. */
- private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1;
-
- /** Bit mask for the lower 51 bits of a long. */
- private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1;
-
- /** Bit mask for the upper 13 bits of a long */
- private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
-
- /**
- * Pack a record address and partition id into a single word.
- *
- * @param recordPointer a record pointer encoded by TaskMemoryManager.
- * @param partitionId a shuffle partition id (maximum value of 2^24).
- * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class.
- */
- public static long packPointer(long recordPointer, int partitionId) {
- assert (partitionId <= MAXIMUM_PARTITION_ID);
- // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page.
- // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses.
- final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24;
- final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS);
- return (((long) partitionId) << 40) | compressedAddress;
- }
-
- private long packedRecordPointer;
-
- public void set(long packedRecordPointer) {
- this.packedRecordPointer = packedRecordPointer;
- }
-
- public int getPartitionId() {
- return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
- }
-
- public long getRecordPointer() {
- final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS;
- final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS;
- return pageNumber | offsetInPage;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
deleted file mode 100644
index 7bac0dc..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import java.io.File;
-
-import org.apache.spark.storage.TempShuffleBlockId;
-
-/**
- * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}.
- */
-final class SpillInfo {
- final long[] partitionLengths;
- final File file;
- final TempShuffleBlockId blockId;
-
- public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
- this.partitionLengths = new long[numPartitions];
- this.file = file;
- this.blockId = blockId;
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org
[2/4] spark git commit: [SPARK-10708] Consolidate sort shuffle
implementations
Posted by jo...@apache.org.
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
deleted file mode 100644
index 87a786b..0000000
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ /dev/null
@@ -1,273 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.InputStream
-import java.nio.IntBuffer
-import java.util.Comparator
-
-import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
-import org.apache.spark.storage.DiskBlockObjectWriter
-import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
-
-/**
- * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes
- * its records upon insert and stores them as raw bytes.
- *
- * We use two data-structures to store the contents. The serialized records are stored in a
- * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a
- * metadata buffer that stores pointers into the data buffer as well as the partition ID of each
- * record. Each entry in the metadata buffer takes up a fixed amount of space.
- *
- * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not
- * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can
- * happen without following any pointers, which should minimize cache misses.
- *
- * Currently, only sorting by partition is supported.
- *
- * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across
- * two integers:
- *
- * +-------------+------------+------------+-------------+
- * | keyStart | keyValLen | partitionId |
- * +-------------+------------+------------+-------------+
- *
- * The buffer can support up to `536870911 (2 ^ 29 - 1)` records.
- *
- * @param metaInitialRecords The initial number of entries in the metadata buffer.
- * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records.
- * @param serializerInstance the serializer used for serializing inserted records.
- */
-private[spark] class PartitionedSerializedPairBuffer[K, V](
- metaInitialRecords: Int,
- kvBlockSize: Int,
- serializerInstance: SerializerInstance)
- extends WritablePartitionedPairCollection[K, V] with SizeTracker {
-
- if (serializerInstance.isInstanceOf[JavaSerializerInstance]) {
- throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" +
- " Java-serialized objects.")
- }
-
- require(metaInitialRecords <= MAXIMUM_RECORDS,
- s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records")
- private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE)
-
- private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize)
- private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer)
- private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream)
-
- def insert(partition: Int, key: K, value: V): Unit = {
- if (metaBuffer.position == metaBuffer.capacity) {
- growMetaBuffer()
- }
-
- val keyStart = kvBuffer.size
- kvSerializationStream.writeKey[Any](key)
- kvSerializationStream.writeValue[Any](value)
- kvSerializationStream.flush()
- val keyValLen = (kvBuffer.size - keyStart).toInt
-
- // keyStart, a long, gets split across two ints
- metaBuffer.put(keyStart.toInt)
- metaBuffer.put((keyStart >> 32).toInt)
- metaBuffer.put(keyValLen)
- metaBuffer.put(partition)
- }
-
- /** Double the size of the array because we've reached capacity */
- private def growMetaBuffer(): Unit = {
- if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) {
- throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_RECORDS} records")
- }
- val newCapacity =
- if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > MAXIMUM_META_BUFFER_CAPACITY) {
- // Overflow
- MAXIMUM_META_BUFFER_CAPACITY
- } else {
- metaBuffer.capacity * 2
- }
- val newMetaBuffer = IntBuffer.allocate(newCapacity)
- newMetaBuffer.put(metaBuffer.array)
- metaBuffer = newMetaBuffer
- }
-
- /** Iterate through the data in a given order. For this class this is not really destructive. */
- override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
- : Iterator[((Int, K), V)] = {
- sort(keyComparator)
- val is = orderedInputStream
- val deserStream = serializerInstance.deserializeStream(is)
- new Iterator[((Int, K), V)] {
- var metaBufferPos = 0
- def hasNext: Boolean = metaBufferPos < metaBuffer.position
- def next(): ((Int, K), V) = {
- val key = deserStream.readKey[Any]().asInstanceOf[K]
- val value = deserStream.readValue[Any]().asInstanceOf[V]
- val partition = metaBuffer.get(metaBufferPos + PARTITION)
- metaBufferPos += RECORD_SIZE
- ((partition, key), value)
- }
- }
- }
-
- override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity
-
- override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
- : WritablePartitionedIterator = {
- sort(keyComparator)
- new WritablePartitionedIterator {
- // current position in the meta buffer in ints
- var pos = 0
-
- def writeNext(writer: DiskBlockObjectWriter): Unit = {
- val keyStart = getKeyStartPos(metaBuffer, pos)
- val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
- pos += RECORD_SIZE
- kvBuffer.read(keyStart, writer, keyValLen)
- writer.recordWritten()
- }
- def nextPartition(): Int = metaBuffer.get(pos + PARTITION)
- def hasNext(): Boolean = pos < metaBuffer.position
- }
- }
-
- // Visible for testing
- def orderedInputStream: OrderedInputStream = {
- new OrderedInputStream(metaBuffer, kvBuffer)
- }
-
- private def sort(keyComparator: Option[Comparator[K]]): Unit = {
- val comparator = if (keyComparator.isEmpty) {
- new Comparator[Int]() {
- def compare(partition1: Int, partition2: Int): Int = {
- partition1 - partition2
- }
- }
- } else {
- throw new UnsupportedOperationException()
- }
-
- val sorter = new Sorter(new SerializedSortDataFormat)
- sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator)
- }
-}
-
-private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer)
- extends InputStream {
-
- import PartitionedSerializedPairBuffer._
-
- private var metaBufferPos = 0
- private var kvBufferPos =
- if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0
-
- override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length)
-
- override def read(bytes: Array[Byte], offs: Int, len: Int): Int = {
- if (metaBufferPos >= metaBuffer.position) {
- return -1
- }
- val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) -
- (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt
- val toRead = math.min(bytesRemainingInRecord, len)
- kvBuffer.read(kvBufferPos, bytes, offs, toRead)
- if (toRead == bytesRemainingInRecord) {
- metaBufferPos += RECORD_SIZE
- if (metaBufferPos < metaBuffer.position) {
- kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos)
- }
- } else {
- kvBufferPos += toRead
- }
- toRead
- }
-
- override def read(): Int = {
- throw new UnsupportedOperationException()
- }
-}
-
-private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] {
-
- private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE)
-
- /** Return the sort key for the element at the given index. */
- override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = {
- metaBuffer.get(pos * RECORD_SIZE + PARTITION)
- }
-
- /** Swap two elements. */
- override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = {
- val iOff = pos0 * RECORD_SIZE
- val jOff = pos1 * RECORD_SIZE
- System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE)
- System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE)
- System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE)
- }
-
- /** Copy a single element from src(srcPos) to dst(dstPos). */
- override def copyElement(
- src: IntBuffer,
- srcPos: Int,
- dst: IntBuffer,
- dstPos: Int): Unit = {
- val srcOff = srcPos * RECORD_SIZE
- val dstOff = dstPos * RECORD_SIZE
- System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE)
- }
-
- /**
- * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
- * Overlapping ranges are allowed.
- */
- override def copyRange(
- src: IntBuffer,
- srcPos: Int,
- dst: IntBuffer,
- dstPos: Int,
- length: Int): Unit = {
- val srcOff = srcPos * RECORD_SIZE
- val dstOff = dstPos * RECORD_SIZE
- System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length)
- }
-
- /**
- * Allocates a Buffer that can hold up to 'length' elements.
- * All elements of the buffer should be considered invalid until data is explicitly copied in.
- */
- override def allocate(length: Int): IntBuffer = {
- IntBuffer.allocate(length * RECORD_SIZE)
- }
-}
-
-private object PartitionedSerializedPairBuffer {
- val KEY_START = 0 // keyStart, a long, gets split across two ints
- val KEY_VAL_LEN = 2
- val PARTITION = 3
- val RECORD_SIZE = PARTITION + 1 // num ints of metadata
-
- val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1
- val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) - 4
-
- def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = {
- val lower32 = metaBuffer.get(metaBufferPos + KEY_START)
- val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1)
- (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL)
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
new file mode 100644
index 0000000..232ae4d
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import org.apache.spark.shuffle.sort.PackedRecordPointer;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
+
+public class PackedRecordPointerSuite {
+
+ @Test
+ public void heap() {
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock page0 = memoryManager.allocatePage(128);
+ final MemoryBlock page1 = memoryManager.allocatePage(128);
+ final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+ page1.getBaseOffset() + 42);
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+ assertEquals(360, packedPointer.getPartitionId());
+ final long recordPointer = packedPointer.getRecordPointer();
+ assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+ assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+ assertEquals(addressInPage1, recordPointer);
+ memoryManager.cleanUpAllAllocatedMemory();
+ }
+
+ @Test
+ public void offHeap() {
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+ final MemoryBlock page0 = memoryManager.allocatePage(128);
+ final MemoryBlock page1 = memoryManager.allocatePage(128);
+ final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+ page1.getBaseOffset() + 42);
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+ assertEquals(360, packedPointer.getPartitionId());
+ final long recordPointer = packedPointer.getRecordPointer();
+ assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+ assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+ assertEquals(addressInPage1, recordPointer);
+ memoryManager.cleanUpAllAllocatedMemory();
+ }
+
+ @Test
+ public void maximumPartitionIdCanBeEncoded() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID));
+ assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
+ }
+
+ @Test
+ public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ try {
+ // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
+ packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1));
+ assertFalse(MAXIMUM_PARTITION_ID + 1 == packedPointer.getPartitionId());
+ } catch (AssertionError e ) {
+ // pass
+ }
+ }
+
+ @Test
+ public void maximumOffsetInPageCanBeEncoded() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1);
+ packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+ assertEquals(address, packedPointer.getRecordPointer());
+ }
+
+ @Test
+ public void offsetsPastMaxOffsetInPageWillOverflow() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES);
+ packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+ assertEquals(0, packedPointer.getRecordPointer());
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
new file mode 100644
index 0000000..1ef3c5f
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class ShuffleInMemorySorterSuite {
+
+ private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
+ final byte[] strBytes = new byte[strLength];
+ Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength);
+ return new String(strBytes);
+ }
+
+ @Test
+ public void testSortingEmptyInput() {
+ final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100);
+ final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
+ assert(!iter.hasNext());
+ }
+
+ @Test
+ public void testBasicSorting() throws Exception {
+ final String[] dataToSort = new String[] {
+ "Boba",
+ "Pearls",
+ "Tapioca",
+ "Taho",
+ "Condensed Milk",
+ "Jasmine",
+ "Milk Tea",
+ "Lychee",
+ "Mango"
+ };
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+ final Object baseObject = dataPage.getBaseObject();
+ final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
+ final HashPartitioner hashPartitioner = new HashPartitioner(4);
+
+ // Write the records into the data page and store pointers into the sorter
+ long position = dataPage.getBaseOffset();
+ for (String str : dataToSort) {
+ final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
+ final byte[] strBytes = str.getBytes("utf-8");
+ Platform.putInt(baseObject, position, strBytes.length);
+ position += 4;
+ Platform.copyMemory(
+ strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length);
+ position += strBytes.length;
+ sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
+ }
+
+ // Sort the records
+ final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
+ int prevPartitionId = -1;
+ Arrays.sort(dataToSort);
+ for (int i = 0; i < dataToSort.length; i++) {
+ Assert.assertTrue(iter.hasNext());
+ iter.loadNext();
+ final int partitionId = iter.packedRecordPointer.getPartitionId();
+ Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
+ Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
+ partitionId >= prevPartitionId);
+ final long recordAddress = iter.packedRecordPointer.getRecordPointer();
+ final int recordLength = Platform.getInt(
+ memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
+ final String str = getStringFromDataPage(
+ memoryManager.getPage(recordAddress),
+ memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length
+ recordLength);
+ Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
+ }
+ Assert.assertFalse(iter.hasNext());
+ }
+
+ @Test
+ public void testSortingManyNumbers() throws Exception {
+ ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
+ int[] numbersToSort = new int[128000];
+ Random random = new Random(16);
+ for (int i = 0; i < numbersToSort.length; i++) {
+ numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
+ sorter.insertRecord(0, numbersToSort[i]);
+ }
+ Arrays.sort(numbersToSort);
+ int[] sorterResult = new int[numbersToSort.length];
+ ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
+ int j = 0;
+ while (iter.hasNext()) {
+ iter.loadNext();
+ sorterResult[j] = iter.packedRecordPointer.getPartitionId();
+ j += 1;
+ }
+ Assert.assertArrayEquals(numbersToSort, sorterResult);
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
new file mode 100644
index 0000000..29d9823
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -0,0 +1,560 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.*;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import scala.*;
+import scala.collection.Iterator;
+import scala.runtime.AbstractFunction1;
+
+import com.google.common.collect.Iterators;
+import com.google.common.collect.HashMultiset;
+import com.google.common.io.ByteStreams;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.*;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZ4CompressionCodec;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.io.SnappyCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.serializer.*;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.sort.SerializedShuffleHandle;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeShuffleWriterSuite {
+
+ static final int NUM_PARTITITONS = 4;
+ final TaskMemoryManager taskMemoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
+ File mergedOutputFile;
+ File tempDir;
+ long[] partitionSizesInMergedFile;
+ final LinkedList<File> spillFilesCreated = new LinkedList<File>();
+ SparkConf conf;
+ final Serializer serializer = new KryoSerializer(new SparkConf());
+ TaskMetrics taskMetrics;
+
+ @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+ @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
+ @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+ @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
+
+ private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+ @Override
+ public OutputStream apply(OutputStream stream) {
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
+ } else {
+ return stream;
+ }
+ }
+ }
+
+ @After
+ public void tearDown() {
+ Utils.deleteRecursively(tempDir);
+ final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
+ if (leakedMemory != 0) {
+ fail("Test leaked " + leakedMemory + " bytes of managed memory");
+ }
+ }
+
+ @Before
+ @SuppressWarnings("unchecked")
+ public void setUp() throws IOException {
+ MockitoAnnotations.initMocks(this);
+ tempDir = Utils.createTempDir("test", "test");
+ mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
+ partitionSizesInMergedFile = null;
+ spillFilesCreated.clear();
+ conf = new SparkConf().set("spark.buffer.pageSize", "128m");
+ taskMetrics = new TaskMetrics();
+
+ when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+ when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024);
+
+ when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+ when(blockManager.getDiskWriter(
+ any(BlockId.class),
+ any(File.class),
+ any(SerializerInstance.class),
+ anyInt(),
+ any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
+ @Override
+ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+ Object[] args = invocationOnMock.getArguments();
+
+ return new DiskBlockObjectWriter(
+ (File) args[1],
+ (SerializerInstance) args[2],
+ (Integer) args[3],
+ new CompressStream(),
+ false,
+ (ShuffleWriteMetrics) args[4]
+ );
+ }
+ });
+ when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer(
+ new Answer<InputStream>() {
+ @Override
+ public InputStream answer(InvocationOnMock invocation) throws Throwable {
+ assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+ InputStream is = (InputStream) invocation.getArguments()[1];
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
+ } else {
+ return is;
+ }
+ }
+ }
+ );
+
+ when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer(
+ new Answer<OutputStream>() {
+ @Override
+ public OutputStream answer(InvocationOnMock invocation) throws Throwable {
+ assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+ OutputStream os = (OutputStream) invocation.getArguments()[1];
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
+ } else {
+ return os;
+ }
+ }
+ }
+ );
+
+ when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+ partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
+ return null;
+ }
+ }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
+
+ when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+ new Answer<Tuple2<TempShuffleBlockId, File>>() {
+ @Override
+ public Tuple2<TempShuffleBlockId, File> answer(
+ InvocationOnMock invocationOnMock) throws Throwable {
+ TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
+ File file = File.createTempFile("spillFile", ".spill", tempDir);
+ spillFilesCreated.add(file);
+ return Tuple2$.MODULE$.apply(blockId, file);
+ }
+ });
+
+ when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+ when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
+
+ when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
+ when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
+ }
+
+ private UnsafeShuffleWriter<Object, Object> createWriter(
+ boolean transferToEnabled) throws IOException {
+ conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
+ return new UnsafeShuffleWriter<Object, Object>(
+ blockManager,
+ shuffleBlockResolver,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep),
+ 0, // map id
+ taskContext,
+ conf
+ );
+ }
+
+ private void assertSpillFilesWereCleanedUp() {
+ for (File spillFile : spillFilesCreated) {
+ assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+ spillFile.exists());
+ }
+ }
+
+ private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
+ final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>();
+ long startOffset = 0;
+ for (int i = 0; i < NUM_PARTITITONS; i++) {
+ final long partitionSize = partitionSizesInMergedFile[i];
+ if (partitionSize > 0) {
+ InputStream in = new FileInputStream(mergedOutputFile);
+ ByteStreams.skipFully(in, startOffset);
+ in = new LimitedInputStream(in, partitionSize);
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
+ }
+ DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in);
+ Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
+ while (records.hasNext()) {
+ Tuple2<Object, Object> record = records.next();
+ assertEquals(i, hashPartitioner.getPartition(record._1()));
+ recordsList.add(record);
+ }
+ recordsStream.close();
+ startOffset += partitionSize;
+ }
+ }
+ return recordsList;
+ }
+
+ @Test(expected=IllegalStateException.class)
+ public void mustCallWriteBeforeSuccessfulStop() throws IOException {
+ createWriter(false).stop(true);
+ }
+
+ @Test
+ public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
+ createWriter(false).stop(false);
+ }
+
+ class PandaException extends RuntimeException {
+ }
+
+ @Test(expected=PandaException.class)
+ public void writeFailurePropagates() throws Exception {
+ class BadRecords extends scala.collection.AbstractIterator<Product2<Object, Object>> {
+ @Override public boolean hasNext() {
+ throw new PandaException();
+ }
+ @Override public Product2<Object, Object> next() {
+ return null;
+ }
+ }
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+ writer.write(new BadRecords());
+ }
+
+ @Test
+ public void writeEmptyIterator() throws Exception {
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+ writer.write(Iterators.<Product2<Object, Object>>emptyIterator());
+ final Option<MapStatus> mapStatus = writer.stop(true);
+ assertTrue(mapStatus.isDefined());
+ assertTrue(mergedOutputFile.exists());
+ assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
+ assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten());
+ assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten());
+ assertEquals(0, taskMetrics.diskBytesSpilled());
+ assertEquals(0, taskMetrics.memoryBytesSpilled());
+ }
+
+ @Test
+ public void writeWithoutSpilling() throws Exception {
+ // In this example, each partition should have exactly one record:
+ final ArrayList<Product2<Object, Object>> dataToWrite =
+ new ArrayList<Product2<Object, Object>>();
+ for (int i = 0; i < NUM_PARTITITONS; i++) {
+ dataToWrite.add(new Tuple2<Object, Object>(i, i));
+ }
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+ writer.write(dataToWrite.iterator());
+ final Option<MapStatus> mapStatus = writer.stop(true);
+ assertTrue(mapStatus.isDefined());
+ assertTrue(mergedOutputFile.exists());
+
+ long sumOfPartitionSizes = 0;
+ for (long size: partitionSizesInMergedFile) {
+ // All partitions should be the same size:
+ assertEquals(partitionSizesInMergedFile[0], size);
+ sumOfPartitionSizes += size;
+ }
+ assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
+ assertEquals(
+ HashMultiset.create(dataToWrite),
+ HashMultiset.create(readRecordsFromFile()));
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertEquals(0, taskMetrics.diskBytesSpilled());
+ assertEquals(0, taskMetrics.memoryBytesSpilled());
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ private void testMergingSpills(
+ boolean transferToEnabled,
+ String compressionCodecName) throws IOException {
+ if (compressionCodecName != null) {
+ conf.set("spark.shuffle.compress", "true");
+ conf.set("spark.io.compression.codec", compressionCodecName);
+ } else {
+ conf.set("spark.shuffle.compress", "false");
+ }
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
+ final ArrayList<Product2<Object, Object>> dataToWrite =
+ new ArrayList<Product2<Object, Object>>();
+ for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
+ dataToWrite.add(new Tuple2<Object, Object>(i, i));
+ }
+ writer.insertRecordIntoSorter(dataToWrite.get(0));
+ writer.insertRecordIntoSorter(dataToWrite.get(1));
+ writer.insertRecordIntoSorter(dataToWrite.get(2));
+ writer.insertRecordIntoSorter(dataToWrite.get(3));
+ writer.forceSorterToSpill();
+ writer.insertRecordIntoSorter(dataToWrite.get(4));
+ writer.insertRecordIntoSorter(dataToWrite.get(5));
+ writer.closeAndWriteOutput();
+ final Option<MapStatus> mapStatus = writer.stop(true);
+ assertTrue(mapStatus.isDefined());
+ assertTrue(mergedOutputFile.exists());
+ assertEquals(2, spillFilesCreated.size());
+
+ long sumOfPartitionSizes = 0;
+ for (long size: partitionSizesInMergedFile) {
+ sumOfPartitionSizes += size;
+ }
+ assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
+
+ assertEquals(
+ HashMultiset.create(dataToWrite),
+ HashMultiset.create(readRecordsFromFile()));
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+ assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+ assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndLZF() throws Exception {
+ testMergingSpills(true, LZFCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndLZF() throws Exception {
+ testMergingSpills(false, LZFCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndLZ4() throws Exception {
+ testMergingSpills(true, LZ4CompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
+ testMergingSpills(false, LZ4CompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndSnappy() throws Exception {
+ testMergingSpills(true, SnappyCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
+ testMergingSpills(false, SnappyCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
+ testMergingSpills(true, null);
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
+ testMergingSpills(false, null);
+ }
+
+ @Test
+ public void writeEnoughDataToTriggerSpill() throws Exception {
+ when(shuffleMemoryManager.tryToAcquire(anyLong()))
+ .then(returnsFirstArg()) // Allocate initial sort buffer
+ .then(returnsFirstArg()) // Allocate initial data page
+ .thenReturn(0L) // Deny request to allocate new data page
+ .then(returnsFirstArg()); // Grant new sort buffer and data page.
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+ final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+ final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
+ for (int i = 0; i < 128 + 1; i++) {
+ dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
+ }
+ writer.write(dataToWrite.iterator());
+ verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+ assertEquals(2, spillFilesCreated.size());
+ writer.stop(true);
+ readRecordsFromFile();
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+ assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+ assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ @Test
+ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
+ when(shuffleMemoryManager.tryToAcquire(anyLong()))
+ .then(returnsFirstArg()) // Allocate initial sort buffer
+ .then(returnsFirstArg()) // Allocate initial data page
+ .thenReturn(0L) // Deny request to grow sort buffer
+ .then(returnsFirstArg()); // Grant new sort buffer and data page.
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+ final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+ for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
+ dataToWrite.add(new Tuple2<Object, Object>(i, i));
+ }
+ writer.write(dataToWrite.iterator());
+ verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+ assertEquals(2, spillFilesCreated.size());
+ writer.stop(true);
+ readRecordsFromFile();
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+ assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+ assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ @Test
+ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+ final ArrayList<Product2<Object, Object>> dataToWrite =
+ new ArrayList<Product2<Object, Object>>();
+ final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
+ new Random(42).nextBytes(bytes);
+ dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
+ writer.write(dataToWrite.iterator());
+ writer.stop(true);
+ assertEquals(
+ HashMultiset.create(dataToWrite),
+ HashMultiset.create(readRecordsFromFile()));
+ assertSpillFilesWereCleanedUp();
+ }
+
+ @Test
+ public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+ final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+ dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(new byte[1])));
+ // We should be able to write a record that's right _at_ the max record size
+ final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
+ new Random(42).nextBytes(atMaxRecordSize);
+ dataToWrite.add(new Tuple2<Object, Object>(2, ByteBuffer.wrap(atMaxRecordSize)));
+ // Inserting a record that's larger than the max record size
+ final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1];
+ new Random(42).nextBytes(exceedsMaxRecordSize);
+ dataToWrite.add(new Tuple2<Object, Object>(3, ByteBuffer.wrap(exceedsMaxRecordSize)));
+ writer.write(dataToWrite.iterator());
+ writer.stop(true);
+ assertEquals(
+ HashMultiset.create(dataToWrite),
+ HashMultiset.create(readRecordsFromFile()));
+ assertSpillFilesWereCleanedUp();
+ }
+
+ @Test
+ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+ writer.forceSorterToSpill();
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+ writer.stop(false);
+ assertSpillFilesWereCleanedUp();
+ }
+
+ @Test
+ public void testPeakMemoryUsed() throws Exception {
+ final long recordLengthBytes = 8;
+ final long pageSizeBytes = 256;
+ final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
+ when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
+ final UnsafeShuffleWriter<Object, Object> writer =
+ new UnsafeShuffleWriter<Object, Object>(
+ blockManager,
+ shuffleBlockResolver,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ new SerializedShuffleHandle<>(0, 1, shuffleDep),
+ 0, // map id
+ taskContext,
+ conf);
+
+ // Peak memory should be monotonically increasing. More specifically, every time
+ // we allocate a new page it should increase by exactly the size of the page.
+ long previousPeakMemory = writer.getPeakMemoryUsedBytes();
+ long newPeakMemory;
+ try {
+ for (int i = 0; i < numRecordsPerPage * 10; i++) {
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ if (i % numRecordsPerPage == 0 && i != 0) {
+ // The first page is allocated in constructor, another page will be allocated after
+ // every numRecordsPerPage records (peak memory should change).
+ assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
+ } else {
+ assertEquals(previousPeakMemory, newPeakMemory);
+ }
+ previousPeakMemory = newPeakMemory;
+ }
+
+ // Spilling should not change peak memory
+ writer.forceSorterToSpill();
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+ for (int i = 0; i < numRecordsPerPage; i++) {
+ writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+ }
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+
+ // Closing the writer should not change peak memory
+ writer.closeAndWriteOutput();
+ newPeakMemory = writer.getPeakMemoryUsedBytes();
+ assertEquals(previousPeakMemory, newPeakMemory);
+ } finally {
+ writer.stop(false);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
deleted file mode 100644
index 934b7e0..0000000
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import org.junit.Test;
-import static org.junit.Assert.*;
-
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
-
-public class PackedRecordPointerSuite {
-
- @Test
- public void heap() {
- final TaskMemoryManager memoryManager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
- final MemoryBlock page0 = memoryManager.allocatePage(128);
- final MemoryBlock page1 = memoryManager.allocatePage(128);
- final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
- page1.getBaseOffset() + 42);
- PackedRecordPointer packedPointer = new PackedRecordPointer();
- packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
- assertEquals(360, packedPointer.getPartitionId());
- final long recordPointer = packedPointer.getRecordPointer();
- assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
- assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
- assertEquals(addressInPage1, recordPointer);
- memoryManager.cleanUpAllAllocatedMemory();
- }
-
- @Test
- public void offHeap() {
- final TaskMemoryManager memoryManager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
- final MemoryBlock page0 = memoryManager.allocatePage(128);
- final MemoryBlock page1 = memoryManager.allocatePage(128);
- final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
- page1.getBaseOffset() + 42);
- PackedRecordPointer packedPointer = new PackedRecordPointer();
- packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
- assertEquals(360, packedPointer.getPartitionId());
- final long recordPointer = packedPointer.getRecordPointer();
- assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
- assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
- assertEquals(addressInPage1, recordPointer);
- memoryManager.cleanUpAllAllocatedMemory();
- }
-
- @Test
- public void maximumPartitionIdCanBeEncoded() {
- PackedRecordPointer packedPointer = new PackedRecordPointer();
- packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID));
- assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
- }
-
- @Test
- public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
- PackedRecordPointer packedPointer = new PackedRecordPointer();
- try {
- // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
- packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1));
- assertFalse(MAXIMUM_PARTITION_ID + 1 == packedPointer.getPartitionId());
- } catch (AssertionError e ) {
- // pass
- }
- }
-
- @Test
- public void maximumOffsetInPageCanBeEncoded() {
- PackedRecordPointer packedPointer = new PackedRecordPointer();
- long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1);
- packedPointer.set(PackedRecordPointer.packPointer(address, 0));
- assertEquals(address, packedPointer.getRecordPointer());
- }
-
- @Test
- public void offsetsPastMaxOffsetInPageWillOverflow() {
- PackedRecordPointer packedPointer = new PackedRecordPointer();
- long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES);
- packedPointer.set(PackedRecordPointer.packPointer(address, 0));
- assertEquals(0, packedPointer.getRecordPointer());
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
deleted file mode 100644
index 40fefe2..0000000
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import java.util.Arrays;
-import java.util.Random;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-import org.apache.spark.HashPartitioner;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-
-public class UnsafeShuffleInMemorySorterSuite {
-
- private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
- final byte[] strBytes = new byte[strLength];
- Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength);
- return new String(strBytes);
- }
-
- @Test
- public void testSortingEmptyInput() {
- final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100);
- final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
- assert(!iter.hasNext());
- }
-
- @Test
- public void testBasicSorting() throws Exception {
- final String[] dataToSort = new String[] {
- "Boba",
- "Pearls",
- "Tapioca",
- "Taho",
- "Condensed Milk",
- "Jasmine",
- "Milk Tea",
- "Lychee",
- "Mango"
- };
- final TaskMemoryManager memoryManager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
- final MemoryBlock dataPage = memoryManager.allocatePage(2048);
- final Object baseObject = dataPage.getBaseObject();
- final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
- final HashPartitioner hashPartitioner = new HashPartitioner(4);
-
- // Write the records into the data page and store pointers into the sorter
- long position = dataPage.getBaseOffset();
- for (String str : dataToSort) {
- final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
- final byte[] strBytes = str.getBytes("utf-8");
- Platform.putInt(baseObject, position, strBytes.length);
- position += 4;
- Platform.copyMemory(
- strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length);
- position += strBytes.length;
- sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
- }
-
- // Sort the records
- final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
- int prevPartitionId = -1;
- Arrays.sort(dataToSort);
- for (int i = 0; i < dataToSort.length; i++) {
- Assert.assertTrue(iter.hasNext());
- iter.loadNext();
- final int partitionId = iter.packedRecordPointer.getPartitionId();
- Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
- Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
- partitionId >= prevPartitionId);
- final long recordAddress = iter.packedRecordPointer.getRecordPointer();
- final int recordLength = Platform.getInt(
- memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
- final String str = getStringFromDataPage(
- memoryManager.getPage(recordAddress),
- memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length
- recordLength);
- Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
- }
- Assert.assertFalse(iter.hasNext());
- }
-
- @Test
- public void testSortingManyNumbers() throws Exception {
- UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
- int[] numbersToSort = new int[128000];
- Random random = new Random(16);
- for (int i = 0; i < numbersToSort.length; i++) {
- numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
- sorter.insertRecord(0, numbersToSort[i]);
- }
- Arrays.sort(numbersToSort);
- int[] sorterResult = new int[numbersToSort.length];
- UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
- int j = 0;
- while (iter.hasNext()) {
- iter.loadNext();
- sorterResult[j] = iter.packedRecordPointer.getPartitionId();
- j += 1;
- }
- Assert.assertArrayEquals(numbersToSort, sorterResult);
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
deleted file mode 100644
index d218344..0000000
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ /dev/null
@@ -1,560 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import java.io.*;
-import java.nio.ByteBuffer;
-import java.util.*;
-
-import scala.*;
-import scala.collection.Iterator;
-import scala.reflect.ClassTag;
-import scala.runtime.AbstractFunction1;
-
-import com.google.common.collect.Iterators;
-import com.google.common.collect.HashMultiset;
-import com.google.common.io.ByteStreams;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.hamcrest.Matchers.lessThan;
-import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
-import static org.mockito.Answers.RETURNS_SMART_NULLS;
-import static org.mockito.Mockito.*;
-
-import org.apache.spark.*;
-import org.apache.spark.io.CompressionCodec$;
-import org.apache.spark.io.LZ4CompressionCodec;
-import org.apache.spark.io.LZFCompressionCodec;
-import org.apache.spark.io.SnappyCompressionCodec;
-import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.executor.TaskMetrics;
-import org.apache.spark.network.util.LimitedInputStream;
-import org.apache.spark.serializer.*;
-import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.storage.*;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import org.apache.spark.util.Utils;
-
-public class UnsafeShuffleWriterSuite {
-
- static final int NUM_PARTITITONS = 4;
- final TaskMemoryManager taskMemoryManager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
- final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
- File mergedOutputFile;
- File tempDir;
- long[] partitionSizesInMergedFile;
- final LinkedList<File> spillFilesCreated = new LinkedList<File>();
- SparkConf conf;
- final Serializer serializer = new KryoSerializer(new SparkConf());
- TaskMetrics taskMetrics;
-
- @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
- @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
- @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
- @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
- @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
- @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
-
- private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
- @Override
- public OutputStream apply(OutputStream stream) {
- if (conf.getBoolean("spark.shuffle.compress", true)) {
- return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
- } else {
- return stream;
- }
- }
- }
-
- @After
- public void tearDown() {
- Utils.deleteRecursively(tempDir);
- final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
- if (leakedMemory != 0) {
- fail("Test leaked " + leakedMemory + " bytes of managed memory");
- }
- }
-
- @Before
- @SuppressWarnings("unchecked")
- public void setUp() throws IOException {
- MockitoAnnotations.initMocks(this);
- tempDir = Utils.createTempDir("test", "test");
- mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
- partitionSizesInMergedFile = null;
- spillFilesCreated.clear();
- conf = new SparkConf().set("spark.buffer.pageSize", "128m");
- taskMetrics = new TaskMetrics();
-
- when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
- when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024);
-
- when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
- when(blockManager.getDiskWriter(
- any(BlockId.class),
- any(File.class),
- any(SerializerInstance.class),
- anyInt(),
- any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
- @Override
- public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
- Object[] args = invocationOnMock.getArguments();
-
- return new DiskBlockObjectWriter(
- (File) args[1],
- (SerializerInstance) args[2],
- (Integer) args[3],
- new CompressStream(),
- false,
- (ShuffleWriteMetrics) args[4]
- );
- }
- });
- when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer(
- new Answer<InputStream>() {
- @Override
- public InputStream answer(InvocationOnMock invocation) throws Throwable {
- assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
- InputStream is = (InputStream) invocation.getArguments()[1];
- if (conf.getBoolean("spark.shuffle.compress", true)) {
- return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
- } else {
- return is;
- }
- }
- }
- );
-
- when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer(
- new Answer<OutputStream>() {
- @Override
- public OutputStream answer(InvocationOnMock invocation) throws Throwable {
- assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
- OutputStream os = (OutputStream) invocation.getArguments()[1];
- if (conf.getBoolean("spark.shuffle.compress", true)) {
- return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
- } else {
- return os;
- }
- }
- }
- );
-
- when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
- doAnswer(new Answer<Void>() {
- @Override
- public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
- partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
- return null;
- }
- }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
-
- when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
- new Answer<Tuple2<TempShuffleBlockId, File>>() {
- @Override
- public Tuple2<TempShuffleBlockId, File> answer(
- InvocationOnMock invocationOnMock) throws Throwable {
- TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
- File file = File.createTempFile("spillFile", ".spill", tempDir);
- spillFilesCreated.add(file);
- return Tuple2$.MODULE$.apply(blockId, file);
- }
- });
-
- when(taskContext.taskMetrics()).thenReturn(taskMetrics);
- when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
-
- when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
- when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
- }
-
- private UnsafeShuffleWriter<Object, Object> createWriter(
- boolean transferToEnabled) throws IOException {
- conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
- return new UnsafeShuffleWriter<Object, Object>(
- blockManager,
- shuffleBlockResolver,
- taskMemoryManager,
- shuffleMemoryManager,
- new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
- 0, // map id
- taskContext,
- conf
- );
- }
-
- private void assertSpillFilesWereCleanedUp() {
- for (File spillFile : spillFilesCreated) {
- assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
- spillFile.exists());
- }
- }
-
- private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
- final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>();
- long startOffset = 0;
- for (int i = 0; i < NUM_PARTITITONS; i++) {
- final long partitionSize = partitionSizesInMergedFile[i];
- if (partitionSize > 0) {
- InputStream in = new FileInputStream(mergedOutputFile);
- ByteStreams.skipFully(in, startOffset);
- in = new LimitedInputStream(in, partitionSize);
- if (conf.getBoolean("spark.shuffle.compress", true)) {
- in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
- }
- DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in);
- Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
- while (records.hasNext()) {
- Tuple2<Object, Object> record = records.next();
- assertEquals(i, hashPartitioner.getPartition(record._1()));
- recordsList.add(record);
- }
- recordsStream.close();
- startOffset += partitionSize;
- }
- }
- return recordsList;
- }
-
- @Test(expected=IllegalStateException.class)
- public void mustCallWriteBeforeSuccessfulStop() throws IOException {
- createWriter(false).stop(true);
- }
-
- @Test
- public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
- createWriter(false).stop(false);
- }
-
- class PandaException extends RuntimeException {
- }
-
- @Test(expected=PandaException.class)
- public void writeFailurePropagates() throws Exception {
- class BadRecords extends scala.collection.AbstractIterator<Product2<Object, Object>> {
- @Override public boolean hasNext() {
- throw new PandaException();
- }
- @Override public Product2<Object, Object> next() {
- return null;
- }
- }
- final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
- writer.write(new BadRecords());
- }
-
- @Test
- public void writeEmptyIterator() throws Exception {
- final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
- writer.write(Iterators.<Product2<Object, Object>>emptyIterator());
- final Option<MapStatus> mapStatus = writer.stop(true);
- assertTrue(mapStatus.isDefined());
- assertTrue(mergedOutputFile.exists());
- assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
- assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten());
- assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten());
- assertEquals(0, taskMetrics.diskBytesSpilled());
- assertEquals(0, taskMetrics.memoryBytesSpilled());
- }
-
- @Test
- public void writeWithoutSpilling() throws Exception {
- // In this example, each partition should have exactly one record:
- final ArrayList<Product2<Object, Object>> dataToWrite =
- new ArrayList<Product2<Object, Object>>();
- for (int i = 0; i < NUM_PARTITITONS; i++) {
- dataToWrite.add(new Tuple2<Object, Object>(i, i));
- }
- final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
- writer.write(dataToWrite.iterator());
- final Option<MapStatus> mapStatus = writer.stop(true);
- assertTrue(mapStatus.isDefined());
- assertTrue(mergedOutputFile.exists());
-
- long sumOfPartitionSizes = 0;
- for (long size: partitionSizesInMergedFile) {
- // All partitions should be the same size:
- assertEquals(partitionSizesInMergedFile[0], size);
- sumOfPartitionSizes += size;
- }
- assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
- assertEquals(
- HashMultiset.create(dataToWrite),
- HashMultiset.create(readRecordsFromFile()));
- assertSpillFilesWereCleanedUp();
- ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
- assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
- assertEquals(0, taskMetrics.diskBytesSpilled());
- assertEquals(0, taskMetrics.memoryBytesSpilled());
- assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
- }
-
- private void testMergingSpills(
- boolean transferToEnabled,
- String compressionCodecName) throws IOException {
- if (compressionCodecName != null) {
- conf.set("spark.shuffle.compress", "true");
- conf.set("spark.io.compression.codec", compressionCodecName);
- } else {
- conf.set("spark.shuffle.compress", "false");
- }
- final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
- final ArrayList<Product2<Object, Object>> dataToWrite =
- new ArrayList<Product2<Object, Object>>();
- for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
- dataToWrite.add(new Tuple2<Object, Object>(i, i));
- }
- writer.insertRecordIntoSorter(dataToWrite.get(0));
- writer.insertRecordIntoSorter(dataToWrite.get(1));
- writer.insertRecordIntoSorter(dataToWrite.get(2));
- writer.insertRecordIntoSorter(dataToWrite.get(3));
- writer.forceSorterToSpill();
- writer.insertRecordIntoSorter(dataToWrite.get(4));
- writer.insertRecordIntoSorter(dataToWrite.get(5));
- writer.closeAndWriteOutput();
- final Option<MapStatus> mapStatus = writer.stop(true);
- assertTrue(mapStatus.isDefined());
- assertTrue(mergedOutputFile.exists());
- assertEquals(2, spillFilesCreated.size());
-
- long sumOfPartitionSizes = 0;
- for (long size: partitionSizesInMergedFile) {
- sumOfPartitionSizes += size;
- }
- assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
-
- assertEquals(
- HashMultiset.create(dataToWrite),
- HashMultiset.create(readRecordsFromFile()));
- assertSpillFilesWereCleanedUp();
- ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
- assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
- assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
- assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
- assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
- assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
- }
-
- @Test
- public void mergeSpillsWithTransferToAndLZF() throws Exception {
- testMergingSpills(true, LZFCompressionCodec.class.getName());
- }
-
- @Test
- public void mergeSpillsWithFileStreamAndLZF() throws Exception {
- testMergingSpills(false, LZFCompressionCodec.class.getName());
- }
-
- @Test
- public void mergeSpillsWithTransferToAndLZ4() throws Exception {
- testMergingSpills(true, LZ4CompressionCodec.class.getName());
- }
-
- @Test
- public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
- testMergingSpills(false, LZ4CompressionCodec.class.getName());
- }
-
- @Test
- public void mergeSpillsWithTransferToAndSnappy() throws Exception {
- testMergingSpills(true, SnappyCompressionCodec.class.getName());
- }
-
- @Test
- public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
- testMergingSpills(false, SnappyCompressionCodec.class.getName());
- }
-
- @Test
- public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
- testMergingSpills(true, null);
- }
-
- @Test
- public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
- testMergingSpills(false, null);
- }
-
- @Test
- public void writeEnoughDataToTriggerSpill() throws Exception {
- when(shuffleMemoryManager.tryToAcquire(anyLong()))
- .then(returnsFirstArg()) // Allocate initial sort buffer
- .then(returnsFirstArg()) // Allocate initial data page
- .thenReturn(0L) // Deny request to allocate new data page
- .then(returnsFirstArg()); // Grant new sort buffer and data page.
- final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
- final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
- final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
- for (int i = 0; i < 128 + 1; i++) {
- dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
- }
- writer.write(dataToWrite.iterator());
- verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
- assertEquals(2, spillFilesCreated.size());
- writer.stop(true);
- readRecordsFromFile();
- assertSpillFilesWereCleanedUp();
- ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
- assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
- assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
- assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
- assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
- assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
- }
-
- @Test
- public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
- when(shuffleMemoryManager.tryToAcquire(anyLong()))
- .then(returnsFirstArg()) // Allocate initial sort buffer
- .then(returnsFirstArg()) // Allocate initial data page
- .thenReturn(0L) // Deny request to grow sort buffer
- .then(returnsFirstArg()); // Grant new sort buffer and data page.
- final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
- final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
- for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
- dataToWrite.add(new Tuple2<Object, Object>(i, i));
- }
- writer.write(dataToWrite.iterator());
- verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
- assertEquals(2, spillFilesCreated.size());
- writer.stop(true);
- readRecordsFromFile();
- assertSpillFilesWereCleanedUp();
- ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
- assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
- assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
- assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
- assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
- assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
- }
-
- @Test
- public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
- final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
- final ArrayList<Product2<Object, Object>> dataToWrite =
- new ArrayList<Product2<Object, Object>>();
- final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
- new Random(42).nextBytes(bytes);
- dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
- writer.write(dataToWrite.iterator());
- writer.stop(true);
- assertEquals(
- HashMultiset.create(dataToWrite),
- HashMultiset.create(readRecordsFromFile()));
- assertSpillFilesWereCleanedUp();
- }
-
- @Test
- public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
- final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
- final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
- dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(new byte[1])));
- // We should be able to write a record that's right _at_ the max record size
- final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
- new Random(42).nextBytes(atMaxRecordSize);
- dataToWrite.add(new Tuple2<Object, Object>(2, ByteBuffer.wrap(atMaxRecordSize)));
- // Inserting a record that's larger than the max record size
- final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1];
- new Random(42).nextBytes(exceedsMaxRecordSize);
- dataToWrite.add(new Tuple2<Object, Object>(3, ByteBuffer.wrap(exceedsMaxRecordSize)));
- writer.write(dataToWrite.iterator());
- writer.stop(true);
- assertEquals(
- HashMultiset.create(dataToWrite),
- HashMultiset.create(readRecordsFromFile()));
- assertSpillFilesWereCleanedUp();
- }
-
- @Test
- public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
- final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
- writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
- writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
- writer.forceSorterToSpill();
- writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
- writer.stop(false);
- assertSpillFilesWereCleanedUp();
- }
-
- @Test
- public void testPeakMemoryUsed() throws Exception {
- final long recordLengthBytes = 8;
- final long pageSizeBytes = 256;
- final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
- when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
- final UnsafeShuffleWriter<Object, Object> writer =
- new UnsafeShuffleWriter<Object, Object>(
- blockManager,
- shuffleBlockResolver,
- taskMemoryManager,
- shuffleMemoryManager,
- new UnsafeShuffleHandle<>(0, 1, shuffleDep),
- 0, // map id
- taskContext,
- conf);
-
- // Peak memory should be monotonically increasing. More specifically, every time
- // we allocate a new page it should increase by exactly the size of the page.
- long previousPeakMemory = writer.getPeakMemoryUsedBytes();
- long newPeakMemory;
- try {
- for (int i = 0; i < numRecordsPerPage * 10; i++) {
- writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
- newPeakMemory = writer.getPeakMemoryUsedBytes();
- if (i % numRecordsPerPage == 0 && i != 0) {
- // The first page is allocated in constructor, another page will be allocated after
- // every numRecordsPerPage records (peak memory should change).
- assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
- } else {
- assertEquals(previousPeakMemory, newPeakMemory);
- }
- previousPeakMemory = newPeakMemory;
- }
-
- // Spilling should not change peak memory
- writer.forceSorterToSpill();
- newPeakMemory = writer.getPeakMemoryUsedBytes();
- assertEquals(previousPeakMemory, newPeakMemory);
- for (int i = 0; i < numRecordsPerPage; i++) {
- writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
- }
- newPeakMemory = writer.getPeakMemoryUsedBytes();
- assertEquals(previousPeakMemory, newPeakMemory);
-
- // Closing the writer should not change peak memory
- writer.closeAndWriteOutput();
- newPeakMemory = writer.getPeakMemoryUsedBytes();
- assertEquals(previousPeakMemory, newPeakMemory);
- } finally {
- writer.stop(false);
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
index 6335817..b8ab227 100644
--- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -17,13 +17,78 @@
package org.apache.spark
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.io.filefilter.TrueFileFilter
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.util.Utils
+
class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with sort-based shuffle.
+ private var tempDir: File = _
+
override def beforeAll() {
conf.set("spark.shuffle.manager", "sort")
}
+
+ override def beforeEach(): Unit = {
+ tempDir = Utils.createTempDir()
+ conf.set("spark.local.dir", tempDir.getAbsolutePath)
+ }
+
+ override def afterEach(): Unit = {
+ try {
+ Utils.deleteRecursively(tempDir)
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") {
+ sc = new SparkContext("local", "test", conf)
+ // Create a shuffled RDD and verify that it actually uses the new serialized map output path
+ val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+ val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+ .setSerializer(new KryoSerializer(conf))
+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+ assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+ ensureFilesAreCleanedUp(shuffledRdd)
+ }
+
+ test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") {
+ sc = new SparkContext("local", "test", conf)
+ // Create a shuffled RDD and verify that it actually uses the old deserialized map output path
+ val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+ val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+ .setSerializer(new JavaSerializer(conf))
+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+ assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+ ensureFilesAreCleanedUp(shuffledRdd)
+ }
+
+ private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = {
+ def getAllFiles: Set[File] =
+ FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+ val filesBeforeShuffle = getAllFiles
+ // Force the shuffle to be performed
+ shuffledRdd.count()
+ // Ensure that the shuffle actually created files that will need to be cleaned up
+ val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+ filesCreatedByShuffle.map(_.getName) should be
+ Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+ // Check that the cleanup actually removes the files
+ sc.env.blockManager.master.removeShuffle(0, blocking = true)
+ for (file <- filesCreatedByShuffle) {
+ assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 5b01ddb..3816b8c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1062,10 +1062,10 @@ class DAGSchedulerSuite
*/
test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") {
val firstRDD = new MyRDD(sc, 3, Nil)
- val firstShuffleDep = new ShuffleDependency(firstRDD, null)
+ val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2))
val firstShuffleId = firstShuffleDep.shuffleId
val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep))
- val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
submit(reduceRdd, Array(0))
@@ -1175,7 +1175,7 @@ class DAGSchedulerSuite
*/
test("register map outputs correctly after ExecutorLost and task Resubmitted") {
val firstRDD = new MyRDD(sc, 3, Nil)
- val firstShuffleDep = new ShuffleDependency(firstRDD, null)
+ val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2))
val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep))
submit(reduceRdd, Array(0))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org