You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by yc...@apache.org on 2022/06/13 20:44:28 UTC

[cassandra] branch trunk updated: Adding support to perform certificate based internode authentication

This is an automated email from the ASF dual-hosted git repository.

ycai pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 557b8e9982 Adding support to perform certificate based internode authentication
557b8e9982 is described below

commit 557b8e9982ad0964191abde810ef5c77a536f70a
Author: Jyothsna Konisa <jk...@apple.com>
AuthorDate: Mon Jun 13 11:05:22 2022 -0700

    Adding support to perform certificate based internode authentication
    
    patch by Jyothsna Konisa; reviewed by Jon Meredith, Yifan Cai for CASSANDRA-17661
---
 CHANGES.txt                                        |   1 +
 .../auth/AllowAllInternodeAuthenticator.java       |   4 +-
 .../cassandra/auth/IInternodeAuthenticator.java    |  50 +++-
 .../cassandra/net/InboundConnectionInitiator.java  | 100 ++++---
 .../cassandra/net/InternodeConnectionUtils.java    |  83 ++++++
 .../org/apache/cassandra/net/MessagingService.java |   3 +
 .../cassandra/net/OutboundConnectionInitiator.java |  60 ++++-
 .../apache/cassandra/service/StorageService.java   |   2 +-
 .../test/InternodeEncryptionEnforcementTest.java   | 286 +++++++++++++++++++++
 test/unit/org/apache/cassandra/SchemaLoader.java   |   1 +
 .../apache/cassandra/net/MessagingServiceTest.java |  91 +++++--
 11 files changed, 616 insertions(+), 65 deletions(-)

diff --git a/CHANGES.txt b/CHANGES.txt
index 57733a5438..8e8305aed0 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 4.2
+ * Adding support to perform certificate based internode authentication (CASSANDRA-17661)
  * Option to disable CDC writes of repaired data (CASSANDRA-17666)
  * When a node is bootstrapping it gets the whole gossip state but applies in random order causing some cases where StorageService will fail causing an instance to not show up in TokenMetadata (CASSANDRA-17676)
  * Add CQLSH command SHOW REPLICAS (CASSANDRA-17577)
diff --git a/src/java/org/apache/cassandra/auth/AllowAllInternodeAuthenticator.java b/src/java/org/apache/cassandra/auth/AllowAllInternodeAuthenticator.java
index d0d2d745d7..ac62bfae00 100644
--- a/src/java/org/apache/cassandra/auth/AllowAllInternodeAuthenticator.java
+++ b/src/java/org/apache/cassandra/auth/AllowAllInternodeAuthenticator.java
@@ -20,12 +20,14 @@
 package org.apache.cassandra.auth;
 
 import java.net.InetAddress;
+import java.security.cert.Certificate;
 
 import org.apache.cassandra.exceptions.ConfigurationException;
 
 public class AllowAllInternodeAuthenticator implements IInternodeAuthenticator
 {
-    public boolean authenticate(InetAddress remoteAddress, int remotePort)
+    public boolean authenticate(InetAddress remoteAddress, int remotePort,
+                                Certificate[] certificates, InternodeConnectionDirection connectionType)
     {
         return true;
     }
diff --git a/src/java/org/apache/cassandra/auth/IInternodeAuthenticator.java b/src/java/org/apache/cassandra/auth/IInternodeAuthenticator.java
index 8e09b9035f..02745fe925 100644
--- a/src/java/org/apache/cassandra/auth/IInternodeAuthenticator.java
+++ b/src/java/org/apache/cassandra/auth/IInternodeAuthenticator.java
@@ -20,6 +20,7 @@
 package org.apache.cassandra.auth;
 
 import java.net.InetAddress;
+import java.security.cert.Certificate;
 
 import org.apache.cassandra.exceptions.ConfigurationException;
 
@@ -33,7 +34,35 @@ public interface IInternodeAuthenticator
      * @param remotePort port of the connecting node.
      * @return true if the connection should be accepted, false otherwise.
      */
-    boolean authenticate(InetAddress remoteAddress, int remotePort);
+    @Deprecated
+    default boolean authenticate(InetAddress remoteAddress, int remotePort)
+    {
+        return false;
+    }
+
+    /**
+     * Decides whether a peer is allowed to connect to this node.
+     * If this method returns false, the socket will be immediately closed.
+     * <p>
+     * Default implementation calls authenticate method by IP and port method
+     * <p>
+     * 1. If it is IP based authentication ignore the certificates & connectionType parameters in the implementation
+     * of this method.
+     * 2. For certificate based authentication like mTLS, server's identity for outbound connections is verified by the
+     * trusted root certificates in the outbound_keystore. In such cases this method may be overridden to return true
+     * when certificateType is OUTBOUND, as the authentication of the server happens during SSL Handshake.
+     *
+     * @param remoteAddress  ip address of the connecting node.
+     * @param remotePort     port of the connecting node.
+     * @param certificates   peer certificates
+     * @param connectionType If the connection is inbound/outbound connection.
+     * @return true if the connection should be accepted, false otherwise.
+     */
+    default boolean authenticate(InetAddress remoteAddress, int remotePort,
+                                 Certificate[] certificates, InternodeConnectionDirection connectionType)
+    {
+        return authenticate(remoteAddress, remotePort);
+    }
 
     /**
      * Validates configuration of IInternodeAuthenticator implementation (if configurable).
@@ -41,4 +70,23 @@ public interface IInternodeAuthenticator
      * @throws ConfigurationException when there is a configuration error.
      */
     void validateConfiguration() throws ConfigurationException;
+
+    /**
+     * Setup is called once upon system startup to initialize the IAuthenticator.
+     *
+     * For example, use this method to create any required keyspaces/column families.
+     */
+    default void setupInternode()
+    {
+
+    }
+
+    /**
+     * Enum that represents connection type of an internode connection.
+     */
+    enum InternodeConnectionDirection
+    {
+        INBOUND,
+        OUTBOUND
+    }
 }
diff --git a/src/java/org/apache/cassandra/net/InboundConnectionInitiator.java b/src/java/org/apache/cassandra/net/InboundConnectionInitiator.java
index c5ed064259..f3dc28a307 100644
--- a/src/java/org/apache/cassandra/net/InboundConnectionInitiator.java
+++ b/src/java/org/apache/cassandra/net/InboundConnectionInitiator.java
@@ -20,6 +20,7 @@ package org.apache.cassandra.net;
 import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
+import java.security.cert.Certificate;
 import java.util.List;
 import java.util.NoSuchElementException;
 import java.util.concurrent.Future;
@@ -46,6 +47,7 @@ import io.netty.handler.logging.LogLevel;
 import io.netty.handler.logging.LoggingHandler;
 import io.netty.handler.ssl.SslContext;
 import io.netty.handler.ssl.SslHandler;
+import org.apache.cassandra.auth.IInternodeAuthenticator;
 import org.apache.cassandra.config.DatabaseDescriptor;
 import org.apache.cassandra.config.EncryptionOptions;
 import org.apache.cassandra.exceptions.ConfigurationException;
@@ -60,7 +62,11 @@ import org.apache.cassandra.utils.memory.BufferPools;
 
 import static java.lang.Math.*;
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.apache.cassandra.auth.IInternodeAuthenticator.InternodeConnectionDirection.INBOUND;
 import static org.apache.cassandra.concurrent.ExecutorFactory.Global.executorFactory;
+import static org.apache.cassandra.net.InternodeConnectionUtils.DISCARD_HANDLER_NAME;
+import static org.apache.cassandra.net.InternodeConnectionUtils.SSL_HANDLER_NAME;
+import static org.apache.cassandra.net.InternodeConnectionUtils.certificates;
 import static org.apache.cassandra.net.MessagingService.*;
 import static org.apache.cassandra.net.SocketFactory.WIRETRACE;
 import static org.apache.cassandra.net.SocketFactory.newSslHandler;
@@ -102,7 +108,7 @@ public class InboundConnectionInitiator
 
             pipelineInjector.accept(pipeline);
 
-            // order of handlers: ssl -> logger -> handshakeHandler
+            // order of handlers: ssl -> client-authentication -> logger -> handshakeHandler
             // For either unencrypted or transitional modes, allow Ssl optionally.
             switch(settings.encryption.tlsEncryptionPolicy())
             {
@@ -111,14 +117,17 @@ public class InboundConnectionInitiator
                     pipeline.addAfter(PIPELINE_INTERNODE_ERROR_EXCLUSIONS, "rejectssl", new RejectSslHandler());
                     break;
                 case OPTIONAL:
-                    pipeline.addAfter(PIPELINE_INTERNODE_ERROR_EXCLUSIONS, "ssl", new OptionalSslHandler(settings.encryption));
+                    pipeline.addAfter(PIPELINE_INTERNODE_ERROR_EXCLUSIONS, SSL_HANDLER_NAME, new OptionalSslHandler(settings.encryption));
                     break;
                 case ENCRYPTED:
                     SslHandler sslHandler = getSslHandler("creating", channel, settings.encryption);
-                    pipeline.addAfter(PIPELINE_INTERNODE_ERROR_EXCLUSIONS, "ssl", sslHandler);
+                    pipeline.addAfter(PIPELINE_INTERNODE_ERROR_EXCLUSIONS, SSL_HANDLER_NAME, sslHandler);
                     break;
             }
 
+            // Pipeline for performing client authentication
+            pipeline.addLast("client-authentication", new ClientAuthenticationHandler(settings.authenticator));
+
             if (WIRETRACE)
                 pipeline.addLast("logger", new LoggingHandler(LogLevel.INFO));
 
@@ -198,6 +207,61 @@ public class InboundConnectionInitiator
         return bind(new Initializer(settings, channelGroup, pipelineInjector));
     }
 
+    /**
+     * Handler to perform authentication for internode inbound connections.
+     * This handler is called even before messaging handshake starts.
+     */
+    private static class ClientAuthenticationHandler extends ByteToMessageDecoder
+    {
+        private final IInternodeAuthenticator authenticator;
+
+        public ClientAuthenticationHandler(IInternodeAuthenticator authenticator)
+        {
+            this.authenticator = authenticator;
+        }
+
+        @Override
+        protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception
+        {
+            // Extract certificates from SSL handler(handler with name "ssl").
+            final Certificate[] certificates = certificates(channelHandlerContext.channel());
+            if (!authenticate(channelHandlerContext.channel().remoteAddress(), certificates))
+            {
+                logger.error("Unable to authenticate peer {} for internode authentication", channelHandlerContext.channel());
+
+                // To release all the pending buffered data, replace authentication handler with discard handler.
+                // This avoids pending inbound data to be fired through the pipeline
+                channelHandlerContext.pipeline().replace(this, DISCARD_HANDLER_NAME, new InternodeConnectionUtils.ByteBufDiscardHandler());
+                channelHandlerContext.pipeline().close();
+            }
+            else
+            {
+                channelHandlerContext.pipeline().remove(this);
+            }
+        }
+
+        private boolean authenticate(SocketAddress socketAddress, final Certificate[] certificates) throws IOException
+        {
+            if (socketAddress.getClass().getSimpleName().equals("EmbeddedSocketAddress"))
+                return true;
+
+            if (!(socketAddress instanceof InetSocketAddress))
+                throw new IOException(String.format("Unexpected SocketAddress type: %s, %s", socketAddress.getClass(), socketAddress));
+
+            InetSocketAddress addr = (InetSocketAddress) socketAddress;
+            if (!authenticator.authenticate(addr.getAddress(), addr.getPort(), certificates, INBOUND))
+            {
+                // Log at info level as anything that can reach the inbound port could hit this
+                // and trigger a log of noise.  Failed outbound connections to known cluster endpoints
+                // still fail with an ERROR message and exception to alert operators that aren't watching logs closely.
+                logger.info("Authenticate rejected inbound internode connection from {}", addr);
+                return false;
+            }
+            return true;
+        }
+
+    }
+
     /**
      * 'Server-side' component that negotiates the internode handshake when establishing a new connection.
      * This handler will be the first in the netty channel for each incoming connection (secure socket (TLS) notwithstanding),
@@ -223,8 +287,7 @@ public class InboundConnectionInitiator
         }
 
         /**
-         * On registration, immediately schedule a timeout to kill this connection if it does not handshake promptly,
-         * and authenticate the remote address.
+         * On registration, immediately schedule a timeout to kill this connection if it does not handshake promptly.
          */
         public void handlerAdded(ChannelHandlerContext ctx) throws Exception
         {
@@ -232,31 +295,6 @@ public class InboundConnectionInitiator
                 logger.error("Timeout handshaking with {} (on {})", SocketFactory.addressId(initiate.from, (InetSocketAddress) ctx.channel().remoteAddress()), settings.bindAddress);
                 failHandshake(ctx);
             }, HandshakeProtocol.TIMEOUT_MILLIS, MILLISECONDS);
-
-            if (!authenticate(ctx.channel().remoteAddress()))
-            {
-                failHandshake(ctx);
-            }
-        }
-
-        private boolean authenticate(SocketAddress socketAddress) throws IOException
-        {
-            if (socketAddress.getClass().getSimpleName().equals("EmbeddedSocketAddress"))
-                return true;
-
-            if (!(socketAddress instanceof InetSocketAddress))
-                throw new IOException(String.format("Unexpected SocketAddress type: %s, %s", socketAddress.getClass(), socketAddress));
-
-            InetSocketAddress addr = (InetSocketAddress)socketAddress;
-            if (!settings.authenticate(addr.getAddress(), addr.getPort()))
-            {
-                // Log at info level as anything that can reach the inbound port could hit this
-                // and trigger a log of noise.  Failed outbound connections to known cluster endpoints
-                // still fail with an ERROR message and exception to alert operators that aren't watching logs closely.
-                logger.info("Authenticate rejected inbound internode connection from {}", addr);
-                return false;
-            }
-            return true;
         }
 
         @Override
@@ -562,7 +600,7 @@ public class InboundConnectionInitiator
             {
                 // Connection uses SSL/TLS, replace the detection handler with a SslHandler and so use encryption.
                 SslHandler sslHandler = getSslHandler("replacing optional", ctx.channel(), encryptionOptions);
-                ctx.pipeline().replace(this, "ssl", sslHandler);
+                ctx.pipeline().replace(this, SSL_HANDLER_NAME, sslHandler);
             }
             else
             {
diff --git a/src/java/org/apache/cassandra/net/InternodeConnectionUtils.java b/src/java/org/apache/cassandra/net/InternodeConnectionUtils.java
new file mode 100644
index 0000000000..39a087960b
--- /dev/null
+++ b/src/java/org/apache/cassandra/net/InternodeConnectionUtils.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net;
+
+import java.security.cert.Certificate;
+import javax.net.ssl.SSLPeerUnverifiedException;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.ssl.SslHandler;
+
+/**
+ * Class that contains certificate utility methods.
+ */
+class InternodeConnectionUtils
+{
+    public static String SSL_HANDLER_NAME = "ssl";
+    public static String DISCARD_HANDLER_NAME = "discard";
+    private static final Logger logger = LoggerFactory.getLogger(InternodeConnectionUtils.class);
+
+    public static Certificate[] certificates(Channel channel)
+    {
+        final SslHandler sslHandler = (SslHandler) channel.pipeline().get(SSL_HANDLER_NAME);
+        Certificate[] certificates = null;
+        if (sslHandler != null)
+        {
+            try
+            {
+                certificates = sslHandler.engine()
+                                         .getSession()
+                                         .getPeerCertificates();
+            }
+            catch (SSLPeerUnverifiedException e)
+            {
+                logger.debug("Failed to get peer certificates for peer {}", channel.remoteAddress(), e);
+            }
+        }
+        return certificates;
+    }
+
+    /**
+     * Discard handler releases the received data silently. when internode authentication fails, the channel is closed,
+     * but the pending buffered data may still be fired through the pipeline. To avoid that, authentication handler is
+     * replaced with this DiscardHandler to release all the buffered data, to avoid handling unauthenticated data in the
+     * following handlers.
+     */
+    public static class ByteBufDiscardHandler extends ChannelInboundHandlerAdapter
+    {
+        @Override
+        public void channelRead(ChannelHandlerContext ctx, Object msg)
+        {
+            if (msg instanceof ByteBuf)
+            {
+                ((ByteBuf) msg).release();
+            }
+            else
+            {
+                ctx.fireChannelRead(msg);
+            }
+        }
+    }
+}
diff --git a/src/java/org/apache/cassandra/net/MessagingService.java b/src/java/org/apache/cassandra/net/MessagingService.java
index ea019fd8fb..d968a0ce2b 100644
--- a/src/java/org/apache/cassandra/net/MessagingService.java
+++ b/src/java/org/apache/cassandra/net/MessagingService.java
@@ -475,7 +475,10 @@ public class MessagingService extends MessagingServiceMBeanImpl
     {
         OutboundConnections pool = channelManagers.get(to);
         if (pool != null)
+        {
             pool.interrupt();
+            logger.info("Interrupted outbound connections to {}", to);
+        }
     }
 
     /**
diff --git a/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java b/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java
index a187068ce7..9565f54846 100644
--- a/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java
+++ b/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java
@@ -21,13 +21,16 @@ package org.apache.cassandra.net;
 import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.nio.channels.ClosedChannelException;
+import java.security.cert.Certificate;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicBoolean;
 
+import com.google.common.annotations.VisibleForTesting;
+
 import io.netty.util.concurrent.Future; //checkstyle: permit this import
 import io.netty.util.concurrent.Promise; //checkstyle: permit this import
 import org.apache.cassandra.utils.concurrent.AsyncPromise;
-import org.apache.cassandra.utils.concurrent.ImmediateFuture;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -59,6 +62,10 @@ import org.apache.cassandra.utils.JVMStabilityInspector;
 import org.apache.cassandra.utils.memory.BufferPools;
 
 import static java.util.concurrent.TimeUnit.*;
+import static org.apache.cassandra.auth.IInternodeAuthenticator.InternodeConnectionDirection.OUTBOUND;
+import static org.apache.cassandra.net.InternodeConnectionUtils.DISCARD_HANDLER_NAME;
+import static org.apache.cassandra.net.InternodeConnectionUtils.SSL_HANDLER_NAME;
+import static org.apache.cassandra.net.InternodeConnectionUtils.certificates;
 import static org.apache.cassandra.net.MessagingService.VERSION_40;
 import static org.apache.cassandra.net.HandshakeProtocol.*;
 import static org.apache.cassandra.net.ConnectionType.STREAMING;
@@ -130,13 +137,6 @@ public class OutboundConnectionInitiator<SuccessType extends OutboundConnectionI
         if (logger.isTraceEnabled())
             logger.trace("creating outbound bootstrap to {}, requestVersion: {}", settings, requestMessagingVersion);
 
-        if (!settings.authenticate())
-        {
-            // interrupt other connections, so they must attempt to re-authenticate
-            MessagingService.instance().interruptOutbound(settings.to);
-            return ImmediateFuture.failure(new IOException("authentication failed to " + settings.connectToId()));
-        }
-
         // this is a bit ugly, but is the easiest way to ensure that if we timeout we can propagate a suitable error message
         // and still guarantee that, if on timing out we raced with success, the successfully created channel is handled
         AtomicBoolean timedout = new AtomicBoolean();
@@ -198,7 +198,7 @@ public class OutboundConnectionInitiator<SuccessType extends OutboundConnectionI
         {
             ChannelPipeline pipeline = channel.pipeline();
 
-            // order of handlers: ssl -> logger -> handshakeHandler
+            // order of handlers: ssl -> server-authentication -> logger -> handshakeHandler
             if (settings.withEncryption())
             {
                 // check if we should actually encrypt this connection
@@ -209,8 +209,9 @@ public class OutboundConnectionInitiator<SuccessType extends OutboundConnectionI
                 InetSocketAddress peer = settings.encryption.require_endpoint_verification ? new InetSocketAddress(address.getAddress(), address.getPort()) : null;
                 SslHandler sslHandler = newSslHandler(channel, sslContext, peer);
                 logger.trace("creating outbound netty SslContext: context={}, engine={}", sslContext.getClass().getName(), sslHandler.engine().getClass().getName());
-                pipeline.addFirst("ssl", sslHandler);
+                pipeline.addFirst(SSL_HANDLER_NAME, sslHandler);
             }
+            pipeline.addLast("server-authentication", new ServerAuthenticationHandler(settings));
 
             if (WIRETRACE)
                 pipeline.addLast("logger", new LoggingHandler(LogLevel.INFO));
@@ -220,6 +221,45 @@ public class OutboundConnectionInitiator<SuccessType extends OutboundConnectionI
 
     }
 
+    /**
+     * Authenticates the server before an outbound connection is established. If a connection is SSL based connection
+     * Server's identity is verified during ssl handshake using root certificate in truststore. One may choose to ignore
+     * outbound authentication or perform required authentication for outbound connections in the implementation
+     * of IInternodeAuthenticator interface.
+     */
+    @VisibleForTesting
+    static class ServerAuthenticationHandler extends ByteToMessageDecoder
+    {
+        final OutboundConnectionSettings settings;
+
+        ServerAuthenticationHandler(OutboundConnectionSettings settings)
+        {
+            this.settings = settings;
+        }
+
+        @Override
+        protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception
+        {
+            // Extract certificates from SSL handler(handler with name "ssl").
+            final Certificate[] certificates = certificates(channelHandlerContext.channel());
+            if (!settings.authenticator.authenticate(settings.to.getAddress(), settings.to.getPort(), certificates, OUTBOUND))
+            {
+                // interrupt other connections, so they must attempt to re-authenticate
+                MessagingService.instance().interruptOutbound(settings.to);
+                logger.error("authentication failed to " + settings.connectToId());
+
+                // To release all the pending buffered data, replace authentication handler with discard handler.
+                // This avoids pending inbound data to be fired through the pipeline
+                channelHandlerContext.pipeline().replace(this, DISCARD_HANDLER_NAME, new InternodeConnectionUtils.ByteBufDiscardHandler());
+                channelHandlerContext.pipeline().close();
+            }
+            else
+            {
+                channelHandlerContext.pipeline().remove(this);
+            }
+        }
+    }
+
     private class Handler extends ByteToMessageDecoder
     {
         /**
diff --git a/src/java/org/apache/cassandra/service/StorageService.java b/src/java/org/apache/cassandra/service/StorageService.java
index 34fb6acb17..0ba43a7373 100644
--- a/src/java/org/apache/cassandra/service/StorageService.java
+++ b/src/java/org/apache/cassandra/service/StorageService.java
@@ -32,7 +32,6 @@ import java.util.Map.Entry;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
 import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
@@ -1238,6 +1237,7 @@ public class StorageService extends NotificationBroadcasterSupport implements IE
 
             DatabaseDescriptor.getRoleManager().setup();
             DatabaseDescriptor.getAuthenticator().setup();
+            DatabaseDescriptor.getInternodeAuthenticator().setupInternode();
             DatabaseDescriptor.getAuthorizer().setup();
             DatabaseDescriptor.getNetworkAuthorizer().setup();
             AuthCacheService.initializeAndRegisterCaches();
diff --git a/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java b/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
index 969f372456..157aede9b7 100644
--- a/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
+++ b/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
@@ -17,15 +17,29 @@
  */
 package org.apache.cassandra.distributed.test;
 
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.InetAddress;
+import java.security.KeyStore;
+import java.security.cert.Certificate;
 import java.util.HashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import com.google.common.collect.ImmutableMap;
 import org.junit.Test;
 
+import org.apache.cassandra.auth.AllowAllInternodeAuthenticator;
+import org.apache.cassandra.auth.IInternodeAuthenticator;
+import org.apache.cassandra.config.DatabaseDescriptor;
 import org.apache.cassandra.distributed.Cluster;
 import org.apache.cassandra.distributed.api.Feature;
 import org.apache.cassandra.distributed.api.IIsolatedExecutor.SerializableRunnable;
 import org.apache.cassandra.distributed.shared.NetworkTopology;
+import org.apache.cassandra.exceptions.ConfigurationException;
 import org.apache.cassandra.net.InboundMessageHandlers;
 import org.apache.cassandra.net.MessagingService;
 import org.apache.cassandra.net.OutboundConnections;
@@ -40,6 +54,131 @@ import static org.junit.Assert.fail;
 
 public final class InternodeEncryptionEnforcementTest extends TestBaseImpl
 {
+
+    @Test
+    public void testInboundConnectionsAreRejectedWhenAuthFails() throws IOException, TimeoutException
+    {
+        Cluster.Builder builder = createCluster(RejectInboundConnections.class);
+
+        final ExecutorService executorService = Executors.newSingleThreadExecutor();
+        try (Cluster cluster = builder.start())
+        {
+            executorService.submit(() -> openConnections(cluster));
+
+            /*
+             * instance (1) should not connect to instance (2) as authentication fails;
+             * instance (2) should not connect to instance (1) as authentication fails.
+             */
+            SerializableRunnable runnable = () ->
+            {
+                // There should be no inbound handlers as authentication fails and we remove handlers.
+                assertEquals(0, MessagingService.instance().messageHandlers.values().size());
+
+                // There should be no outbound connections as authentication fails.
+                OutboundConnections outbound = getOnlyElement(MessagingService.instance().channelManagers.values());
+                assertTrue(!outbound.small.isConnected() && !outbound.large.isConnected() && !outbound.urgent.isConnected());
+
+                // Verify that the failure is due to authentication failure
+                final RejectInboundConnections authenticator = (RejectInboundConnections) DatabaseDescriptor.getInternodeAuthenticator();
+                assertTrue(authenticator.authenticationFailed);
+            };
+
+            // Wait for authentication to fail
+            cluster.get(1).logs().watchFor("Unable to authenticate peer");
+            cluster.get(1).runOnInstance(runnable);
+            cluster.get(2).logs().watchFor("Unable to authenticate peer");
+            cluster.get(2).runOnInstance(runnable);
+        }
+        executorService.shutdown();
+    }
+
+    @Test
+    public void testOutboundConnectionsAreRejectedWhenAuthFails() throws IOException, TimeoutException
+    {
+        Cluster.Builder builder = createCluster(RejectOutboundAuthenticator.class);
+
+        final ExecutorService executorService = Executors.newSingleThreadExecutor();
+        try (Cluster cluster = builder.start())
+        {
+            executorService.submit(() -> openConnections(cluster));
+
+            /*
+             * instance (1) should not connect to instance (2) as authentication fails;
+             * instance (2) should not connect to instance (1) as authentication fails.
+             */
+            SerializableRunnable runnable = () ->
+            {
+                // There should be no inbound connections as authentication fails.
+                InboundMessageHandlers inbound = getOnlyElement(MessagingService.instance().messageHandlers.values());
+                assertEquals(0, inbound.count());
+
+                // There should be no outbound connections as authentication fails.
+                OutboundConnections outbound = getOnlyElement(MessagingService.instance().channelManagers.values());
+                assertTrue(!outbound.small.isConnected() && !outbound.large.isConnected() && !outbound.urgent.isConnected());
+
+                // Verify that the failure is due to authentication failure
+                final RejectOutboundAuthenticator authenticator = (RejectOutboundAuthenticator) DatabaseDescriptor.getInternodeAuthenticator();
+                assertTrue(authenticator.authenticationFailed);
+            };
+
+            // Wait for authentication to fail
+            cluster.get(1).logs().watchFor("authentication failed");
+            cluster.get(1).runOnInstance(runnable);
+            cluster.get(2).logs().watchFor("authentication failed");
+            cluster.get(2).runOnInstance(runnable);
+        }
+        executorService.shutdown();
+    }
+
+    @Test
+    public void testOutboundConnectionsAreInterruptedWhenAuthFails() throws IOException, TimeoutException
+    {
+        Cluster.Builder builder = createCluster(AllowFirstAndRejectOtherOutboundAuthenticator.class);
+        try (Cluster cluster = builder.start())
+        {
+            try
+            {
+                openConnections(cluster);
+            }
+            catch (RuntimeException ise)
+            {
+                assertThat(ise.getMessage(), containsString("agreement not reached"));
+            }
+
+            // Verify that authentication is failed and Interrupt is called on outbound connections.
+            cluster.get(1).logs().watchFor("authentication failed to");
+            cluster.get(1).logs().watchFor("Interrupted outbound connections to");
+
+            /*
+             * Check if outbound connections are zero
+             */
+            SerializableRunnable runnable = () ->
+            {
+                // Verify that there is only one successful outbound connection
+                final AllowFirstAndRejectOtherOutboundAuthenticator authenticator = (AllowFirstAndRejectOtherOutboundAuthenticator) DatabaseDescriptor.getInternodeAuthenticator();
+                assertEquals(1, authenticator.successfulOutbound.get());
+                assertTrue(authenticator.failedOutbound.get() > 0);
+
+                // There should be no outbound connections as authentication fails.
+                OutboundConnections outbound = getOnlyElement(MessagingService.instance().channelManagers.values());
+                assertTrue(!outbound.small.isConnected() && !outbound.large.isConnected() && !outbound.urgent.isConnected());
+            };
+            cluster.get(1).runOnInstance(runnable);
+        }
+    }
+
+    @Test
+    public void testConnectionsAreAcceptedWhenAuthSucceds() throws IOException
+    {
+        verifyAuthenticationSucceeds(AllowAllInternodeAuthenticator.class);
+    }
+
+    @Test
+    public void testAuthenticationWithCertificateAuthenticator() throws IOException
+    {
+        verifyAuthenticationSucceeds(CertificateVerifyAuthenticator.class);
+    }
+
     @Test
     public void testConnectionsAreRejectedWithInvalidConfig() throws Throwable
     {
@@ -155,4 +294,151 @@ public final class InternodeEncryptionEnforcementTest extends TestBaseImpl
         cluster.schemaChange("CREATE KEYSPACE test_connections_from_2 " +
                              "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 2};", false, cluster.get(2));
     }
+
+    private void verifyAuthenticationSucceeds(final Class authenticatorClass) throws IOException
+    {
+        Cluster.Builder builder = createCluster(authenticatorClass);
+        try (Cluster cluster = builder.start())
+        {
+            openConnections(cluster);
+
+            /*
+             * instance (1) should connect to instance (2) without any issues;
+             * instance (2) should connect to instance (1) without any issues.
+             */
+
+            SerializableRunnable runnable = () ->
+            {
+                // There should be inbound connections as authentication succeeds.
+                InboundMessageHandlers inbound = getOnlyElement(MessagingService.instance().messageHandlers.values());
+                assertTrue(inbound.count() > 0);
+
+                // There should be outbound connections as authentication succeeds.
+                OutboundConnections outbound = getOnlyElement(MessagingService.instance().channelManagers.values());
+                assertTrue(outbound.small.isConnected() || outbound.large.isConnected() || outbound.urgent.isConnected());
+            };
+
+            cluster.get(1).runOnInstance(runnable);
+            cluster.get(2).runOnInstance(runnable);
+        }
+    }
+
+    private Cluster.Builder createCluster(final Class authenticatorClass)
+    {
+        return builder()
+        .withNodes(2)
+        .withConfig(c ->
+                    {
+                        c.with(Feature.NETWORK);
+                        c.with(Feature.NATIVE_PROTOCOL);
+
+                        HashMap<String, Object> encryption = new HashMap<>();
+                        encryption.put("keystore", "test/conf/cassandra_ssl_test.keystore");
+                        encryption.put("keystore_password", "cassandra");
+                        encryption.put("truststore", "test/conf/cassandra_ssl_test.truststore");
+                        encryption.put("truststore_password", "cassandra");
+                        encryption.put("internode_encryption", "all");
+                        encryption.put("require_client_auth", "true");
+                        c.set("server_encryption_options", encryption);
+                        c.set("internode_authenticator", authenticatorClass.getName());
+                    })
+        .withNodeIdTopology(ImmutableMap.of(1, NetworkTopology.dcAndRack("dc1", "r1a"),
+                                            2, NetworkTopology.dcAndRack("dc2", "r2a")));
+    }
+
+    // Authenticator that validates certificate authentication
+    public static class CertificateVerifyAuthenticator implements IInternodeAuthenticator
+    {
+        @Override
+        public boolean authenticate(InetAddress remoteAddress, int remotePort, Certificate[] certificates, InternodeConnectionDirection connectionType)
+        {
+            try
+            {
+                // Check if the presented certificates during internode authentication are the ones in the keystores
+                // configured in the cassandra.yaml configuration.
+                KeyStore keyStore = KeyStore.getInstance("JKS");
+                char[] keyStorePassword = "cassandra".toCharArray();
+                InputStream keyStoreData = new FileInputStream("test/conf/cassandra_ssl_test.keystore");
+                keyStore.load(keyStoreData, keyStorePassword);
+                return certificates != null && certificates.length != 0 && keyStore.getCertificate("cassandra_ssl_test").equals(certificates[0]);
+            }
+            catch (Exception e)
+            {
+                return false;
+            }
+        }
+
+        @Override
+        public void validateConfiguration() throws ConfigurationException
+        {
+
+        }
+    }
+
+    public static class RejectConnectionsAuthenticator implements IInternodeAuthenticator
+    {
+        boolean authenticationFailed = false;
+
+        @Override
+        public boolean authenticate(InetAddress remoteAddress, int remotePort, Certificate[] certificates, InternodeConnectionDirection connectionType)
+        {
+            authenticationFailed = true;
+            return false;
+        }
+
+        @Override
+        public void validateConfiguration() throws ConfigurationException
+        {
+
+        }
+    }
+
+    public static class RejectInboundConnections extends RejectConnectionsAuthenticator
+    {
+        @Override
+        public boolean authenticate(InetAddress remoteAddress, int remotePort, Certificate[] certificates, InternodeConnectionDirection connectionType)
+        {
+            if (connectionType == InternodeConnectionDirection.INBOUND)
+            {
+                return super.authenticate(remoteAddress, remotePort, certificates, connectionType);
+            }
+            return true;
+        }
+    }
+
+    public static class RejectOutboundAuthenticator extends RejectConnectionsAuthenticator
+    {
+        @Override
+        public boolean authenticate(InetAddress remoteAddress, int remotePort, Certificate[] certificates, InternodeConnectionDirection connectionType)
+        {
+            if (connectionType == InternodeConnectionDirection.OUTBOUND)
+            {
+                return super.authenticate(remoteAddress, remotePort, certificates, connectionType);
+            }
+            return true;
+        }
+    }
+
+    public static class AllowFirstAndRejectOtherOutboundAuthenticator extends RejectOutboundAuthenticator
+    {
+        AtomicInteger successfulOutbound = new AtomicInteger();
+        AtomicInteger failedOutbound = new AtomicInteger();
+
+        @Override
+        public boolean authenticate(InetAddress remoteAddress, int remotePort, Certificate[] certificates, InternodeConnectionDirection connectionType)
+        {
+            if (connectionType == InternodeConnectionDirection.OUTBOUND)
+            {
+                if(successfulOutbound.get() == 0) {
+                    successfulOutbound.incrementAndGet();
+                    return true;
+                } else {
+                    failedOutbound.incrementAndGet();
+                    return false;
+                }
+
+            }
+            return true;
+        }
+    }
 }
diff --git a/test/unit/org/apache/cassandra/SchemaLoader.java b/test/unit/org/apache/cassandra/SchemaLoader.java
index 327949074b..e2ef487ddf 100644
--- a/test/unit/org/apache/cassandra/SchemaLoader.java
+++ b/test/unit/org/apache/cassandra/SchemaLoader.java
@@ -286,6 +286,7 @@ public class SchemaLoader
         SchemaTestUtil.announceNewKeyspace(AuthKeyspace.metadata());
         DatabaseDescriptor.getRoleManager().setup();
         DatabaseDescriptor.getAuthenticator().setup();
+        DatabaseDescriptor.getInternodeAuthenticator().setupInternode();
         DatabaseDescriptor.getAuthorizer().setup();
         DatabaseDescriptor.getNetworkAuthorizer().setup();
         Schema.instance.registerListener(new AuthSchemaChangeListener());
diff --git a/test/unit/org/apache/cassandra/net/MessagingServiceTest.java b/test/unit/org/apache/cassandra/net/MessagingServiceTest.java
index 349d8652cf..32d5050388 100644
--- a/test/unit/org/apache/cassandra/net/MessagingServiceTest.java
+++ b/test/unit/org/apache/cassandra/net/MessagingServiceTest.java
@@ -25,6 +25,7 @@ import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
 import java.nio.channels.AsynchronousSocketChannel;
+import java.security.cert.Certificate;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashSet;
@@ -35,31 +36,35 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
-import java.util.regex.*;
 import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 
 import com.google.common.net.InetAddresses;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
 
 import com.codahale.metrics.Timer;
-
 import org.apache.cassandra.auth.IInternodeAuthenticator;
 import org.apache.cassandra.config.DatabaseDescriptor;
 import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions;
 import org.apache.cassandra.db.commitlog.CommitLog;
-import org.apache.cassandra.metrics.MessagingMetrics;
 import org.apache.cassandra.exceptions.ConfigurationException;
 import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.metrics.MessagingMetrics;
+import org.apache.cassandra.utils.ByteBufferUtil;
 import org.apache.cassandra.utils.FBUtilities;
 import org.awaitility.Awaitility;
 import org.caffinitas.ohc.histo.EstimatedHistogram;
-import org.junit.After;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.BeforeClass;
-import org.junit.Test;
 
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
 
 public class MessagingServiceTest
 {
@@ -67,7 +72,8 @@ public class MessagingServiceTest
     public static AtomicInteger rejectedConnections = new AtomicInteger();
     public static final IInternodeAuthenticator ALLOW_NOTHING_AUTHENTICATOR = new IInternodeAuthenticator()
     {
-        public boolean authenticate(InetAddress remoteAddress, int remotePort)
+        public boolean authenticate(InetAddress remoteAddress, int remotePort,
+                                    Certificate[] certificates, InternodeConnectionDirection connectionType)
         {
             rejectedConnections.incrementAndGet();
             return false;
@@ -78,6 +84,25 @@ public class MessagingServiceTest
 
         }
     };
+
+    public static final IInternodeAuthenticator REJECT_OUTBOUND_AUTHENTICATOR = new IInternodeAuthenticator()
+    {
+        public boolean authenticate(InetAddress remoteAddress, int remotePort,
+                                    Certificate[] certificates, InternodeConnectionDirection connectionType)
+        {
+            if (connectionType == InternodeConnectionDirection.OUTBOUND)
+            {
+                rejectedConnections.incrementAndGet();
+                return false;
+            }
+            return true;
+        }
+
+        public void validateConfiguration() throws ConfigurationException
+        {
+
+        }
+    };
     private static IInternodeAuthenticator originalAuthenticator;
     private static ServerEncryptionOptions originalServerEncryptionOptions;
     private static InetAddressAndPort originalListenAddress;
@@ -228,19 +253,38 @@ public class MessagingServiceTest
     @Test
     public void testFailedOutboundInternodeAuth() throws Exception
     {
-        MessagingService ms = MessagingService.instance();
-        DatabaseDescriptor.setInternodeAuthenticator(ALLOW_NOTHING_AUTHENTICATOR);
-        InetAddressAndPort address = InetAddressAndPort.getByName("127.0.0.250");
+        // Listen on serverside for connections
+        ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions()
+        .withInternodeEncryption(ServerEncryptionOptions.InternodeEncryption.none);
+
+        DatabaseDescriptor.setInternodeAuthenticator(REJECT_OUTBOUND_AUTHENTICATOR);
+        InetAddress listenAddress = FBUtilities.getJustLocalAddress();
 
-        //Should return null
-        int rejectedBefore = rejectedConnections.get();
-        Message<?> messageOut = Message.out(Verb.ECHO_REQ, NoPayload.noPayload);
-        ms.send(messageOut, address);
-        Awaitility.await().atMost(10, TimeUnit.SECONDS).until(() -> rejectedConnections.get() > rejectedBefore);
+        InboundConnectionSettings settings = new InboundConnectionSettings().withEncryption(serverEncryptionOptions);
+        InboundSockets connections = new InboundSockets(settings);
 
-        //Should tolerate null
-        ms.closeOutbound(address);
-        ms.send(messageOut, address);
+        try
+        {
+            connections.open().await();
+            Assert.assertTrue(connections.isListening());
+
+            MessagingService ms = MessagingService.instance();
+            //Should return null
+            int rejectedBefore = rejectedConnections.get();
+            Message<?> messageOut = Message.out(Verb.ECHO_REQ, NoPayload.noPayload);
+            InetAddressAndPort address = InetAddressAndPort.getByAddress(listenAddress);
+            ms.send(messageOut, address);
+            Awaitility.await().atMost(10, TimeUnit.SECONDS).until(() -> rejectedConnections.get() > rejectedBefore);
+
+            //Should tolerate null
+            ms.closeOutbound(address);
+            ms.send(messageOut, address);
+        }
+        finally
+        {
+            connections.close().await();
+            Assert.assertFalse(connections.isListening());
+        }
     }
 
     @Test
@@ -262,6 +306,11 @@ public class MessagingServiceTest
 
             int rejectedBefore = rejectedConnections.get();
             Future<Void> connectFuture = testChannel.connect(new InetSocketAddress(listenAddress, DatabaseDescriptor.getStoragePort()));
+            Awaitility.await().atMost(10, TimeUnit.SECONDS).until(connectFuture::isDone);
+
+            // Since authentication doesn't happen during connect, try writing a dummy string which triggers
+            // authentication handler.
+            testChannel.write(ByteBufferUtil.bytes("dummy string"));
             Awaitility.await().atMost(10, TimeUnit.SECONDS).until(() -> rejectedConnections.get() > rejectedBefore);
 
             connectFuture.cancel(true);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org