You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/14 19:40:52 UTC
git commit: move some test file to match src code
Repository: spark
Updated Branches:
refs/heads/master aab534966 -> 38ccd6ebd
move some test file to match src code
Just move some test suite to corresponding package
Author: Daoyuan <da...@intel.com>
Closes #1401 from adrian-wang/movetestfiles and squashes the following commits:
d1a6803 [Daoyuan] move some test file to match src code
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/38ccd6eb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/38ccd6eb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/38ccd6eb
Branch: refs/heads/master
Commit: 38ccd6ebd412cfbf82ae9d8a0998ff697db11455
Parents: aab5349
Author: Daoyuan <da...@intel.com>
Authored: Mon Jul 14 10:40:44 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Mon Jul 14 10:40:44 2014 -0700
----------------------------------------------------------------------
.../scala/org/apache/spark/AkkaUtilsSuite.scala | 211 -------------
.../scala/org/apache/spark/BroadcastSuite.scala | 315 -------------------
.../apache/spark/ConnectionManagerSuite.scala | 230 --------------
.../scala/org/apache/spark/PipedRDDSuite.scala | 196 ------------
.../apache/spark/ZippedPartitionsSuite.scala | 41 ---
.../apache/spark/broadcast/BroadcastSuite.scala | 313 ++++++++++++++++++
.../spark/network/ConnectionManagerSuite.scala | 229 ++++++++++++++
.../org/apache/spark/rdd/PipedRDDSuite.scala | 192 +++++++++++
.../spark/rdd/ZippedPartitionsSuite.scala | 42 +++
.../org/apache/spark/util/AkkaUtilsSuite.scala | 211 +++++++++++++
10 files changed, 987 insertions(+), 993 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/38ccd6eb/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
deleted file mode 100644
index 4ab870e..0000000
--- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
+++ /dev/null
@@ -1,211 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import org.scalatest.FunSuite
-
-import akka.actor._
-import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.AkkaUtils
-import scala.concurrent.Await
-
-/**
- * Test the AkkaUtils with various security settings.
- */
-class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
-
- test("remote fetch security bad password") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "true")
- conf.set("spark.authenticate.secret", "good")
-
- val securityManager = new SecurityManager(conf);
- val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
- assert(securityManager.isAuthenticationEnabled() === true)
-
- val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
-
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "true")
- badconf.set("spark.authenticate.secret", "bad")
- val securityManagerBad = new SecurityManager(badconf)
-
- assert(securityManagerBad.isAuthenticationEnabled() === true)
-
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = conf, securityManager = securityManagerBad)
- val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
- val timeout = AkkaUtils.lookupTimeout(conf)
- intercept[akka.actor.ActorNotFound] {
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
- }
-
- actorSystem.shutdown()
- slaveSystem.shutdown()
- }
-
- test("remote fetch security off") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "false")
- conf.set("spark.authenticate.secret", "bad")
- val securityManager = new SecurityManager(conf);
-
- val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
-
- assert(securityManager.isAuthenticationEnabled() === false)
-
- val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
-
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "false")
- badconf.set("spark.authenticate.secret", "good")
- val securityManagerBad = new SecurityManager(badconf);
-
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = badconf, securityManager = securityManagerBad)
- val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
- val timeout = AkkaUtils.lookupTimeout(conf)
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
-
- assert(securityManagerBad.isAuthenticationEnabled() === false)
-
- masterTracker.registerShuffle(10, 1)
- masterTracker.incrementEpoch()
- slaveTracker.updateEpoch(masterTracker.getEpoch)
-
- val compressedSize1000 = MapOutputTracker.compressSize(1000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
- masterTracker.incrementEpoch()
- slaveTracker.updateEpoch(masterTracker.getEpoch)
-
- // this should succeed since security off
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
-
- actorSystem.shutdown()
- slaveSystem.shutdown()
- }
-
- test("remote fetch security pass") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "true")
- conf.set("spark.authenticate.secret", "good")
- val securityManager = new SecurityManager(conf);
-
- val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
-
- assert(securityManager.isAuthenticationEnabled() === true)
-
- val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
-
- val goodconf = new SparkConf
- goodconf.set("spark.authenticate", "true")
- goodconf.set("spark.authenticate.secret", "good")
- val securityManagerGood = new SecurityManager(goodconf);
-
- assert(securityManagerGood.isAuthenticationEnabled() === true)
-
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = goodconf, securityManager = securityManagerGood)
- val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
- val timeout = AkkaUtils.lookupTimeout(conf)
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
-
- masterTracker.registerShuffle(10, 1)
- masterTracker.incrementEpoch()
- slaveTracker.updateEpoch(masterTracker.getEpoch)
-
- val compressedSize1000 = MapOutputTracker.compressSize(1000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
- masterTracker.incrementEpoch()
- slaveTracker.updateEpoch(masterTracker.getEpoch)
-
- // this should succeed since security on and passwords match
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
-
- actorSystem.shutdown()
- slaveSystem.shutdown()
- }
-
- test("remote fetch security off client") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "true")
- conf.set("spark.authenticate.secret", "good")
-
- val securityManager = new SecurityManager(conf);
-
- val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
- conf = conf, securityManager = securityManager)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
-
- assert(securityManager.isAuthenticationEnabled() === true)
-
- val masterTracker = new MapOutputTrackerMaster(conf)
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
-
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "false")
- badconf.set("spark.authenticate.secret", "bad")
- val securityManagerBad = new SecurityManager(badconf);
-
- assert(securityManagerBad.isAuthenticationEnabled() === false)
-
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
- conf = badconf, securityManager = securityManagerBad)
- val slaveTracker = new MapOutputTrackerWorker(conf)
- val selection = slaveSystem.actorSelection(
- s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
- val timeout = AkkaUtils.lookupTimeout(conf)
- intercept[akka.actor.ActorNotFound] {
- slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
- }
-
- actorSystem.shutdown()
- slaveSystem.shutdown()
- }
-
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/38ccd6eb/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
deleted file mode 100644
index c993625..0000000
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ /dev/null
@@ -1,315 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.storage._
-import org.apache.spark.broadcast.{Broadcast, HttpBroadcast}
-import org.apache.spark.storage.BroadcastBlockId
-
-class BroadcastSuite extends FunSuite with LocalSparkContext {
-
- private val httpConf = broadcastConf("HttpBroadcastFactory")
- private val torrentConf = broadcastConf("TorrentBroadcastFactory")
-
- test("Using HttpBroadcast locally") {
- sc = new SparkContext("local", "test", httpConf)
- val list = List[Int](1, 2, 3, 4)
- val broadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
- assert(results.collect().toSet === Set((1, 10), (2, 10)))
- }
-
- test("Accessing HttpBroadcast variables from multiple threads") {
- sc = new SparkContext("local[10]", "test", httpConf)
- val list = List[Int](1, 2, 3, 4)
- val broadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
- assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
- }
-
- test("Accessing HttpBroadcast variables in a local cluster") {
- val numSlaves = 4
- sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf)
- val list = List[Int](1, 2, 3, 4)
- val broadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
- assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
- }
-
- test("Using TorrentBroadcast locally") {
- sc = new SparkContext("local", "test", torrentConf)
- val list = List[Int](1, 2, 3, 4)
- val broadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
- assert(results.collect().toSet === Set((1, 10), (2, 10)))
- }
-
- test("Accessing TorrentBroadcast variables from multiple threads") {
- sc = new SparkContext("local[10]", "test", torrentConf)
- val list = List[Int](1, 2, 3, 4)
- val broadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
- assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
- }
-
- test("Accessing TorrentBroadcast variables in a local cluster") {
- val numSlaves = 4
- sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf)
- val list = List[Int](1, 2, 3, 4)
- val broadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
- assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
- }
-
- test("Unpersisting HttpBroadcast on executors only in local mode") {
- testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
- }
-
- test("Unpersisting HttpBroadcast on executors and driver in local mode") {
- testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true)
- }
-
- test("Unpersisting HttpBroadcast on executors only in distributed mode") {
- testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false)
- }
-
- test("Unpersisting HttpBroadcast on executors and driver in distributed mode") {
- testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true)
- }
-
- test("Unpersisting TorrentBroadcast on executors only in local mode") {
- testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false)
- }
-
- test("Unpersisting TorrentBroadcast on executors and driver in local mode") {
- testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = true)
- }
-
- test("Unpersisting TorrentBroadcast on executors only in distributed mode") {
- testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = false)
- }
-
- test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") {
- testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true)
- }
- /**
- * Verify the persistence of state associated with an HttpBroadcast in either local mode or
- * local-cluster mode (when distributed = true).
- *
- * This test creates a broadcast variable, uses it on all executors, and then unpersists it.
- * In between each step, this test verifies that the broadcast blocks and the broadcast file
- * are present only on the expected nodes.
- */
- private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
- val numSlaves = if (distributed) 2 else 0
-
- def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
-
- // Verify that the broadcast file is created, and blocks are persisted only on the driver
- def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
- assert(statuses.size === 1)
- statuses.head match { case (bm, status) =>
- assert(bm.executorId === "<driver>", "Block should only be on the driver")
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store on the driver")
- assert(status.diskSize === 0, "Block should not be in disk store on the driver")
- }
- if (distributed) {
- // this file is only generated in distributed mode
- assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!")
- }
- }
-
- // Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
- assert(statuses.size === numSlaves + 1)
- statuses.foreach { case (_, status) =>
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store")
- assert(status.diskSize === 0, "Block should not be in disk store")
- }
- }
-
- // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
- // is true. In the latter case, also verify that the broadcast file is deleted on the driver.
- def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
- val expectedNumBlocks = if (removeFromDriver) 0 else 1
- val possiblyNot = if (removeFromDriver) "" else " not"
- assert(statuses.size === expectedNumBlocks,
- "Block should%s be unpersisted on the driver".format(possiblyNot))
- if (distributed && removeFromDriver) {
- // this file is only generated in distributed mode
- assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists,
- "Broadcast file should%s be deleted".format(possiblyNot))
- }
- }
-
- testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation,
- afterUsingBroadcast, afterUnpersist, removeFromDriver)
- }
-
- /**
- * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster.
- *
- * This test creates a broadcast variable, uses it on all executors, and then unpersists it.
- * In between each step, this test verifies that the broadcast blocks are present only on the
- * expected nodes.
- */
- private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
- val numSlaves = if (distributed) 2 else 0
-
- def getBlockIds(id: Long) = {
- val broadcastBlockId = BroadcastBlockId(id)
- val metaBlockId = BroadcastBlockId(id, "meta")
- // Assume broadcast value is small enough to fit into 1 piece
- val pieceBlockId = BroadcastBlockId(id, "piece0")
- if (distributed) {
- // the metadata and piece blocks are generated only in distributed mode
- Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
- } else {
- Seq[BroadcastBlockId](broadcastBlockId)
- }
- }
-
- // Verify that blocks are persisted only on the driver
- def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
- assert(statuses.size === 1)
- statuses.head match { case (bm, status) =>
- assert(bm.executorId === "<driver>", "Block should only be on the driver")
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store on the driver")
- assert(status.diskSize === 0, "Block should not be in disk store on the driver")
- }
- }
- }
-
- // Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- if (blockId.field == "meta") {
- // Meta data is only on the driver
- assert(statuses.size === 1)
- statuses.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
- } else {
- // Other blocks are on both the executors and the driver
- assert(statuses.size === numSlaves + 1,
- blockId + " has " + statuses.size + " statuses: " + statuses.mkString(","))
- statuses.foreach { case (_, status) =>
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store")
- assert(status.diskSize === 0, "Block should not be in disk store")
- }
- }
- }
- }
-
- // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
- // is true.
- def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- val expectedNumBlocks = if (removeFromDriver) 0 else 1
- val possiblyNot = if (removeFromDriver) "" else " not"
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- assert(statuses.size === expectedNumBlocks,
- "Block should%s be unpersisted on the driver".format(possiblyNot))
- }
- }
-
- testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation,
- afterUsingBroadcast, afterUnpersist, removeFromDriver)
- }
-
- /**
- * This test runs in 4 steps:
- *
- * 1) Create broadcast variable, and verify that all state is persisted on the driver.
- * 2) Use the broadcast variable on all executors, and verify that all state is persisted
- * on both the driver and the executors.
- * 3) Unpersist the broadcast, and verify that all state is removed where they should be.
- * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable.
- */
- private def testUnpersistBroadcast(
- distributed: Boolean,
- numSlaves: Int, // used only when distributed = true
- broadcastConf: SparkConf,
- getBlockIds: Long => Seq[BroadcastBlockId],
- afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
- afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
- afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
- removeFromDriver: Boolean) {
-
- sc = if (distributed) {
- new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf)
- } else {
- new SparkContext("local", "test", broadcastConf)
- }
- val blockManagerMaster = sc.env.blockManager.master
- val list = List[Int](1, 2, 3, 4)
-
- // Create broadcast variable
- val broadcast = sc.broadcast(list)
- val blocks = getBlockIds(broadcast.id)
- afterCreation(blocks, blockManagerMaster)
-
- // Use broadcast variable on all executors
- val partitions = 10
- assert(partitions > numSlaves)
- val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
- assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
- afterUsingBroadcast(blocks, blockManagerMaster)
-
- // Unpersist broadcast
- if (removeFromDriver) {
- broadcast.destroy(blocking = true)
- } else {
- broadcast.unpersist(blocking = true)
- }
- afterUnpersist(blocks, blockManagerMaster)
-
- // If the broadcast is removed from driver, all subsequent uses of the broadcast variable
- // should throw SparkExceptions. Otherwise, the result should be the same as before.
- if (removeFromDriver) {
- // Using this variable on the executors crashes them, which hangs the test.
- // Instead, crash the driver by directly accessing the broadcast value.
- intercept[SparkException] { broadcast.value }
- intercept[SparkException] { broadcast.unpersist() }
- intercept[SparkException] { broadcast.destroy(blocking = true) }
- } else {
- val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
- assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
- }
- }
-
- /** Helper method to create a SparkConf that uses the given broadcast factory. */
- private def broadcastConf(factoryName: String): SparkConf = {
- val conf = new SparkConf
- conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName))
- conf
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/38ccd6eb/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
deleted file mode 100644
index df6b260..0000000
--- a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
+++ /dev/null
@@ -1,230 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import org.scalatest.FunSuite
-
-import java.nio._
-
-import org.apache.spark.network.{ConnectionManager, Message, ConnectionManagerId}
-import scala.concurrent.Await
-import scala.concurrent.TimeoutException
-import scala.concurrent.duration._
-import scala.language.postfixOps
-
-/**
- * Test the ConnectionManager with various security settings.
- */
-class ConnectionManagerSuite extends FunSuite {
-
- test("security default off") {
- val conf = new SparkConf
- val securityManager = new SecurityManager(conf)
- val manager = new ConnectionManager(0, conf, securityManager)
- var receivedMessage = false
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- receivedMessage = true
- None
- })
-
- val size = 10 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliablySync(manager.id, bufferMessage)
-
- assert(receivedMessage == true)
-
- manager.stop()
- }
-
- test("security on same password") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "true")
- conf.set("spark.authenticate.secret", "good")
- val securityManager = new SecurityManager(conf)
- val manager = new ConnectionManager(0, conf, securityManager)
- var numReceivedMessages = 0
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedMessages += 1
- None
- })
- val managerServer = new ConnectionManager(0, conf, securityManager)
- var numReceivedServerMessages = 0
- managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedServerMessages += 1
- None
- })
-
- val size = 10 * 1024 * 1024
- val count = 10
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
-
- (0 until count).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliablySync(managerServer.id, bufferMessage)
- })
-
- assert(numReceivedServerMessages == 10)
- assert(numReceivedMessages == 0)
-
- manager.stop()
- managerServer.stop()
- }
-
- test("security mismatch password") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "true")
- conf.set("spark.authenticate.secret", "good")
- val securityManager = new SecurityManager(conf)
- val manager = new ConnectionManager(0, conf, securityManager)
- var numReceivedMessages = 0
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedMessages += 1
- None
- })
-
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "true")
- badconf.set("spark.authenticate.secret", "bad")
- val badsecurityManager = new SecurityManager(badconf)
- val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
- var numReceivedServerMessages = 0
-
- managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedServerMessages += 1
- None
- })
-
- val size = 10 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliablySync(managerServer.id, bufferMessage)
-
- assert(numReceivedServerMessages == 0)
- assert(numReceivedMessages == 0)
-
- manager.stop()
- managerServer.stop()
- }
-
- test("security mismatch auth off") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "false")
- conf.set("spark.authenticate.secret", "good")
- val securityManager = new SecurityManager(conf)
- val manager = new ConnectionManager(0, conf, securityManager)
- var numReceivedMessages = 0
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedMessages += 1
- None
- })
-
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "true")
- badconf.set("spark.authenticate.secret", "good")
- val badsecurityManager = new SecurityManager(badconf)
- val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
- var numReceivedServerMessages = 0
- managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedServerMessages += 1
- None
- })
-
- val size = 10 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- (0 until 1).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliably(managerServer.id, bufferMessage)
- }).foreach(f => {
- try {
- val g = Await.result(f, 1 second)
- assert(false)
- } catch {
- case e: TimeoutException => {
- // we should timeout here since the client can't do the negotiation
- assert(true)
- }
- }
- })
-
- assert(numReceivedServerMessages == 0)
- assert(numReceivedMessages == 0)
- manager.stop()
- managerServer.stop()
- }
-
- test("security auth off") {
- val conf = new SparkConf
- conf.set("spark.authenticate", "false")
- val securityManager = new SecurityManager(conf)
- val manager = new ConnectionManager(0, conf, securityManager)
- var numReceivedMessages = 0
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedMessages += 1
- None
- })
-
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "false")
- val badsecurityManager = new SecurityManager(badconf)
- val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
- var numReceivedServerMessages = 0
-
- managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- numReceivedServerMessages += 1
- None
- })
-
- val size = 10 * 1024 * 1024
- val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
- buffer.flip
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- (0 until 10).map(i => {
- val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliably(managerServer.id, bufferMessage)
- }).foreach(f => {
- try {
- val g = Await.result(f, 1 second)
- if (!g.isDefined) assert(false) else assert(true)
- } catch {
- case e: Exception => {
- assert(false)
- }
- }
- })
- assert(numReceivedServerMessages == 10)
- assert(numReceivedMessages == 0)
-
- manager.stop()
- managerServer.stop()
- }
-
-
-
-}
-
http://git-wip-us.apache.org/repos/asf/spark/blob/38ccd6eb/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
deleted file mode 100644
index db56a4a..0000000
--- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
+++ /dev/null
@@ -1,196 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import java.io.File
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.rdd.{HadoopRDD, PipedRDD, HadoopPartition}
-import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit}
-import org.apache.hadoop.fs.Path
-
-import scala.collection.Map
-import scala.language.postfixOps
-import scala.sys.process._
-import scala.util.Try
-
-import org.apache.hadoop.io.{Text, LongWritable}
-
-import org.apache.spark.executor.TaskMetrics
-
-class PipedRDDSuite extends FunSuite with SharedSparkContext {
-
- test("basic pipe") {
- if (testCommandAvailable("cat")) {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
-
- val piped = nums.pipe(Seq("cat"))
-
- val c = piped.collect()
- assert(c.size === 4)
- assert(c(0) === "1")
- assert(c(1) === "2")
- assert(c(2) === "3")
- assert(c(3) === "4")
- } else {
- assert(true)
- }
- }
-
- test("advanced pipe") {
- if (testCommandAvailable("cat")) {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val bl = sc.broadcast(List("0"))
-
- val piped = nums.pipe(Seq("cat"),
- Map[String, String](),
- (f: String => Unit) => {
- bl.value.map(f(_)); f("\u0001")
- },
- (i: Int, f: String => Unit) => f(i + "_"))
-
- val c = piped.collect()
-
- assert(c.size === 8)
- assert(c(0) === "0")
- assert(c(1) === "\u0001")
- assert(c(2) === "1_")
- assert(c(3) === "2_")
- assert(c(4) === "0")
- assert(c(5) === "\u0001")
- assert(c(6) === "3_")
- assert(c(7) === "4_")
-
- val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
- val d = nums1.groupBy(str => str.split("\t")(0)).
- pipe(Seq("cat"),
- Map[String, String](),
- (f: String => Unit) => {
- bl.value.map(f(_)); f("\u0001")
- },
- (i: Tuple2[String, Iterable[String]], f: String => Unit) => {
- for (e <- i._2) {
- f(e + "_")
- }
- }).collect()
- assert(d.size === 8)
- assert(d(0) === "0")
- assert(d(1) === "\u0001")
- assert(d(2) === "b\t2_")
- assert(d(3) === "b\t4_")
- assert(d(4) === "0")
- assert(d(5) === "\u0001")
- assert(d(6) === "a\t1_")
- assert(d(7) === "a\t3_")
- } else {
- assert(true)
- }
- }
-
- test("pipe with env variable") {
- if (testCommandAvailable("printenv")) {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
- val c = piped.collect()
- assert(c.size === 2)
- assert(c(0) === "LALALA")
- assert(c(1) === "LALALA")
- } else {
- assert(true)
- }
- }
-
- test("pipe with non-zero exit status") {
- if (testCommandAvailable("cat")) {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null"))
- intercept[SparkException] {
- piped.collect()
- }
- } else {
- assert(true)
- }
- }
-
- test("basic pipe with separate working directory") {
- if (testCommandAvailable("cat")) {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val piped = nums.pipe(Seq("cat"), separateWorkingDir = true)
- val c = piped.collect()
- assert(c.size === 4)
- assert(c(0) === "1")
- assert(c(1) === "2")
- assert(c(2) === "3")
- assert(c(3) === "4")
- val pipedPwd = nums.pipe(Seq("pwd"), separateWorkingDir = true)
- val collectPwd = pipedPwd.collect()
- assert(collectPwd(0).contains("tasks/"))
- val pipedLs = nums.pipe(Seq("ls"), separateWorkingDir = true).collect()
- // make sure symlinks were created
- assert(pipedLs.length > 0)
- // clean up top level tasks directory
- new File("tasks").delete()
- } else {
- assert(true)
- }
- }
-
- test("test pipe exports map_input_file") {
- testExportInputFile("map_input_file")
- }
-
- test("test pipe exports mapreduce_map_input_file") {
- testExportInputFile("mapreduce_map_input_file")
- }
-
- def testCommandAvailable(command: String): Boolean = {
- Try(Process(command) !!).isSuccess
- }
-
- def testExportInputFile(varName: String) {
- if (testCommandAvailable("printenv")) {
- val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable],
- classOf[Text], 2) {
- override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition())
-
- override val getDependencies = List[Dependency[_]]()
-
- override def compute(theSplit: Partition, context: TaskContext) = {
- new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1),
- new Text("b"))))
- }
- }
- val hadoopPart1 = generateFakeHadoopPartition()
- val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContext(0, 0, 0)
- val rddIter = pipedRdd.compute(hadoopPart1, tContext)
- val arr = rddIter.toArray
- assert(arr(0) == "/some/path")
- } else {
- // printenv isn't available so just pass the test
- }
- }
-
- def generateFakeHadoopPartition(): HadoopPartition = {
- val split = new FileSplit(new Path("/some/path"), 0, 1,
- Array[String]("loc1", "loc2", "loc3", "loc4", "loc5"))
- new HadoopPartition(sc.newRddId(), 1, split)
- }
-
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/38ccd6eb/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala
deleted file mode 100644
index 4f87fd8..0000000
--- a/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import org.scalatest.FunSuite
-
-object ZippedPartitionsSuite {
- def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = {
- Iterator(i.toArray.size, s.toArray.size, d.toArray.size)
- }
-}
-
-class ZippedPartitionsSuite extends FunSuite with SharedSparkContext {
- test("print sizes") {
- val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2)
- val data3 = sc.makeRDD(Array(1.0, 2.0), 2)
-
- val zippedRDD = data1.zipPartitions(data2, data3)(ZippedPartitionsSuite.procZippedData)
-
- val obtainedSizes = zippedRDD.collect()
- val expectedSizes = Array(2, 3, 1, 2, 3, 1)
- assert(obtainedSizes.size == 6)
- assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2))
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/38ccd6eb/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
new file mode 100644
index 0000000..7c3d020
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -0,0 +1,313 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import org.apache.spark.storage.{BroadcastBlockId, _}
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
+import org.scalatest.FunSuite
+
+class BroadcastSuite extends FunSuite with LocalSparkContext {
+
+ private val httpConf = broadcastConf("HttpBroadcastFactory")
+ private val torrentConf = broadcastConf("TorrentBroadcastFactory")
+
+ test("Using HttpBroadcast locally") {
+ sc = new SparkContext("local", "test", httpConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === Set((1, 10), (2, 10)))
+ }
+
+ test("Accessing HttpBroadcast variables from multiple threads") {
+ sc = new SparkContext("local[10]", "test", httpConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
+ }
+
+ test("Accessing HttpBroadcast variables in a local cluster") {
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
+ test("Using TorrentBroadcast locally") {
+ sc = new SparkContext("local", "test", torrentConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === Set((1, 10), (2, 10)))
+ }
+
+ test("Accessing TorrentBroadcast variables from multiple threads") {
+ sc = new SparkContext("local[10]", "test", torrentConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
+ }
+
+ test("Accessing TorrentBroadcast variables in a local cluster") {
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
+ test("Unpersisting HttpBroadcast on executors only in local mode") {
+ testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
+ }
+
+ test("Unpersisting HttpBroadcast on executors and driver in local mode") {
+ testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true)
+ }
+
+ test("Unpersisting HttpBroadcast on executors only in distributed mode") {
+ testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false)
+ }
+
+ test("Unpersisting HttpBroadcast on executors and driver in distributed mode") {
+ testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors only in local mode") {
+ testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors and driver in local mode") {
+ testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = true)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors only in distributed mode") {
+ testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = false)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") {
+ testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true)
+ }
+ /**
+ * Verify the persistence of state associated with an HttpBroadcast in either local mode or
+ * local-cluster mode (when distributed = true).
+ *
+ * This test creates a broadcast variable, uses it on all executors, and then unpersists it.
+ * In between each step, this test verifies that the broadcast blocks and the broadcast file
+ * are present only on the expected nodes.
+ */
+ private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
+ val numSlaves = if (distributed) 2 else 0
+
+ def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
+
+ // Verify that the broadcast file is created, and blocks are persisted only on the driver
+ def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ assert(blockIds.size === 1)
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ assert(statuses.size === 1)
+ statuses.head match { case (bm, status) =>
+ assert(bm.executorId === "<driver>", "Block should only be on the driver")
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store on the driver")
+ assert(status.diskSize === 0, "Block should not be in disk store on the driver")
+ }
+ if (distributed) {
+ // this file is only generated in distributed mode
+ assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!")
+ }
+ }
+
+ // Verify that blocks are persisted in both the executors and the driver
+ def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ assert(blockIds.size === 1)
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ assert(statuses.size === numSlaves + 1)
+ statuses.foreach { case (_, status) =>
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store")
+ assert(status.diskSize === 0, "Block should not be in disk store")
+ }
+ }
+
+ // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
+ // is true. In the latter case, also verify that the broadcast file is deleted on the driver.
+ def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ assert(blockIds.size === 1)
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ val expectedNumBlocks = if (removeFromDriver) 0 else 1
+ val possiblyNot = if (removeFromDriver) "" else " not"
+ assert(statuses.size === expectedNumBlocks,
+ "Block should%s be unpersisted on the driver".format(possiblyNot))
+ if (distributed && removeFromDriver) {
+ // this file is only generated in distributed mode
+ assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists,
+ "Broadcast file should%s be deleted".format(possiblyNot))
+ }
+ }
+
+ testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation,
+ afterUsingBroadcast, afterUnpersist, removeFromDriver)
+ }
+
+ /**
+ * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster.
+ *
+ * This test creates a broadcast variable, uses it on all executors, and then unpersists it.
+ * In between each step, this test verifies that the broadcast blocks are present only on the
+ * expected nodes.
+ */
+ private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
+ val numSlaves = if (distributed) 2 else 0
+
+ def getBlockIds(id: Long) = {
+ val broadcastBlockId = BroadcastBlockId(id)
+ val metaBlockId = BroadcastBlockId(id, "meta")
+ // Assume broadcast value is small enough to fit into 1 piece
+ val pieceBlockId = BroadcastBlockId(id, "piece0")
+ if (distributed) {
+ // the metadata and piece blocks are generated only in distributed mode
+ Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
+ } else {
+ Seq[BroadcastBlockId](broadcastBlockId)
+ }
+ }
+
+ // Verify that blocks are persisted only on the driver
+ def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ blockIds.foreach { blockId =>
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ assert(statuses.size === 1)
+ statuses.head match { case (bm, status) =>
+ assert(bm.executorId === "<driver>", "Block should only be on the driver")
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store on the driver")
+ assert(status.diskSize === 0, "Block should not be in disk store on the driver")
+ }
+ }
+ }
+
+ // Verify that blocks are persisted in both the executors and the driver
+ def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ blockIds.foreach { blockId =>
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ if (blockId.field == "meta") {
+ // Meta data is only on the driver
+ assert(statuses.size === 1)
+ statuses.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
+ } else {
+ // Other blocks are on both the executors and the driver
+ assert(statuses.size === numSlaves + 1,
+ blockId + " has " + statuses.size + " statuses: " + statuses.mkString(","))
+ statuses.foreach { case (_, status) =>
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store")
+ assert(status.diskSize === 0, "Block should not be in disk store")
+ }
+ }
+ }
+ }
+
+ // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
+ // is true.
+ def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ val expectedNumBlocks = if (removeFromDriver) 0 else 1
+ val possiblyNot = if (removeFromDriver) "" else " not"
+ blockIds.foreach { blockId =>
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === expectedNumBlocks,
+ "Block should%s be unpersisted on the driver".format(possiblyNot))
+ }
+ }
+
+ testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation,
+ afterUsingBroadcast, afterUnpersist, removeFromDriver)
+ }
+
+ /**
+ * This test runs in 4 steps:
+ *
+ * 1) Create broadcast variable, and verify that all state is persisted on the driver.
+ * 2) Use the broadcast variable on all executors, and verify that all state is persisted
+ * on both the driver and the executors.
+ * 3) Unpersist the broadcast, and verify that all state is removed where they should be.
+ * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable.
+ */
+ private def testUnpersistBroadcast(
+ distributed: Boolean,
+ numSlaves: Int, // used only when distributed = true
+ broadcastConf: SparkConf,
+ getBlockIds: Long => Seq[BroadcastBlockId],
+ afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ removeFromDriver: Boolean) {
+
+ sc = if (distributed) {
+ new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf)
+ } else {
+ new SparkContext("local", "test", broadcastConf)
+ }
+ val blockManagerMaster = sc.env.blockManager.master
+ val list = List[Int](1, 2, 3, 4)
+
+ // Create broadcast variable
+ val broadcast = sc.broadcast(list)
+ val blocks = getBlockIds(broadcast.id)
+ afterCreation(blocks, blockManagerMaster)
+
+ // Use broadcast variable on all executors
+ val partitions = 10
+ assert(partitions > numSlaves)
+ val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
+ afterUsingBroadcast(blocks, blockManagerMaster)
+
+ // Unpersist broadcast
+ if (removeFromDriver) {
+ broadcast.destroy(blocking = true)
+ } else {
+ broadcast.unpersist(blocking = true)
+ }
+ afterUnpersist(blocks, blockManagerMaster)
+
+ // If the broadcast is removed from driver, all subsequent uses of the broadcast variable
+ // should throw SparkExceptions. Otherwise, the result should be the same as before.
+ if (removeFromDriver) {
+ // Using this variable on the executors crashes them, which hangs the test.
+ // Instead, crash the driver by directly accessing the broadcast value.
+ intercept[SparkException] { broadcast.value }
+ intercept[SparkException] { broadcast.unpersist() }
+ intercept[SparkException] { broadcast.destroy(blocking = true) }
+ } else {
+ val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
+ }
+ }
+
+ /** Helper method to create a SparkConf that uses the given broadcast factory. */
+ private def broadcastConf(factoryName: String): SparkConf = {
+ val conf = new SparkConf
+ conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName))
+ conf
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/38ccd6eb/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
new file mode 100644
index 0000000..415ad8c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio._
+
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.scalatest.FunSuite
+
+import scala.concurrent.{Await, TimeoutException}
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+/**
+ * Test the ConnectionManager with various security settings.
+ */
+class ConnectionManagerSuite extends FunSuite {
+
+ test("security default off") {
+ val conf = new SparkConf
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var receivedMessage = false
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ receivedMessage = true
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(manager.id, bufferMessage)
+
+ assert(receivedMessage == true)
+
+ manager.stop()
+ }
+
+ test("security on same password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+ val managerServer = new ConnectionManager(0, conf, securityManager)
+ var numReceivedServerMessages = 0
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val count = 10
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+ })
+
+ assert(numReceivedServerMessages == 10)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security mismatch password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "bad")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+
+ assert(numReceivedServerMessages == 0)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security mismatch auth off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "good")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ (0 until 1).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(managerServer.id, bufferMessage)
+ }).foreach(f => {
+ try {
+ val g = Await.result(f, 1 second)
+ assert(false)
+ } catch {
+ case e: TimeoutException => {
+ // we should timeout here since the client can't do the negotiation
+ assert(true)
+ }
+ }
+ })
+
+ assert(numReceivedServerMessages == 0)
+ assert(numReceivedMessages == 0)
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security auth off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ (0 until 10).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(managerServer.id, bufferMessage)
+ }).foreach(f => {
+ try {
+ val g = Await.result(f, 1 second)
+ if (!g.isDefined) assert(false) else assert(true)
+ } catch {
+ case e: Exception => {
+ assert(false)
+ }
+ }
+ })
+ assert(numReceivedServerMessages == 10)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+
+
+}
+
http://git-wip-us.apache.org/repos/asf/spark/blob/38ccd6eb/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
new file mode 100644
index 0000000..be972c5
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.io.File
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.{LongWritable, Text}
+import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat}
+import org.apache.spark._
+import org.scalatest.FunSuite
+
+import scala.collection.Map
+import scala.language.postfixOps
+import scala.sys.process._
+import scala.util.Try
+
+class PipedRDDSuite extends FunSuite with SharedSparkContext {
+
+ test("basic pipe") {
+ if (testCommandAvailable("cat")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+
+ val piped = nums.pipe(Seq("cat"))
+
+ val c = piped.collect()
+ assert(c.size === 4)
+ assert(c(0) === "1")
+ assert(c(1) === "2")
+ assert(c(2) === "3")
+ assert(c(3) === "4")
+ } else {
+ assert(true)
+ }
+ }
+
+ test("advanced pipe") {
+ if (testCommandAvailable("cat")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val bl = sc.broadcast(List("0"))
+
+ val piped = nums.pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {
+ bl.value.map(f(_)); f("\u0001")
+ },
+ (i: Int, f: String => Unit) => f(i + "_"))
+
+ val c = piped.collect()
+
+ assert(c.size === 8)
+ assert(c(0) === "0")
+ assert(c(1) === "\u0001")
+ assert(c(2) === "1_")
+ assert(c(3) === "2_")
+ assert(c(4) === "0")
+ assert(c(5) === "\u0001")
+ assert(c(6) === "3_")
+ assert(c(7) === "4_")
+
+ val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
+ val d = nums1.groupBy(str => str.split("\t")(0)).
+ pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {
+ bl.value.map(f(_)); f("\u0001")
+ },
+ (i: Tuple2[String, Iterable[String]], f: String => Unit) => {
+ for (e <- i._2) {
+ f(e + "_")
+ }
+ }).collect()
+ assert(d.size === 8)
+ assert(d(0) === "0")
+ assert(d(1) === "\u0001")
+ assert(d(2) === "b\t2_")
+ assert(d(3) === "b\t4_")
+ assert(d(4) === "0")
+ assert(d(5) === "\u0001")
+ assert(d(6) === "a\t1_")
+ assert(d(7) === "a\t3_")
+ } else {
+ assert(true)
+ }
+ }
+
+ test("pipe with env variable") {
+ if (testCommandAvailable("printenv")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
+ val c = piped.collect()
+ assert(c.size === 2)
+ assert(c(0) === "LALALA")
+ assert(c(1) === "LALALA")
+ } else {
+ assert(true)
+ }
+ }
+
+ test("pipe with non-zero exit status") {
+ if (testCommandAvailable("cat")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null"))
+ intercept[SparkException] {
+ piped.collect()
+ }
+ } else {
+ assert(true)
+ }
+ }
+
+ test("basic pipe with separate working directory") {
+ if (testCommandAvailable("cat")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("cat"), separateWorkingDir = true)
+ val c = piped.collect()
+ assert(c.size === 4)
+ assert(c(0) === "1")
+ assert(c(1) === "2")
+ assert(c(2) === "3")
+ assert(c(3) === "4")
+ val pipedPwd = nums.pipe(Seq("pwd"), separateWorkingDir = true)
+ val collectPwd = pipedPwd.collect()
+ assert(collectPwd(0).contains("tasks/"))
+ val pipedLs = nums.pipe(Seq("ls"), separateWorkingDir = true).collect()
+ // make sure symlinks were created
+ assert(pipedLs.length > 0)
+ // clean up top level tasks directory
+ new File("tasks").delete()
+ } else {
+ assert(true)
+ }
+ }
+
+ test("test pipe exports map_input_file") {
+ testExportInputFile("map_input_file")
+ }
+
+ test("test pipe exports mapreduce_map_input_file") {
+ testExportInputFile("mapreduce_map_input_file")
+ }
+
+ def testCommandAvailable(command: String): Boolean = {
+ Try(Process(command) !!).isSuccess
+ }
+
+ def testExportInputFile(varName: String) {
+ if (testCommandAvailable("printenv")) {
+ val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable],
+ classOf[Text], 2) {
+ override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition())
+
+ override val getDependencies = List[Dependency[_]]()
+
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1),
+ new Text("b"))))
+ }
+ }
+ val hadoopPart1 = generateFakeHadoopPartition()
+ val pipedRdd = new PipedRDD(nums, "printenv " + varName)
+ val tContext = new TaskContext(0, 0, 0)
+ val rddIter = pipedRdd.compute(hadoopPart1, tContext)
+ val arr = rddIter.toArray
+ assert(arr(0) == "/some/path")
+ } else {
+ // printenv isn't available so just pass the test
+ }
+ }
+
+ def generateFakeHadoopPartition(): HadoopPartition = {
+ val split = new FileSplit(new Path("/some/path"), 0, 1,
+ Array[String]("loc1", "loc2", "loc3", "loc4", "loc5"))
+ new HadoopPartition(sc.newRddId(), 1, split)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/38ccd6eb/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala
new file mode 100644
index 0000000..72596e8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.SharedSparkContext
+import org.scalatest.FunSuite
+
+object ZippedPartitionsSuite {
+ def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = {
+ Iterator(i.toArray.size, s.toArray.size, d.toArray.size)
+ }
+}
+
+class ZippedPartitionsSuite extends FunSuite with SharedSparkContext {
+ test("print sizes") {
+ val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2)
+ val data3 = sc.makeRDD(Array(1.0, 2.0), 2)
+
+ val zippedRDD = data1.zipPartitions(data2, data3)(ZippedPartitionsSuite.procZippedData)
+
+ val obtainedSizes = zippedRDD.collect()
+ val expectedSizes = Array(2, 3, 1, 2, 3, 1)
+ assert(obtainedSizes.size == 6)
+ assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2))
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/38ccd6eb/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
new file mode 100644
index 0000000..c4765e5
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import akka.actor._
+import org.apache.spark._
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+import org.scalatest.FunSuite
+
+import scala.concurrent.Await
+
+/**
+ * Test the AkkaUtils with various security settings.
+ */
+class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
+
+ test("remote fetch security bad password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+
+ val securityManager = new SecurityManager(conf);
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "bad")
+ val securityManagerBad = new SecurityManager(badconf)
+
+ assert(securityManagerBad.isAuthenticationEnabled() === true)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = conf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ intercept[akka.actor.ActorNotFound] {
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ }
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ conf.set("spark.authenticate.secret", "bad")
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === false)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ badconf.set("spark.authenticate.secret", "good")
+ val securityManagerBad = new SecurityManager(badconf);
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = badconf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+
+ assert(securityManagerBad.isAuthenticationEnabled() === false)
+
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ // this should succeed since security off
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security pass") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+
+ val goodconf = new SparkConf
+ goodconf.set("spark.authenticate", "true")
+ goodconf.set("spark.authenticate.secret", "good")
+ val securityManagerGood = new SecurityManager(goodconf);
+
+ assert(securityManagerGood.isAuthenticationEnabled() === true)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = goodconf, securityManager = securityManagerGood)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ // this should succeed since security on and passwords match
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security off client") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ badconf.set("spark.authenticate.secret", "bad")
+ val securityManagerBad = new SecurityManager(badconf);
+
+ assert(securityManagerBad.isAuthenticationEnabled() === false)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = badconf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ intercept[akka.actor.ActorNotFound] {
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ }
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+}