You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2020/04/17 18:28:06 UTC

[spark] branch master updated: Apply appropriate RPC handler to receive, receiveStream when auth enabled

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 61b7d44  Apply appropriate RPC handler to receive, receiveStream when auth enabled
61b7d44 is described below

commit 61b7d446b37cecc45e6d274bbfdde3b745bf068f
Author: Sean Owen <sr...@gmail.com>
AuthorDate: Fri Apr 17 13:25:12 2020 -0500

    Apply appropriate RPC handler to receive, receiveStream when auth enabled
---
 .../spark/network/crypto/AuthRpcHandler.java       |  73 +++-----------
 .../apache/spark/network/sasl/SaslRpcHandler.java  |  60 +++---------
 .../network/server/AbstractAuthRpcHandler.java     | 107 +++++++++++++++++++++
 .../spark/network/crypto/AuthIntegrationSuite.java |  12 +--
 .../apache/spark/network/sasl/SparkSaslSuite.java  |   3 +-
 5 files changed, 142 insertions(+), 113 deletions(-)

diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
index 821cc7a..dd31c95 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
@@ -29,12 +29,11 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.StreamCallbackWithID;
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.sasl.SecretKeyHolder;
 import org.apache.spark.network.sasl.SaslRpcHandler;
+import org.apache.spark.network.server.AbstractAuthRpcHandler;
 import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.StreamManager;
 import org.apache.spark.network.util.TransportConf;
 
 /**
@@ -46,7 +45,7 @@ import org.apache.spark.network.util.TransportConf;
  * The delegate will only receive messages if the given connection has been successfully
  * authenticated. A connection may be authenticated at most once.
  */
-class AuthRpcHandler extends RpcHandler {
+class AuthRpcHandler extends AbstractAuthRpcHandler {
   private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class);
 
   /** Transport configuration. */
@@ -55,36 +54,31 @@ class AuthRpcHandler extends RpcHandler {
   /** The client channel. */
   private final Channel channel;
 
-  /**
-   * RpcHandler we will delegate to for authenticated connections. When falling back to SASL
-   * this will be replaced with the SASL RPC handler.
-   */
-  @VisibleForTesting
-  RpcHandler delegate;
-
   /** Class which provides secret keys which are shared by server and client on a per-app basis. */
   private final SecretKeyHolder secretKeyHolder;
 
-  /** Whether auth is done and future calls should be delegated. */
+  /** RPC handler for auth handshake when falling back to SASL auth. */
   @VisibleForTesting
-  boolean doDelegate;
+  SaslRpcHandler saslHandler;
 
   AuthRpcHandler(
       TransportConf conf,
       Channel channel,
       RpcHandler delegate,
       SecretKeyHolder secretKeyHolder) {
+    super(delegate);
     this.conf = conf;
     this.channel = channel;
-    this.delegate = delegate;
     this.secretKeyHolder = secretKeyHolder;
   }
 
   @Override
-  public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
-    if (doDelegate) {
-      delegate.receive(client, message, callback);
-      return;
+  protected boolean doAuthChallenge(
+      TransportClient client,
+      ByteBuffer message,
+      RpcResponseCallback callback) {
+    if (saslHandler != null) {
+      return saslHandler.doAuthChallenge(client, message, callback);
     }
 
     int position = message.position();
@@ -98,18 +92,17 @@ class AuthRpcHandler extends RpcHandler {
       if (conf.saslFallback()) {
         LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.",
           channel.remoteAddress());
-        delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder);
+        saslHandler = new SaslRpcHandler(conf, channel, null, secretKeyHolder);
         message.position(position);
         message.limit(limit);
-        delegate.receive(client, message, callback);
-        doDelegate = true;
+        return saslHandler.doAuthChallenge(client, message, callback);
       } else {
         LOG.debug("Unexpected challenge message from client {}, closing channel.",
           channel.remoteAddress());
         callback.onFailure(new IllegalArgumentException("Unknown challenge message."));
         channel.close();
       }
-      return;
+      return false;
     }
 
     // Here we have the client challenge, so perform the new auth protocol and set up the channel.
@@ -131,7 +124,7 @@ class AuthRpcHandler extends RpcHandler {
       LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
       callback.onFailure(new IllegalArgumentException("Authentication failed."));
       channel.close();
-      return;
+      return false;
     } finally {
       if (engine != null) {
         try {
@@ -143,40 +136,6 @@ class AuthRpcHandler extends RpcHandler {
     }
 
     LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
-    doDelegate = true;
-  }
-
-  @Override
-  public void receive(TransportClient client, ByteBuffer message) {
-    delegate.receive(client, message);
-  }
-
-  @Override
-  public StreamCallbackWithID receiveStream(
-      TransportClient client,
-      ByteBuffer message,
-      RpcResponseCallback callback) {
-    return delegate.receiveStream(client, message, callback);
+    return true;
   }
-
-  @Override
-  public StreamManager getStreamManager() {
-    return delegate.getStreamManager();
-  }
-
-  @Override
-  public void channelActive(TransportClient client) {
-    delegate.channelActive(client);
-  }
-
-  @Override
-  public void channelInactive(TransportClient client) {
-    delegate.channelInactive(client);
-  }
-
-  @Override
-  public void exceptionCaught(Throwable cause, TransportClient client) {
-    delegate.exceptionCaught(cause, client);
-  }
-
 }
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index 355a3de..cc9e88f 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -28,10 +28,9 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.StreamCallbackWithID;
 import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.AbstractAuthRpcHandler;
 import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.StreamManager;
 import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.TransportConf;
 
@@ -43,7 +42,7 @@ import org.apache.spark.network.util.TransportConf;
  * Note that the authentication process consists of multiple challenge-response pairs, each of
  * which are individual RPCs.
  */
-public class SaslRpcHandler extends RpcHandler {
+public class SaslRpcHandler extends AbstractAuthRpcHandler {
   private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
 
   /** Transport configuration. */
@@ -52,37 +51,28 @@ public class SaslRpcHandler extends RpcHandler {
   /** The client channel. */
   private final Channel channel;
 
-  /** RpcHandler we will delegate to for authenticated connections. */
-  private final RpcHandler delegate;
-
   /** Class which provides secret keys which are shared by server and client on a per-app basis. */
   private final SecretKeyHolder secretKeyHolder;
 
   private SparkSaslServer saslServer;
-  private boolean isComplete;
-  private boolean isAuthenticated;
 
   public SaslRpcHandler(
       TransportConf conf,
       Channel channel,
       RpcHandler delegate,
       SecretKeyHolder secretKeyHolder) {
+    super(delegate);
     this.conf = conf;
     this.channel = channel;
-    this.delegate = delegate;
     this.secretKeyHolder = secretKeyHolder;
     this.saslServer = null;
-    this.isComplete = false;
-    this.isAuthenticated = false;
   }
 
   @Override
-  public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
-    if (isComplete) {
-      // Authentication complete, delegate to base handler.
-      delegate.receive(client, message, callback);
-      return;
-    }
+  public boolean doAuthChallenge(
+      TransportClient client,
+      ByteBuffer message,
+      RpcResponseCallback callback) {
     if (saslServer == null || !saslServer.isComplete()) {
       ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
       SaslMessage saslMessage;
@@ -118,43 +108,21 @@ public class SaslRpcHandler extends RpcHandler {
       if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
         logger.debug("SASL authentication successful for channel {}", client);
         complete(true);
-        return;
+        return true;
       }
 
       logger.debug("Enabling encryption for channel {}", client);
       SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
       complete(false);
-      return;
+      return true;
     }
-  }
-
-  @Override
-  public void receive(TransportClient client, ByteBuffer message) {
-    delegate.receive(client, message);
-  }
-
-  @Override
-  public StreamCallbackWithID receiveStream(
-      TransportClient client,
-      ByteBuffer message,
-      RpcResponseCallback callback) {
-    return delegate.receiveStream(client, message, callback);
-  }
-
-  @Override
-  public StreamManager getStreamManager() {
-    return delegate.getStreamManager();
-  }
-
-  @Override
-  public void channelActive(TransportClient client) {
-    delegate.channelActive(client);
+    return false;
   }
 
   @Override
   public void channelInactive(TransportClient client) {
     try {
-      delegate.channelInactive(client);
+      super.channelInactive(client);
     } finally {
       if (saslServer != null) {
         saslServer.dispose();
@@ -162,11 +130,6 @@ public class SaslRpcHandler extends RpcHandler {
     }
   }
 
-  @Override
-  public void exceptionCaught(Throwable cause, TransportClient client) {
-    delegate.exceptionCaught(cause, client);
-  }
-
   private void complete(boolean dispose) {
     if (dispose) {
       try {
@@ -177,7 +140,6 @@ public class SaslRpcHandler extends RpcHandler {
     }
 
     saslServer = null;
-    isComplete = true;
   }
 
 }
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java
new file mode 100644
index 0000000..92eb886
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.nio.ByteBuffer;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.StreamCallbackWithID;
+import org.apache.spark.network.client.TransportClient;
+
+/**
+ * RPC Handler which performs authentication, and when it's successful, delegates further
+ * calls to another RPC handler. The authentication handshake itself should be implemented
+ * by subclasses.
+ */
+public abstract class AbstractAuthRpcHandler extends RpcHandler {
+  /** RpcHandler we will delegate to for authenticated connections. */
+  private final RpcHandler delegate;
+
+  private boolean isAuthenticated;
+
+  protected AbstractAuthRpcHandler(RpcHandler delegate) {
+    this.delegate = delegate;
+  }
+
+  /**
+   * Responds to an authentication challenge.
+   *
+   * @return Whether the client is authenticated.
+   */
+  protected abstract boolean doAuthChallenge(
+      TransportClient client,
+      ByteBuffer message,
+      RpcResponseCallback callback);
+
+  @Override
+  public final void receive(
+      TransportClient client,
+      ByteBuffer message,
+      RpcResponseCallback callback) {
+    if (isAuthenticated) {
+      delegate.receive(client, message, callback);
+    } else {
+      isAuthenticated = doAuthChallenge(client, message, callback);
+    }
+  }
+
+  @Override
+  public final void receive(TransportClient client, ByteBuffer message) {
+    if (isAuthenticated) {
+      delegate.receive(client, message);
+    } else {
+      throw new SecurityException("Unauthenticated call to receive().");
+    }
+  }
+
+  @Override
+  public final StreamCallbackWithID receiveStream(
+      TransportClient client,
+      ByteBuffer message,
+      RpcResponseCallback callback) {
+    if (isAuthenticated) {
+      return delegate.receiveStream(client, message, callback);
+    } else {
+      throw new SecurityException("Unauthenticated call to receiveStream().");
+    }
+  }
+
+  @Override
+  public StreamManager getStreamManager() {
+    return delegate.getStreamManager();
+  }
+
+  @Override
+  public void channelActive(TransportClient client) {
+    delegate.channelActive(client);
+  }
+
+  @Override
+  public void channelInactive(TransportClient client) {
+    delegate.channelInactive(client);
+  }
+
+  @Override
+  public void exceptionCaught(Throwable cause, TransportClient client) {
+    delegate.exceptionCaught(cause, client);
+  }
+
+  public boolean isAuthenticated() {
+    return isAuthenticated;
+  }
+}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
index 2f9dd62..a87a6aa 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
@@ -34,7 +34,6 @@ import org.apache.spark.network.TransportContext;
 import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.client.TransportClientBootstrap;
-import org.apache.spark.network.sasl.SaslRpcHandler;
 import org.apache.spark.network.sasl.SaslServerBootstrap;
 import org.apache.spark.network.sasl.SecretKeyHolder;
 import org.apache.spark.network.server.RpcHandler;
@@ -65,8 +64,7 @@ public class AuthIntegrationSuite {
 
     ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
     assertEquals("Pong", JavaUtils.bytesToString(reply));
-    assertTrue(ctx.authRpcHandler.doDelegate);
-    assertFalse(ctx.authRpcHandler.delegate instanceof SaslRpcHandler);
+    assertNull(ctx.authRpcHandler.saslHandler);
   }
 
   @Test
@@ -78,7 +76,7 @@ public class AuthIntegrationSuite {
       ctx.createClient("client");
       fail("Should have failed to create client.");
     } catch (Exception e) {
-      assertFalse(ctx.authRpcHandler.doDelegate);
+      assertFalse(ctx.authRpcHandler.isAuthenticated());
       assertFalse(ctx.serverChannel.isActive());
     }
   }
@@ -91,6 +89,8 @@ public class AuthIntegrationSuite {
 
     ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
     assertEquals("Pong", JavaUtils.bytesToString(reply));
+    assertNotNull(ctx.authRpcHandler.saslHandler);
+    assertTrue(ctx.authRpcHandler.isAuthenticated());
   }
 
   @Test
@@ -120,7 +120,7 @@ public class AuthIntegrationSuite {
       ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
       fail("Should have failed unencrypted RPC.");
     } catch (Exception e) {
-      assertTrue(ctx.authRpcHandler.doDelegate);
+      assertTrue(ctx.authRpcHandler.isAuthenticated());
     }
   }
 
@@ -151,7 +151,7 @@ public class AuthIntegrationSuite {
       ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
       fail("Should have failed unencrypted RPC.");
     } catch (Exception e) {
-      assertTrue(ctx.authRpcHandler.doDelegate);
+      assertTrue(ctx.authRpcHandler.isAuthenticated());
       assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD"));
       // Verify we receive the complete error message
       int messageStart = e.getMessage().indexOf("DDDDD");
diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index cf2d72f..ecaeec9 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -357,7 +357,8 @@ public class SparkSaslSuite {
   public void testDelegates() throws Exception {
     Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods();
     for (Method m : rpcHandlerMethods) {
-      SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes());
+      Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes());
+      assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class);
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org