You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2020/04/17 18:28:06 UTC
[spark] branch master updated: Apply appropriate RPC handler to
receive, receiveStream when auth enabled
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 61b7d44 Apply appropriate RPC handler to receive, receiveStream when auth enabled
61b7d44 is described below
commit 61b7d446b37cecc45e6d274bbfdde3b745bf068f
Author: Sean Owen <sr...@gmail.com>
AuthorDate: Fri Apr 17 13:25:12 2020 -0500
Apply appropriate RPC handler to receive, receiveStream when auth enabled
---
.../spark/network/crypto/AuthRpcHandler.java | 73 +++-----------
.../apache/spark/network/sasl/SaslRpcHandler.java | 60 +++---------
.../network/server/AbstractAuthRpcHandler.java | 107 +++++++++++++++++++++
.../spark/network/crypto/AuthIntegrationSuite.java | 12 +--
.../apache/spark/network/sasl/SparkSaslSuite.java | 3 +-
5 files changed, 142 insertions(+), 113 deletions(-)
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
index 821cc7a..dd31c95 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
@@ -29,12 +29,11 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.sasl.SaslRpcHandler;
+import org.apache.spark.network.server.AbstractAuthRpcHandler;
import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.TransportConf;
/**
@@ -46,7 +45,7 @@ import org.apache.spark.network.util.TransportConf;
* The delegate will only receive messages if the given connection has been successfully
* authenticated. A connection may be authenticated at most once.
*/
-class AuthRpcHandler extends RpcHandler {
+class AuthRpcHandler extends AbstractAuthRpcHandler {
private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class);
/** Transport configuration. */
@@ -55,36 +54,31 @@ class AuthRpcHandler extends RpcHandler {
/** The client channel. */
private final Channel channel;
- /**
- * RpcHandler we will delegate to for authenticated connections. When falling back to SASL
- * this will be replaced with the SASL RPC handler.
- */
- @VisibleForTesting
- RpcHandler delegate;
-
/** Class which provides secret keys which are shared by server and client on a per-app basis. */
private final SecretKeyHolder secretKeyHolder;
- /** Whether auth is done and future calls should be delegated. */
+ /** RPC handler for auth handshake when falling back to SASL auth. */
@VisibleForTesting
- boolean doDelegate;
+ SaslRpcHandler saslHandler;
AuthRpcHandler(
TransportConf conf,
Channel channel,
RpcHandler delegate,
SecretKeyHolder secretKeyHolder) {
+ super(delegate);
this.conf = conf;
this.channel = channel;
- this.delegate = delegate;
this.secretKeyHolder = secretKeyHolder;
}
@Override
- public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
- if (doDelegate) {
- delegate.receive(client, message, callback);
- return;
+ protected boolean doAuthChallenge(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ if (saslHandler != null) {
+ return saslHandler.doAuthChallenge(client, message, callback);
}
int position = message.position();
@@ -98,18 +92,17 @@ class AuthRpcHandler extends RpcHandler {
if (conf.saslFallback()) {
LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.",
channel.remoteAddress());
- delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder);
+ saslHandler = new SaslRpcHandler(conf, channel, null, secretKeyHolder);
message.position(position);
message.limit(limit);
- delegate.receive(client, message, callback);
- doDelegate = true;
+ return saslHandler.doAuthChallenge(client, message, callback);
} else {
LOG.debug("Unexpected challenge message from client {}, closing channel.",
channel.remoteAddress());
callback.onFailure(new IllegalArgumentException("Unknown challenge message."));
channel.close();
}
- return;
+ return false;
}
// Here we have the client challenge, so perform the new auth protocol and set up the channel.
@@ -131,7 +124,7 @@ class AuthRpcHandler extends RpcHandler {
LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
callback.onFailure(new IllegalArgumentException("Authentication failed."));
channel.close();
- return;
+ return false;
} finally {
if (engine != null) {
try {
@@ -143,40 +136,6 @@ class AuthRpcHandler extends RpcHandler {
}
LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
- doDelegate = true;
- }
-
- @Override
- public void receive(TransportClient client, ByteBuffer message) {
- delegate.receive(client, message);
- }
-
- @Override
- public StreamCallbackWithID receiveStream(
- TransportClient client,
- ByteBuffer message,
- RpcResponseCallback callback) {
- return delegate.receiveStream(client, message, callback);
+ return true;
}
-
- @Override
- public StreamManager getStreamManager() {
- return delegate.getStreamManager();
- }
-
- @Override
- public void channelActive(TransportClient client) {
- delegate.channelActive(client);
- }
-
- @Override
- public void channelInactive(TransportClient client) {
- delegate.channelInactive(client);
- }
-
- @Override
- public void exceptionCaught(Throwable cause, TransportClient client) {
- delegate.exceptionCaught(cause, client);
- }
-
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index 355a3de..cc9e88f 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -28,10 +28,9 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.AbstractAuthRpcHandler;
import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;
@@ -43,7 +42,7 @@ import org.apache.spark.network.util.TransportConf;
* Note that the authentication process consists of multiple challenge-response pairs, each of
* which are individual RPCs.
*/
-public class SaslRpcHandler extends RpcHandler {
+public class SaslRpcHandler extends AbstractAuthRpcHandler {
private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
/** Transport configuration. */
@@ -52,37 +51,28 @@ public class SaslRpcHandler extends RpcHandler {
/** The client channel. */
private final Channel channel;
- /** RpcHandler we will delegate to for authenticated connections. */
- private final RpcHandler delegate;
-
/** Class which provides secret keys which are shared by server and client on a per-app basis. */
private final SecretKeyHolder secretKeyHolder;
private SparkSaslServer saslServer;
- private boolean isComplete;
- private boolean isAuthenticated;
public SaslRpcHandler(
TransportConf conf,
Channel channel,
RpcHandler delegate,
SecretKeyHolder secretKeyHolder) {
+ super(delegate);
this.conf = conf;
this.channel = channel;
- this.delegate = delegate;
this.secretKeyHolder = secretKeyHolder;
this.saslServer = null;
- this.isComplete = false;
- this.isAuthenticated = false;
}
@Override
- public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
- if (isComplete) {
- // Authentication complete, delegate to base handler.
- delegate.receive(client, message, callback);
- return;
- }
+ public boolean doAuthChallenge(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
if (saslServer == null || !saslServer.isComplete()) {
ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
SaslMessage saslMessage;
@@ -118,43 +108,21 @@ public class SaslRpcHandler extends RpcHandler {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
logger.debug("SASL authentication successful for channel {}", client);
complete(true);
- return;
+ return true;
}
logger.debug("Enabling encryption for channel {}", client);
SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
complete(false);
- return;
+ return true;
}
- }
-
- @Override
- public void receive(TransportClient client, ByteBuffer message) {
- delegate.receive(client, message);
- }
-
- @Override
- public StreamCallbackWithID receiveStream(
- TransportClient client,
- ByteBuffer message,
- RpcResponseCallback callback) {
- return delegate.receiveStream(client, message, callback);
- }
-
- @Override
- public StreamManager getStreamManager() {
- return delegate.getStreamManager();
- }
-
- @Override
- public void channelActive(TransportClient client) {
- delegate.channelActive(client);
+ return false;
}
@Override
public void channelInactive(TransportClient client) {
try {
- delegate.channelInactive(client);
+ super.channelInactive(client);
} finally {
if (saslServer != null) {
saslServer.dispose();
@@ -162,11 +130,6 @@ public class SaslRpcHandler extends RpcHandler {
}
}
- @Override
- public void exceptionCaught(Throwable cause, TransportClient client) {
- delegate.exceptionCaught(cause, client);
- }
-
private void complete(boolean dispose) {
if (dispose) {
try {
@@ -177,7 +140,6 @@ public class SaslRpcHandler extends RpcHandler {
}
saslServer = null;
- isComplete = true;
}
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java
new file mode 100644
index 0000000..92eb886
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.nio.ByteBuffer;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.StreamCallbackWithID;
+import org.apache.spark.network.client.TransportClient;
+
+/**
+ * RPC Handler which performs authentication, and when it's successful, delegates further
+ * calls to another RPC handler. The authentication handshake itself should be implemented
+ * by subclasses.
+ */
+public abstract class AbstractAuthRpcHandler extends RpcHandler {
+ /** RpcHandler we will delegate to for authenticated connections. */
+ private final RpcHandler delegate;
+
+ private boolean isAuthenticated;
+
+ protected AbstractAuthRpcHandler(RpcHandler delegate) {
+ this.delegate = delegate;
+ }
+
+ /**
+ * Responds to an authentication challenge.
+ *
+ * @return Whether the client is authenticated.
+ */
+ protected abstract boolean doAuthChallenge(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback);
+
+ @Override
+ public final void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ if (isAuthenticated) {
+ delegate.receive(client, message, callback);
+ } else {
+ isAuthenticated = doAuthChallenge(client, message, callback);
+ }
+ }
+
+ @Override
+ public final void receive(TransportClient client, ByteBuffer message) {
+ if (isAuthenticated) {
+ delegate.receive(client, message);
+ } else {
+ throw new SecurityException("Unauthenticated call to receive().");
+ }
+ }
+
+ @Override
+ public final StreamCallbackWithID receiveStream(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ if (isAuthenticated) {
+ return delegate.receiveStream(client, message, callback);
+ } else {
+ throw new SecurityException("Unauthenticated call to receiveStream().");
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return delegate.getStreamManager();
+ }
+
+ @Override
+ public void channelActive(TransportClient client) {
+ delegate.channelActive(client);
+ }
+
+ @Override
+ public void channelInactive(TransportClient client) {
+ delegate.channelInactive(client);
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause, TransportClient client) {
+ delegate.exceptionCaught(cause, client);
+ }
+
+ public boolean isAuthenticated() {
+ return isAuthenticated;
+ }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
index 2f9dd62..a87a6aa 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
@@ -34,7 +34,6 @@ import org.apache.spark.network.TransportContext;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
-import org.apache.spark.network.sasl.SaslRpcHandler;
import org.apache.spark.network.sasl.SaslServerBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.RpcHandler;
@@ -65,8 +64,7 @@ public class AuthIntegrationSuite {
ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
assertEquals("Pong", JavaUtils.bytesToString(reply));
- assertTrue(ctx.authRpcHandler.doDelegate);
- assertFalse(ctx.authRpcHandler.delegate instanceof SaslRpcHandler);
+ assertNull(ctx.authRpcHandler.saslHandler);
}
@Test
@@ -78,7 +76,7 @@ public class AuthIntegrationSuite {
ctx.createClient("client");
fail("Should have failed to create client.");
} catch (Exception e) {
- assertFalse(ctx.authRpcHandler.doDelegate);
+ assertFalse(ctx.authRpcHandler.isAuthenticated());
assertFalse(ctx.serverChannel.isActive());
}
}
@@ -91,6 +89,8 @@ public class AuthIntegrationSuite {
ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
assertEquals("Pong", JavaUtils.bytesToString(reply));
+ assertNotNull(ctx.authRpcHandler.saslHandler);
+ assertTrue(ctx.authRpcHandler.isAuthenticated());
}
@Test
@@ -120,7 +120,7 @@ public class AuthIntegrationSuite {
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
fail("Should have failed unencrypted RPC.");
} catch (Exception e) {
- assertTrue(ctx.authRpcHandler.doDelegate);
+ assertTrue(ctx.authRpcHandler.isAuthenticated());
}
}
@@ -151,7 +151,7 @@ public class AuthIntegrationSuite {
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
fail("Should have failed unencrypted RPC.");
} catch (Exception e) {
- assertTrue(ctx.authRpcHandler.doDelegate);
+ assertTrue(ctx.authRpcHandler.isAuthenticated());
assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD"));
// Verify we receive the complete error message
int messageStart = e.getMessage().indexOf("DDDDD");
diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index cf2d72f..ecaeec9 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -357,7 +357,8 @@ public class SparkSaslSuite {
public void testDelegates() throws Exception {
Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods();
for (Method m : rpcHandlerMethods) {
- SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes());
+ Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes());
+ assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class);
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org