You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by va...@apache.org on 2015/09/02 21:53:31 UTC

spark git commit: [SPARK-10004] [SHUFFLE] Perform auth checks when clients read shuffle data.

Repository: spark
Updated Branches:
  refs/heads/master fc4830779 -> 2da3a9e98


[SPARK-10004] [SHUFFLE] Perform auth checks when clients read shuffle data.

To correctly isolate applications, when requests to read shuffle data
arrive at the shuffle service, proper authorization checks need to
be performed. This change makes sure that only the application that
created the shuffle data can read from it.

Such checks are only enabled when "spark.authenticate" is enabled,
otherwise there's no secure way to make sure that the client is really
who it says it is.

Author: Marcelo Vanzin <va...@cloudera.com>

Closes #8218 from vanzin/SPARK-10004.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2da3a9e9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2da3a9e9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2da3a9e9

Branch: refs/heads/master
Commit: 2da3a9e98e5d129d4507b5db01bba5ee9558d28e
Parents: fc48307
Author: Marcelo Vanzin <va...@cloudera.com>
Authored: Wed Sep 2 12:53:24 2015 -0700
Committer: Marcelo Vanzin <va...@cloudera.com>
Committed: Wed Sep 2 12:53:24 2015 -0700

----------------------------------------------------------------------
 .../network/netty/NettyBlockRpcServer.scala     |   3 +-
 .../netty/NettyBlockTransferService.scala       |   2 +-
 network/common/pom.xml                          |   4 +
 .../spark/network/client/TransportClient.java   |  22 +++
 .../spark/network/sasl/SaslClientBootstrap.java |   2 +
 .../spark/network/sasl/SaslRpcHandler.java      |   1 +
 .../network/server/OneForOneStreamManager.java  |  31 +++-
 .../spark/network/server/StreamManager.java     |   9 +
 .../network/server/TransportRequestHandler.java |   1 +
 .../shuffle/ExternalShuffleBlockHandler.java    |  16 +-
 .../network/sasl/SaslIntegrationSuite.java      | 163 ++++++++++++++++---
 .../ExternalShuffleBlockHandlerSuite.java       |   2 +-
 project/MimaExcludes.scala                      |   1 +
 13 files changed, 221 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 7c170a7..7696824 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel}
  * is equivalent to one Spark-level shuffle block.
  */
 class NettyBlockRpcServer(
+    appId: String,
     serializer: Serializer,
     blockManager: BlockDataManager)
   extends RpcHandler with Logging {
@@ -55,7 +56,7 @@ class NettyBlockRpcServer(
       case openBlocks: OpenBlocks =>
         val blocks: Seq[ManagedBuffer] =
           openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
-        val streamId = streamManager.registerStream(blocks.iterator.asJava)
+        val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
         logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
         responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index ff8aae9..d5ad2c9 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -49,7 +49,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
   private[this] var appId: String = _
 
   override def init(blockDataManager: BlockDataManager): Unit = {
-    val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+    val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager)
     var serverBootstrap: Option[TransportServerBootstrap] = None
     var clientBootstrap: Option[TransportClientBootstrap] = None
     if (authEnabled) {

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/network/common/pom.xml
----------------------------------------------------------------------
diff --git a/network/common/pom.xml b/network/common/pom.xml
index 7dc3068..4141fcb 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -48,6 +48,10 @@
       <artifactId>slf4j-api</artifactId>
       <scope>provided</scope>
     </dependency>
+    <dependency>
+      <groupId>com.google.code.findbugs</groupId>
+      <artifactId>jsr305</artifactId>
+    </dependency>
     <!--
       Promote Guava to "compile" so that maven-shade-plugin picks it up (for packaging the Optional
       class exposed in the Java API). The plugin will then remove this dependency from the published

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index e8e7f06..df84128 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -23,6 +23,7 @@ import java.net.SocketAddress;
 import java.util.UUID;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
 
 import com.google.common.base.Objects;
 import com.google.common.base.Preconditions;
@@ -70,6 +71,7 @@ public class TransportClient implements Closeable {
 
   private final Channel channel;
   private final TransportResponseHandler handler;
+  @Nullable private String clientId;
 
   public TransportClient(Channel channel, TransportResponseHandler handler) {
     this.channel = Preconditions.checkNotNull(channel);
@@ -85,6 +87,25 @@ public class TransportClient implements Closeable {
   }
 
   /**
+   * Returns the ID used by the client to authenticate itself when authentication is enabled.
+   *
+   * @return The client ID, or null if authentication is disabled.
+   */
+  public String getClientId() {
+    return clientId;
+  }
+
+  /**
+   * Sets the authenticated client ID. This is meant to be used by the authentication layer.
+   *
+   * Trying to set a different client ID after it's been set will result in an exception.
+   */
+  public void setClientId(String id) {
+    Preconditions.checkState(clientId == null, "Client ID has already been set.");
+    this.clientId = id;
+  }
+
+  /**
    * Requests a single chunk from the remote side, from the pre-negotiated streamId.
    *
    * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
@@ -207,6 +228,7 @@ public class TransportClient implements Closeable {
   public String toString() {
     return Objects.toStringHelper(this)
       .add("remoteAdress", channel.remoteAddress())
+      .add("clientId", clientId)
       .add("isActive", isActive())
       .toString();
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 185ba2e..6992376 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -77,6 +77,8 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
         payload = saslClient.response(response);
       }
 
+      client.setClientId(appId);
+
       if (encrypt) {
         if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
           throw new RuntimeException(

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index be6165c..3f2ebe3 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -81,6 +81,7 @@ class SaslRpcHandler extends RpcHandler {
 
     if (saslServer == null) {
       // First message in the handshake, setup the necessary state.
+      client.setClientId(saslMessage.appId);
       saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
         conf.saslServerAlwaysEncrypt());
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index c95e64e..e671854 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -24,13 +24,13 @@ import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
 
+import com.google.common.base.Preconditions;
 import io.netty.channel.Channel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.buffer.ManagedBuffer;
-
-import com.google.common.base.Preconditions;
+import org.apache.spark.network.client.TransportClient;
 
 /**
  * StreamManager which allows registration of an Iterator&lt;ManagedBuffer&gt;, which are individually
@@ -44,6 +44,7 @@ public class OneForOneStreamManager extends StreamManager {
 
   /** State of a single stream. */
   private static class StreamState {
+    final String appId;
     final Iterator<ManagedBuffer> buffers;
 
     // The channel associated to the stream
@@ -53,7 +54,8 @@ public class OneForOneStreamManager extends StreamManager {
     // that the caller only requests each chunk one at a time, in order.
     int curChunk = 0;
 
-    StreamState(Iterator<ManagedBuffer> buffers) {
+    StreamState(String appId, Iterator<ManagedBuffer> buffers) {
+      this.appId = appId;
       this.buffers = Preconditions.checkNotNull(buffers);
     }
   }
@@ -109,15 +111,34 @@ public class OneForOneStreamManager extends StreamManager {
     }
   }
 
+  @Override
+  public void checkAuthorization(TransportClient client, long streamId) {
+    if (client.getClientId() != null) {
+      StreamState state = streams.get(streamId);
+      Preconditions.checkArgument(state != null, "Unknown stream ID.");
+      if (!client.getClientId().equals(state.appId)) {
+        throw new SecurityException(String.format(
+          "Client %s not authorized to read stream %d (app %s).",
+          client.getClientId(),
+          streamId,
+          state.appId));
+      }
+    }
+  }
+
   /**
    * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
    * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
    * client connection is closed before the iterator is fully drained, then the remaining buffers
    * will all be release()'d.
+   *
+   * If an app ID is provided, only callers who've authenticated with the given app ID will be
+   * allowed to fetch from this stream.
    */
-  public long registerStream(Iterator<ManagedBuffer> buffers) {
+  public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
     long myStreamId = nextStreamId.getAndIncrement();
-    streams.put(myStreamId, new StreamState(buffers));
+    streams.put(myStreamId, new StreamState(appId, buffers));
     return myStreamId;
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
index 929f789..aaa677c 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -20,6 +20,7 @@ package org.apache.spark.network.server;
 import io.netty.channel.Channel;
 
 import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.TransportClient;
 
 /**
  * The StreamManager is used to fetch individual chunks from a stream. This is used in
@@ -60,4 +61,12 @@ public abstract class StreamManager {
    * to read from the associated streams again, so any state can be cleaned up.
    */
   public void connectionTerminated(Channel channel) { }
+
+  /**
+   * Verify that the client is authorized to read from the given stream.
+   *
+   * @throws SecurityException If client is not authorized.
+   */
+  public void checkAuthorization(TransportClient client, long streamId) { }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index e5159ab..df60278 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -97,6 +97,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
 
     ManagedBuffer buf;
     try {
+      streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
       streamManager.registerChannel(channel, req.streamChunkId.streamId);
       buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
     } catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 0df1dd6..3ddf5c3 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -58,7 +58,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
 
   /** Enables mocking out the StreamManager and BlockManager. */
   @VisibleForTesting
-  ExternalShuffleBlockHandler(
+  public ExternalShuffleBlockHandler(
       OneForOneStreamManager streamManager,
       ExternalShuffleBlockResolver blockManager) {
     this.streamManager = streamManager;
@@ -77,17 +77,19 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
       RpcResponseCallback callback) {
     if (msgObj instanceof OpenBlocks) {
       OpenBlocks msg = (OpenBlocks) msgObj;
-      List<ManagedBuffer> blocks = Lists.newArrayList();
+      checkAuth(client, msg.appId);
 
+      List<ManagedBuffer> blocks = Lists.newArrayList();
       for (String blockId : msg.blockIds) {
         blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
       }
-      long streamId = streamManager.registerStream(blocks.iterator());
+      long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
       logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
       callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
 
     } else if (msgObj instanceof RegisterExecutor) {
       RegisterExecutor msg = (RegisterExecutor) msgObj;
+      checkAuth(client, msg.appId);
       blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
       callback.onSuccess(new byte[0]);
 
@@ -126,4 +128,12 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
   public void close() {
     blockManager.close();
   }
+
+  private void checkAuth(TransportClient client, String appId) {
+    if (client.getClientId() != null && !client.getClientId().equals(appId)) {
+      throw new SecurityException(String.format(
+        "Client for %s not authorized for application %s.", client.getClientId(), appId));
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 382f613..5cb0e4d 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.sasl;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.concurrent.atomic.AtomicReference;
 
 import com.google.common.collect.Lists;
 import org.junit.After;
@@ -27,9 +28,12 @@ import org.junit.BeforeClass;
 import org.junit.Test;
 
 import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
 
 import org.apache.spark.network.TestUtils;
 import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
 import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.client.TransportClientBootstrap;
@@ -39,44 +43,39 @@ import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.StreamManager;
 import org.apache.spark.network.server.TransportServer;
 import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.shuffle.BlockFetchingListener;
 import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver;
+import org.apache.spark.network.shuffle.OneForOneBlockFetcher;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
 import org.apache.spark.network.util.SystemPropertyConfigProvider;
 import org.apache.spark.network.util.TransportConf;
 
 public class SaslIntegrationSuite {
-  static ExternalShuffleBlockHandler handler;
   static TransportServer server;
   static TransportConf conf;
   static TransportContext context;
+  static SecretKeyHolder secretKeyHolder;
 
   TransportClientFactory clientFactory;
 
-  /** Provides a secret key holder which always returns the given secret key. */
-  static class TestSecretKeyHolder implements SecretKeyHolder {
-
-    private final String secretKey;
-
-    TestSecretKeyHolder(String secretKey) {
-      this.secretKey = secretKey;
-    }
-
-    @Override
-    public String getSaslUser(String appId) {
-      return "user";
-    }
-    @Override
-    public String getSecretKey(String appId) {
-      return secretKey;
-    }
-  }
-
-
   @BeforeClass
   public static void beforeAll() throws IOException {
-    SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
     conf = new TransportConf(new SystemPropertyConfigProvider());
     context = new TransportContext(conf, new TestRpcHandler());
 
+    secretKeyHolder = mock(SecretKeyHolder.class);
+    when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
+    when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
+    when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
+    when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
+    when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
+    when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password");
+
     TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
     server = context.createServer(Arrays.asList(bootstrap));
   }
@@ -99,7 +98,7 @@ public class SaslIntegrationSuite {
   public void testGoodClient() throws IOException {
     clientFactory = context.createClientFactory(
       Lists.<TransportClientBootstrap>newArrayList(
-        new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
+        new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
 
     TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
     String msg = "Hello, World!";
@@ -109,13 +108,17 @@ public class SaslIntegrationSuite {
 
   @Test
   public void testBadClient() {
+    SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class);
+    when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
+    when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password");
     clientFactory = context.createClientFactory(
       Lists.<TransportClientBootstrap>newArrayList(
-        new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key"))));
+        new SaslClientBootstrap(conf, "unknown-app", badKeyHolder)));
 
     try {
       // Bootstrap should fail on startup.
       clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+      fail("Connection should have failed.");
     } catch (Exception e) {
       assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
     }
@@ -149,7 +152,7 @@ public class SaslIntegrationSuite {
     TransportContext context = new TransportContext(conf, handler);
     clientFactory = context.createClientFactory(
       Lists.<TransportClientBootstrap>newArrayList(
-        new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key"))));
+        new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
     TransportServer server = context.createServer();
     try {
       clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
@@ -160,6 +163,110 @@ public class SaslIntegrationSuite {
     }
   }
 
+  /**
+   * This test is not actually testing SASL behavior, but testing that the shuffle service
+   * performs correct authorization checks based on the SASL authentication data.
+   */
+  @Test
+  public void testAppIsolation() throws Exception {
+    // Start a new server with the correct RPC handler to serve block data.
+    ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
+    ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler(
+      new OneForOneStreamManager(), blockResolver);
+    TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
+    TransportContext blockServerContext = new TransportContext(conf, blockHandler);
+    TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
+
+    TransportClient client1 = null;
+    TransportClient client2 = null;
+    TransportClientFactory clientFactory2 = null;
+    try {
+      // Create a client, and make a request to fetch blocks from a different app.
+      clientFactory = blockServerContext.createClientFactory(
+        Lists.<TransportClientBootstrap>newArrayList(
+          new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
+      client1 = clientFactory.createClient(TestUtils.getLocalHost(),
+        blockServer.getPort());
+
+      final AtomicReference<Throwable> exception = new AtomicReference<>();
+
+      BlockFetchingListener listener = new BlockFetchingListener() {
+        @Override
+        public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+          notifyAll();
+        }
+
+        @Override
+        public synchronized void onBlockFetchFailure(String blockId, Throwable t) {
+          exception.set(t);
+          notifyAll();
+        }
+      };
+
+      String[] blockIds = new String[] { "shuffle_2_3_4", "shuffle_6_7_8" };
+      OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0",
+        blockIds, listener);
+      synchronized (listener) {
+        fetcher.start();
+        listener.wait();
+      }
+      checkSecurityException(exception.get());
+
+      // Register an executor so that the next steps work.
+      ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
+        new String[] { System.getProperty("java.io.tmpdir") }, 1,
+        "org.apache.spark.shuffle.sort.SortShuffleManager");
+      RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
+      client1.sendRpcSync(regmsg.toByteArray(), 10000);
+
+      // Make a successful request to fetch blocks, which creates a new stream. But do not actually
+      // fetch any blocks, to keep the stream open.
+      OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
+      byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000);
+      StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
+      long streamId = stream.streamId;
+
+      // Create a second client, authenticated with a different app ID, and try to read from
+      // the stream created for the previous app.
+      clientFactory2 = blockServerContext.createClientFactory(
+        Lists.<TransportClientBootstrap>newArrayList(
+          new SaslClientBootstrap(conf, "app-2", secretKeyHolder)));
+      client2 = clientFactory2.createClient(TestUtils.getLocalHost(),
+        blockServer.getPort());
+
+      ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+        @Override
+        public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+          notifyAll();
+        }
+
+        @Override
+        public synchronized void onFailure(int chunkIndex, Throwable t) {
+          exception.set(t);
+          notifyAll();
+        }
+      };
+
+      exception.set(null);
+      synchronized (callback) {
+        client2.fetchChunk(streamId, 0, callback);
+        callback.wait();
+      }
+      checkSecurityException(exception.get());
+    } finally {
+      if (client1 != null) {
+        client1.close();
+      }
+      if (client2 != null) {
+        client2.close();
+      }
+      if (clientFactory2 != null) {
+        clientFactory2.close();
+      }
+      blockServer.close();
+    }
+  }
+
   /** RPC handler which simply responds with the message it received. */
   public static class TestRpcHandler extends RpcHandler {
     @Override
@@ -172,4 +279,10 @@ public class SaslIntegrationSuite {
       return new OneForOneStreamManager();
     }
   }
+
+  private void checkSecurityException(Throwable t) {
+    assertNotNull("No exception was caught.", t);
+    assertTrue("Expected SecurityException.",
+      t.getMessage().contains(SecurityException.class.getName()));
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index 1d19749..e61390c 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -93,7 +93,7 @@ public class ExternalShuffleBlockHandlerSuite {
     @SuppressWarnings("unchecked")
     ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
         (ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
-    verify(streamManager, times(1)).registerStream(stream.capture());
+    verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
     Iterator<ManagedBuffer> buffers = stream.getValue();
     assertEquals(block0Marker, buffers.next());
     assertEquals(block1Marker, buffers.next());

http://git-wip-us.apache.org/repos/asf/spark/blob/2da3a9e9/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 88745dc..714ce3c 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -37,6 +37,7 @@ object MimaExcludes {
         case v if v.startsWith("1.5") =>
           Seq(
             MimaBuild.excludeSparkPackage("deploy"),
+            MimaBuild.excludeSparkPackage("network"),
             // These are needed if checking against the sbt build, since they are part of
             // the maven-generated artifacts in 1.3.
             excludePackage("org.spark-project.jetty"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org