You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by yc...@apache.org on 2022/10/07 22:21:41 UTC

[cassandra] branch trunk updated: Mixed mode support for internode authentication during TLS upgrades

This is an automated email from the ASF dual-hosted git repository.

ycai pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git


The following commit(s) were added to refs/heads/trunk by this push:
     new ca75ffe4d0 Mixed mode support for internode authentication during TLS upgrades
ca75ffe4d0 is described below

commit ca75ffe4d09a3e7b26a56345c0bdacaa284eaab7
Author: Jyothsna Konisa <jk...@apple.com>
AuthorDate: Fri Oct 7 10:03:16 2022 -0700

    Mixed mode support for internode authentication during TLS upgrades
    
    patch by Jyothsna Konisa; reviewed by Jon Meredith, Yifan Cai for CASSANDRA-17923
---
 CHANGES.txt                                        |   1 +
 .../cassandra/net/InternodeConnectionUtils.java    |  11 +-
 .../apache/cassandra/net/OutboundConnection.java   |  23 ++-
 .../cassandra/net/OutboundConnectionInitiator.java |  43 ++++-
 .../async/NettyStreamingConnectionFactory.java     |  45 +++--
 test/conf/cassandra_ssl_test.truststore            | Bin 992 -> 3240 bytes
 .../test/InternodeEncryptionEnforcementTest.java   |   8 +-
 .../org/apache/cassandra/net/HandshakeTest.java    | 185 ++++++++++++++++++++-
 8 files changed, 282 insertions(+), 34 deletions(-)

diff --git a/CHANGES.txt b/CHANGES.txt
index c418c48ba9..85fdae62e8 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 4.2
+ * Mixed mode support for internode authentication during TLS upgrades (CASSANDRA-17923)
  * Revert Mockito downgrade from CASSANDRA-17750 (CASSANDRA-17496)
  * Add --older-than and --older-than-timestamp options for nodetool clearsnapshots (CASSANDRA-16860)
  * Fix "open RT bound as its last item" exception (CASSANDRA-17810)
diff --git a/src/java/org/apache/cassandra/net/InternodeConnectionUtils.java b/src/java/org/apache/cassandra/net/InternodeConnectionUtils.java
index 39a087960b..fd3d1bd69e 100644
--- a/src/java/org/apache/cassandra/net/InternodeConnectionUtils.java
+++ b/src/java/org/apache/cassandra/net/InternodeConnectionUtils.java
@@ -18,6 +18,7 @@
 
 package org.apache.cassandra.net;
 
+import java.nio.channels.ClosedChannelException;
 import java.security.cert.Certificate;
 import javax.net.ssl.SSLPeerUnverifiedException;
 
@@ -33,7 +34,7 @@ import io.netty.handler.ssl.SslHandler;
 /**
  * Class that contains certificate utility methods.
  */
-class InternodeConnectionUtils
+public class InternodeConnectionUtils
 {
     public static String SSL_HANDLER_NAME = "ssl";
     public static String DISCARD_HANDLER_NAME = "discard";
@@ -59,6 +60,14 @@ class InternodeConnectionUtils
         return certificates;
     }
 
+    public static boolean isSSLError(final Throwable cause)
+    {
+        return (cause instanceof ClosedChannelException)
+               && cause.getCause() == null
+               && cause.getStackTrace()[0].getClassName().contains("SslHandler")
+               && cause.getStackTrace()[0].getMethodName().contains("channelInactive");
+    }
+
     /**
      * Discard handler releases the received data silently. when internode authentication fails, the channel is closed,
      * but the pending buffered data may still be fired through the pipeline. To avoid that, authentication handler is
diff --git a/src/java/org/apache/cassandra/net/OutboundConnection.java b/src/java/org/apache/cassandra/net/OutboundConnection.java
index 821521bfb9..2af6d3b01d 100644
--- a/src/java/org/apache/cassandra/net/OutboundConnection.java
+++ b/src/java/org/apache/cassandra/net/OutboundConnection.java
@@ -61,6 +61,7 @@ import org.apache.cassandra.utils.concurrent.UncheckedInterruptedException;
 import static java.lang.Math.max;
 import static java.lang.Math.min;
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.apache.cassandra.net.InternodeConnectionUtils.isSSLError;
 import static org.apache.cassandra.net.MessagingService.current_version;
 import static org.apache.cassandra.net.OutboundConnectionInitiator.*;
 import static org.apache.cassandra.net.OutboundConnections.LARGE_MESSAGE_THRESHOLD;
@@ -1100,8 +1101,9 @@ public class OutboundConnection
 
                 if (hasPending())
                 {
+                    boolean isSSLFailure = isSSLError(cause);
                     Promise<Result<MessagingSuccess>> result = AsyncPromise.withExecutor(eventLoop);
-                    state = new Connecting(state.disconnected(), result, eventLoop.schedule(() -> attempt(result), max(100, retryRateMillis), MILLISECONDS));
+                    state = new Connecting(state.disconnected(), result, eventLoop.schedule(() -> attempt(result, isSSLFailure), max(100, retryRateMillis), MILLISECONDS));
                     retryRateMillis = min(1000, retryRateMillis * 2);
                 }
                 else
@@ -1189,7 +1191,7 @@ public class OutboundConnection
              *
              * Note: this should only be invoked on the event loop.
              */
-            private void attempt(Promise<Result<MessagingSuccess>> result)
+            private void attempt(Promise<Result<MessagingSuccess>> result, boolean sslFallbackEnabled)
             {
                 ++connectionAttempts;
 
@@ -1216,7 +1218,20 @@ public class OutboundConnection
                 // ensure we connect to the correct SSL port
                 settings = settings.withLegacyPortIfNecessary(messagingVersion);
 
-                initiateMessaging(eventLoop, type, settings, messagingVersion, result)
+                // In mixed mode operation, some nodes might be configured to use SSL for internode connections and
+                // others might be configured to not use SSL. When a node is configured in optional SSL mode, It should
+                // be able to handle SSL and Non-SSL internode connections. We take care of this when accepting NON-SSL
+                // connection in Inbound connection by having optional SSL handler for inbound connections.
+                // For outbound connections, if the authentication fails, we should fall back to other SSL strategies
+                // while talking to older nodes in the cluster which are configured to make NON-SSL connections
+                SslFallbackConnectionType[] fallBackSslFallbackConnectionTypes = SslFallbackConnectionType.values();
+                int index = sslFallbackEnabled && settings.withEncryption() && settings.encryption.getOptional() ?
+                            (int) (connectionAttempts - 1) % fallBackSslFallbackConnectionTypes.length : 0;
+                if (fallBackSslFallbackConnectionTypes[index] != SslFallbackConnectionType.SERVER_CONFIG)
+                {
+                    logger.info("ConnectionId {} is falling back to {} reconnect strategy for retry", id(), fallBackSslFallbackConnectionTypes[index]);
+                }
+                initiateMessaging(eventLoop, type, fallBackSslFallbackConnectionTypes[index], settings, messagingVersion, result)
                 .addListener(future -> {
                     if (future.isCancelled())
                         return;
@@ -1231,7 +1246,7 @@ public class OutboundConnection
             {
                 Promise<Result<MessagingSuccess>> result = AsyncPromise.withExecutor(eventLoop);
                 state = new Connecting(state.disconnected(), result);
-                attempt(result);
+                attempt(result, false);
                 return result;
             }
         }
diff --git a/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java b/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java
index 7e38dd8812..f8df49b598 100644
--- a/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java
+++ b/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java
@@ -94,15 +94,17 @@ public class OutboundConnectionInitiator<SuccessType extends OutboundConnectionI
     private static final Logger logger = LoggerFactory.getLogger(OutboundConnectionInitiator.class);
 
     private final ConnectionType type;
+    private final SslFallbackConnectionType sslConnectionType;
     private final OutboundConnectionSettings settings;
     private final int requestMessagingVersion; // for pre40 nodes
     private final Promise<Result<SuccessType>> resultPromise;
     private boolean isClosed;
 
-    private OutboundConnectionInitiator(ConnectionType type, OutboundConnectionSettings settings,
+    private OutboundConnectionInitiator(ConnectionType type, SslFallbackConnectionType sslConnectionType, OutboundConnectionSettings settings,
                                         int requestMessagingVersion, Promise<Result<SuccessType>> resultPromise)
     {
         this.type = type;
+        this.sslConnectionType = sslConnectionType;
         this.requestMessagingVersion = requestMessagingVersion;
         this.settings = settings;
         this.resultPromise = resultPromise;
@@ -115,9 +117,10 @@ public class OutboundConnectionInitiator<SuccessType extends OutboundConnectionI
      *
      * The returned {@code Future} is guaranteed to be completed on the supplied eventLoop.
      */
-    public static Future<Result<StreamingSuccess>> initiateStreaming(EventLoop eventLoop, OutboundConnectionSettings settings, int requestMessagingVersion)
+    public static Future<Result<StreamingSuccess>> initiateStreaming(EventLoop eventLoop, OutboundConnectionSettings settings,
+                                                                     SslFallbackConnectionType sslConnectionType, int requestMessagingVersion)
     {
-        return new OutboundConnectionInitiator<StreamingSuccess>(STREAMING, settings, requestMessagingVersion, AsyncPromise.withExecutor(eventLoop))
+        return new OutboundConnectionInitiator<StreamingSuccess>(STREAMING, sslConnectionType, settings, requestMessagingVersion, AsyncPromise.withExecutor(eventLoop))
                .initiate(eventLoop);
     }
 
@@ -128,9 +131,10 @@ public class OutboundConnectionInitiator<SuccessType extends OutboundConnectionI
      *
      * The returned {@code Future} is guaranteed to be completed on the supplied eventLoop.
      */
-    static Future<Result<MessagingSuccess>> initiateMessaging(EventLoop eventLoop, ConnectionType type, OutboundConnectionSettings settings, int requestMessagingVersion, Promise<Result<MessagingSuccess>> result)
+    static Future<Result<MessagingSuccess>> initiateMessaging(EventLoop eventLoop, ConnectionType type, SslFallbackConnectionType sslConnectionType,
+                                                              OutboundConnectionSettings settings, int requestMessagingVersion, Promise<Result<MessagingSuccess>> result)
     {
-        return new OutboundConnectionInitiator<>(type, settings, requestMessagingVersion, result)
+        return new OutboundConnectionInitiator<>(type, sslConnectionType, settings, requestMessagingVersion, result)
                .initiate(eventLoop);
     }
 
@@ -202,6 +206,14 @@ public class OutboundConnectionInitiator<SuccessType extends OutboundConnectionI
         return bootstrap;
     }
 
+    public enum SslFallbackConnectionType
+    {
+        SERVER_CONFIG, // Original configuration of the server
+        MTLS,
+        SSL,
+        NO_SSL
+    }
+
     private class Initializer extends ChannelInitializer<SocketChannel>
     {
         public void initChannel(SocketChannel channel) throws Exception
@@ -209,11 +221,10 @@ public class OutboundConnectionInitiator<SuccessType extends OutboundConnectionI
             ChannelPipeline pipeline = channel.pipeline();
 
             // order of handlers: ssl -> server-authentication -> logger -> handshakeHandler
-            if (settings.withEncryption())
+            if ((sslConnectionType == SslFallbackConnectionType.SERVER_CONFIG && settings.withEncryption())
+                || sslConnectionType == SslFallbackConnectionType.SSL || sslConnectionType == SslFallbackConnectionType.MTLS)
             {
-                // check if we should actually encrypt this connection
-                SslContext sslContext = SSLFactory.getOrCreateSslContext(settings.encryption, true,
-                                                                         ISslContextFactory.SocketType.CLIENT);
+                SslContext sslContext = getSslContext(sslConnectionType);
                 // for some reason channel.remoteAddress() will return null
                 InetAddressAndPort address = settings.to;
                 InetSocketAddress peer = settings.encryption.require_endpoint_verification ? new InetSocketAddress(address.getAddress(), address.getPort()) : null;
@@ -229,6 +240,20 @@ public class OutboundConnectionInitiator<SuccessType extends OutboundConnectionI
             pipeline.addLast("handshake", new Handler());
         }
 
+        private SslContext getSslContext(SslFallbackConnectionType connectionType) throws IOException
+        {
+            boolean requireClientAuth = false;
+            if (connectionType == SslFallbackConnectionType.MTLS || connectionType == SslFallbackConnectionType.SSL)
+            {
+                requireClientAuth = true;
+            }
+            else if (connectionType == SslFallbackConnectionType.SERVER_CONFIG)
+            {
+                requireClientAuth = settings.withEncryption();
+            }
+            return SSLFactory.getOrCreateSslContext(settings.encryption, requireClientAuth, ISslContextFactory.SocketType.CLIENT);
+        }
+
     }
 
     /**
diff --git a/src/java/org/apache/cassandra/streaming/async/NettyStreamingConnectionFactory.java b/src/java/org/apache/cassandra/streaming/async/NettyStreamingConnectionFactory.java
index 6a57e395e4..529b396367 100644
--- a/src/java/org/apache/cassandra/streaming/async/NettyStreamingConnectionFactory.java
+++ b/src/java/org/apache/cassandra/streaming/async/NettyStreamingConnectionFactory.java
@@ -20,6 +20,9 @@ package org.apache.cassandra.streaming.async;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
 
 import com.google.common.annotations.VisibleForTesting;
 
@@ -35,7 +38,10 @@ import org.apache.cassandra.net.OutboundConnectionSettings;
 import org.apache.cassandra.streaming.StreamingChannel;
 
 import static org.apache.cassandra.locator.InetAddressAndPort.getByAddress;
+import static org.apache.cassandra.net.InternodeConnectionUtils.isSSLError;
 import static org.apache.cassandra.net.OutboundConnectionInitiator.initiateStreaming;
+import static org.apache.cassandra.net.OutboundConnectionInitiator.SslFallbackConnectionType;
+import static org.apache.cassandra.net.OutboundConnectionInitiator.SslFallbackConnectionType.SERVER_CONFIG;
 
 public class NettyStreamingConnectionFactory implements StreamingChannel.Factory
 {
@@ -45,27 +51,38 @@ public class NettyStreamingConnectionFactory implements StreamingChannel.Factory
     public static NettyStreamingChannel connect(OutboundConnectionSettings template, int messagingVersion, StreamingChannel.Kind kind) throws IOException
     {
         EventLoop eventLoop = MessagingService.instance().socketFactory.outboundStreamingGroup().next();
+        OutboundConnectionSettings settings = template.withDefaults(ConnectionCategory.STREAMING);
+        List<SslFallbackConnectionType> sslFallbacks = settings.withEncryption() && settings.encryption.getOptional()
+                                                       ? Arrays.asList(SslFallbackConnectionType.values())
+                                                       : Collections.singletonList(SERVER_CONFIG);
 
-        int attempts = 0;
-        while (true)
+        Throwable cause = null;
+        for (final SslFallbackConnectionType sslFallbackConnectionType : sslFallbacks)
         {
-            Future<Result<StreamingSuccess>> result = initiateStreaming(eventLoop, template.withDefaults(ConnectionCategory.STREAMING), messagingVersion);
-            result.awaitUninterruptibly(); // initiate has its own timeout, so this is "guaranteed" to return relatively promptly
-            if (result.isSuccess())
+            for (int i = 0; i < MAX_CONNECT_ATTEMPTS; i++)
             {
-                Channel channel = result.getNow().success().channel;
-                NettyStreamingChannel streamingChannel = new NettyStreamingChannel(messagingVersion, channel, kind);
-                if (kind == StreamingChannel.Kind.CONTROL)
+                Future<Result<StreamingSuccess>> result = initiateStreaming(eventLoop, settings, sslFallbackConnectionType, messagingVersion);
+                result.awaitUninterruptibly(); // initiate has its own timeout, so this is "guaranteed" to return relatively promptly
+                if (result.isSuccess())
                 {
-                    ChannelPipeline pipeline = channel.pipeline();
-                    pipeline.addLast("stream", streamingChannel);
+                    Channel channel = result.getNow().success().channel;
+                    NettyStreamingChannel streamingChannel = new NettyStreamingChannel(messagingVersion, channel, kind);
+                    if (kind == StreamingChannel.Kind.CONTROL)
+                    {
+                        ChannelPipeline pipeline = channel.pipeline();
+                        pipeline.addLast("stream", streamingChannel);
+                    }
+                    return streamingChannel;
                 }
-                return streamingChannel;
+                cause = result.cause();
+            }
+            if (!isSSLError(cause))
+            {
+                // Fallback only when the error is SSL related, otherwise retries are exhausted, so fail
+                break;
             }
-
-            if (++attempts == MAX_CONNECT_ATTEMPTS)
-                throw new IOException("failed to connect to " + template.to + " for streaming data", result.cause());
         }
+        throw new IOException("failed to connect to " + template.to + " for streaming data", cause);
     }
 
     @Override
diff --git a/test/conf/cassandra_ssl_test.truststore b/test/conf/cassandra_ssl_test.truststore
index 49cf3323e5..5ba9a9977c 100644
Binary files a/test/conf/cassandra_ssl_test.truststore and b/test/conf/cassandra_ssl_test.truststore differ
diff --git a/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java b/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
index 156a6b4b64..d13e2e4a0c 100644
--- a/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
+++ b/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
@@ -189,16 +189,18 @@ public final class InternodeEncryptionEnforcementTest extends TestBaseImpl
                 c.with(Feature.NETWORK);
                 c.with(Feature.NATIVE_PROTOCOL);
 
+                HashMap<String, Object> encryption = new HashMap<>();
+                encryption.put("optional", "false");
+                encryption.put("internode_encryption", "none");
                 if (c.num() == 1)
                 {
-                    HashMap<String, Object> encryption = new HashMap<>();
                     encryption.put("keystore", "test/conf/cassandra_ssl_test.keystore");
                     encryption.put("keystore_password", "cassandra");
                     encryption.put("truststore", "test/conf/cassandra_ssl_test.truststore");
                     encryption.put("truststore_password", "cassandra");
-                    encryption.put("internode_encryption", "dc");
-                    c.set("server_encryption_options", encryption);
+                    encryption.put("internode_encryption", "all");
                 }
+                c.set("server_encryption_options", encryption);
             })
             .withNodeIdTopology(ImmutableMap.of(1, NetworkTopology.dcAndRack("dc1", "r1a"),
                                                 2, NetworkTopology.dcAndRack("dc2", "r2a")));
diff --git a/test/unit/org/apache/cassandra/net/HandshakeTest.java b/test/unit/org/apache/cassandra/net/HandshakeTest.java
index 75ae1034c5..6a0f7d379a 100644
--- a/test/unit/org/apache/cassandra/net/HandshakeTest.java
+++ b/test/unit/org/apache/cassandra/net/HandshakeTest.java
@@ -19,10 +19,20 @@
 package org.apache.cassandra.net;
 
 import java.nio.channels.ClosedChannelException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
 import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 
+import com.google.common.net.InetAddresses;
+
+import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions;
+import org.apache.cassandra.config.ParameterizedClass;
+import org.apache.cassandra.gms.GossipDigestSyn;
+import org.apache.cassandra.security.DefaultSslContextFactory;
 import org.apache.cassandra.utils.concurrent.AsyncPromise;
 import org.junit.AfterClass;
 import org.junit.Assert;
@@ -42,11 +52,15 @@ import static org.apache.cassandra.net.MessagingService.current_version;
 import static org.apache.cassandra.net.MessagingService.minimum_version;
 import static org.apache.cassandra.net.ConnectionType.SMALL_MESSAGES;
 import static org.apache.cassandra.net.OutboundConnectionInitiator.*;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 
 // TODO: test failure due to exception, timeout, etc
 public class HandshakeTest
 {
     private static final SocketFactory factory = new SocketFactory();
+    static final InetAddressAndPort TO_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 7012);
+    static final InetAddressAndPort FROM_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 7012);
 
     @BeforeClass
     public static void startup()
@@ -80,6 +94,7 @@ public class HandshakeTest
             Future<Result<MessagingSuccess>> future =
             initiateMessaging(eventLoop,
                               SMALL_MESSAGES,
+                              SslFallbackConnectionType.SERVER_CONFIG,
                               new OutboundConnectionSettings(endpoint)
                                                     .withAcceptVersions(acceptOutbound)
                                                     .withDefaults(ConnectionCategory.MESSAGING),
@@ -92,6 +107,7 @@ public class HandshakeTest
         }
     }
 
+
     @Test
     public void testBothCurrentVersion() throws InterruptedException, ExecutionException
     {
@@ -172,7 +188,7 @@ public class HandshakeTest
         }
         catch (ExecutionException e)
         {
-            Assert.assertTrue(e.getCause() instanceof ClosedChannelException);
+            assertTrue(e.getCause() instanceof ClosedChannelException);
         }
     }
 
@@ -186,7 +202,7 @@ public class HandshakeTest
         }
         catch (ExecutionException e)
         {
-            Assert.assertTrue(e.getCause() instanceof ClosedChannelException);
+            assertTrue(e.getCause() instanceof ClosedChannelException);
         }
     }
 
@@ -207,7 +223,7 @@ public class HandshakeTest
         }
         catch (ExecutionException e)
         {
-            Assert.assertTrue(e.getCause() instanceof ClosedChannelException);
+            assertTrue(e.getCause() instanceof ClosedChannelException);
         }
     }
 
@@ -218,4 +234,167 @@ public class HandshakeTest
         Assert.assertEquals(Result.Outcome.SUCCESS, result.outcome);
         Assert.assertEquals(VERSION_30, result.success().messagingVersion);
     }
+
+    @Test
+    public void testOutboundConnectionfFallbackDuringUpgrades() throws ClosedChannelException, InterruptedException
+    {
+        // Upgrade from Non-SSL -> Optional SSL
+        // Outbound connection from Optional SSL(new node) -> Non-SSL (old node)
+        testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.SSL, true, SslFallbackConnectionType.NO_SSL, false);
+
+        // Upgrade from Optional SSL -> Strict SSL
+        // Outbound connection from Strict SSL(new node) -> Optional SSL (old node)
+        testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.SSL, false, SslFallbackConnectionType.SSL, true);
+
+        // Upgrade from Optional SSL -> Strict MTLS
+        // Outbound connection from Strict MTLS(new node) -> Optional SSL (old node)
+        testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.MTLS, false, SslFallbackConnectionType.SSL, true);
+
+        // Upgrade from Strict SSL -> Optional MTLS
+        // Outbound connection from Optional MTLS(new node) -> Strict SSL (old node)
+        testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.MTLS, true, SslFallbackConnectionType.SSL, false);
+
+        // Upgrade from Strict Optional MTLS -> Strict MTLS
+        // Outbound connection from Strict TLS(new node) -> Optional TLS (old node)
+        testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.MTLS, false, SslFallbackConnectionType.MTLS, true);
+    }
+
+    @Test
+    public void testOutboundConnectionfFallbackDuringDowngrades() throws ClosedChannelException, InterruptedException
+    {
+        // From Strict MTLS -> Optional MTLS
+        // Outbound connection from Optional TLS(new node) -> Strict MTLS (old node)
+        testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.MTLS, true, SslFallbackConnectionType.MTLS, false);
+
+        // From Optional MTLS -> Strict SSL
+        // Outbound connection from Strict SSL(new node) -> Optional MTLS (old node)
+        testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.SSL, false, SslFallbackConnectionType.MTLS, true);
+
+        // From Strict MTLS -> Optional SSL
+        // Outbound connection from Optional SSL(new node) -> Strict MTLS (old node)
+        testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.SSL, true, SslFallbackConnectionType.MTLS, false);
+
+        // From Strict SSL -> Optional SSL
+        // Outbound connection from Optional SSL(new node) -> Strict SSL (old node)
+        testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.SSL, true, SslFallbackConnectionType.SSL, false);
+
+        // From Optional SSL -> Non-SSL
+        // Outbound connection from Non-SSL(new node) -> Optional SSL (old node)
+        testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.NO_SSL, false, SslFallbackConnectionType.SSL, true);
+    }
+
+    @Test
+    public void testOutboundConnectionDoesntFallbackWhenErrorIsNotSSLRelated() throws ClosedChannelException, InterruptedException
+    {
+        // Configuring nodes in Optional SSL mode
+        // when optional mode is enabled, if the connection error is SSL related, fallback to another SSL strategy should happen,
+        // otherwise it should use same SSL strategy and retry
+        ServerEncryptionOptions serverEncryptionOptions = getServerEncryptionOptions(SslFallbackConnectionType.SSL, true);
+        InboundSockets inbound = getInboundSocket(serverEncryptionOptions);
+        try
+        {
+            InetAddressAndPort endpoint = inbound.sockets().stream().map(s -> s.settings.bindAddress).findFirst().get();
+
+            // Open outbound connections before server starts listening
+            // The connection should be accepted after opening inbound connections, with the same SSL context without fallback
+            OutboundConnection outboundConnection = initiateOutbound(endpoint, SslFallbackConnectionType.SSL, true);
+
+            // Let the outbound connection be tried for 4 times atleast
+            while (outboundConnection.connectionAttempts() < SslFallbackConnectionType.values().length)
+            {
+                Thread.sleep(1000);
+            }
+            assertFalse(outboundConnection.isConnected());
+            inbound.open();
+            // As soon as the node accepts inbound connections, the connection must be established with right SSL context
+            waitForConnection(outboundConnection);
+            assertTrue(outboundConnection.isConnected());
+            assertFalse(outboundConnection.hasPending());
+        }
+        finally
+        {
+            inbound.close().await(10L, TimeUnit.SECONDS);
+        }
+    }
+
+    private ServerEncryptionOptions getServerEncryptionOptions(SslFallbackConnectionType sslConnectionType, boolean optional)
+    {
+        ServerEncryptionOptions serverEncryptionOptions = new ServerEncryptionOptions().withOptional(optional)
+                                                                                       .withKeyStore("test/conf/cassandra_ssl_test.keystore")
+                                                                                       .withKeyStorePassword("cassandra")
+                                                                                       .withOutboundKeystore("test/conf/cassandra_ssl_test_outbound.keystore")
+                                                                                       .withOutboundKeystorePassword("cassandra")
+                                                                                       .withTrustStore("test/conf/cassandra_ssl_test.truststore")
+                                                                                       .withTrustStorePassword("cassandra")
+                                                                                       .withSslContextFactory((new ParameterizedClass(DefaultSslContextFactory.class.getName(),
+                                                                                                                                      new HashMap<>())));
+        if (sslConnectionType == SslFallbackConnectionType.MTLS)
+        {
+            serverEncryptionOptions = serverEncryptionOptions.withInternodeEncryption(ServerEncryptionOptions.InternodeEncryption.all)
+                                                             .withRequireClientAuth(true);
+        }
+        else if (sslConnectionType == SslFallbackConnectionType.SSL)
+        {
+            serverEncryptionOptions = serverEncryptionOptions.withInternodeEncryption(ServerEncryptionOptions.InternodeEncryption.all)
+                                                             .withRequireClientAuth(false);
+        }
+        return serverEncryptionOptions;
+    }
+
+    private InboundSockets getInboundSocket(ServerEncryptionOptions serverEncryptionOptions)
+    {
+        InboundConnectionSettings settings = new InboundConnectionSettings().withAcceptMessaging(new AcceptVersions(minimum_version, current_version))
+                                                                            .withEncryption(serverEncryptionOptions)
+                                                                            .withBindAddress(TO_ADDR);
+        List<InboundConnectionSettings> settingsList =  new ArrayList<>();
+        settingsList.add(settings);
+        return new InboundSockets(settingsList);
+    }
+
+    private OutboundConnection initiateOutbound(InetAddressAndPort endpoint, SslFallbackConnectionType connectionType, boolean optional) throws ClosedChannelException
+    {
+        final OutboundConnectionSettings settings = new OutboundConnectionSettings(endpoint)
+        .withAcceptVersions(new AcceptVersions(minimum_version, current_version))
+        .withDefaults(ConnectionCategory.MESSAGING)
+        .withEncryption(getServerEncryptionOptions(connectionType, optional))
+        .withFrom(FROM_ADDR);
+        OutboundConnections outboundConnections = OutboundConnections.tryRegister(new ConcurrentHashMap<>(), TO_ADDR, settings);
+        GossipDigestSyn syn = new GossipDigestSyn("cluster", "partitioner", new ArrayList<>(0));
+        Message<GossipDigestSyn> message = Message.out(Verb.GOSSIP_DIGEST_SYN, syn);
+        OutboundConnection outboundConnection = outboundConnections.connectionFor(message);
+        outboundConnection.enqueue(message);
+        outboundConnection.initiate();
+        return outboundConnection;
+    }
+
+    private void testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType fromConnectionType, boolean fromOptional,
+                                                           SslFallbackConnectionType toConnectionType, boolean toOptional) throws ClosedChannelException, InterruptedException
+    {
+        // Configures inbound connections to be optional mTLS
+        InboundSockets inbound = getInboundSocket(getServerEncryptionOptions(toConnectionType, toOptional));
+        try
+        {
+            InetAddressAndPort endpoint = inbound.sockets().stream().map(s -> s.settings.bindAddress).findFirst().get();
+            inbound.open();
+
+            // Open outbound connections, and wait until connection is established
+            OutboundConnection outboundConnection = initiateOutbound(endpoint, fromConnectionType, fromOptional);
+            waitForConnection(outboundConnection);
+            assertTrue(outboundConnection.isConnected());
+            assertFalse(outboundConnection.hasPending());
+        }
+        finally
+        {
+            inbound.close().await(10L, TimeUnit.SECONDS);
+        }
+    }
+
+    private void waitForConnection(OutboundConnection outboundConnection) throws InterruptedException
+    {
+        long startTime = System.currentTimeMillis();
+        while (!outboundConnection.isConnected() && System.currentTimeMillis() - startTime < 60000)
+        {
+            Thread.sleep(1000);
+        }
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org