You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by rs...@apache.org on 2021/02/10 09:35:22 UTC

[kafka] branch 2.5 updated: KAFKA-12193: Re-resolve IPs after a client disconnects (#9902) (#10064)

This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a commit to branch 2.5
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.5 by this push:
     new b86a6d3  KAFKA-12193: Re-resolve IPs after a client disconnects (#9902) (#10064)
b86a6d3 is described below

commit b86a6d369674bfaf79c640fb973b4afaca86374d
Author: Bob Barrett <bo...@confluent.io>
AuthorDate: Wed Feb 10 01:33:34 2021 -0800

    KAFKA-12193: Re-resolve IPs after a client disconnects (#9902) (#10064)
    
    This patch changes the NetworkClient behavior to resolve the target node's hostname after disconnecting from an established connection, rather than waiting until the previously-resolved addresses are exhausted. This is to handle the scenario when the node's IP addresses have changed during the lifetime of the connection, and means that the client does not have to try to connect to invalid IP addresses until it has tried each address.
    
    Reviewers: Mickael Maison <mi...@gmail.com>, Satish Duggana <sa...@apache.org>, David Jacot <dj...@confluent.io>
---
 .../java/org/apache/kafka/clients/ClientUtils.java |  12 +-
 .../kafka/clients/ClusterConnectionStates.java     |  27 +++-
 .../apache/kafka/clients/DefaultHostResolver.java  |  29 ++++
 .../org/apache/kafka/clients/HostResolver.java     |  26 ++++
 .../org/apache/kafka/clients/NetworkClient.java    |  73 +++++-----
 .../kafka/clients/AddressChangeHostResolver.java   |  49 +++++++
 .../org/apache/kafka/clients/ClientUtilsTest.java  |   7 +-
 .../kafka/clients/ClusterConnectionStatesTest.java |  70 ++++++---
 .../apache/kafka/clients/NetworkClientTest.java    | 157 +++++++++++++++++++++
 .../java/org/apache/kafka/test/MockSelector.java   |  11 +-
 10 files changed, 398 insertions(+), 63 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/ClientUtils.java b/clients/src/main/java/org/apache/kafka/clients/ClientUtils.java
index bcdac45..401920f 100644
--- a/clients/src/main/java/org/apache/kafka/clients/ClientUtils.java
+++ b/clients/src/main/java/org/apache/kafka/clients/ClientUtils.java
@@ -106,8 +106,9 @@ public final class ClientUtils {
                 clientSaslMechanism, time, true, logContext);
     }
 
-    static List<InetAddress> resolve(String host, ClientDnsLookup clientDnsLookup) throws UnknownHostException {
-        InetAddress[] addresses = InetAddress.getAllByName(host);
+    static List<InetAddress> resolve(String host, ClientDnsLookup clientDnsLookup,
+                                     HostResolver hostResolver) throws UnknownHostException {
+        InetAddress[] addresses = hostResolver.resolve(host);
         if (ClientDnsLookup.USE_ALL_DNS_IPS == clientDnsLookup) {
             return filterPreferredAddresses(addresses);
         } else {
@@ -115,6 +116,13 @@ public final class ClientUtils {
         }
     }
 
+    /**
+     * Return a list containing the first address in `allAddresses` and subsequent addresses
+     * that are a subtype of the first address.
+     *
+     * The outcome is that all returned addresses are either IPv4 or IPv6 (InetAddress has two
+     * subclasses: Inet4Address and Inet6Address).
+     */
     static List<InetAddress> filterPreferredAddresses(InetAddress[] allAddresses) {
         List<InetAddress> preferredAddresses = new ArrayList<>();
         Class<? extends InetAddress> clazz = null;
diff --git a/clients/src/main/java/org/apache/kafka/clients/ClusterConnectionStates.java b/clients/src/main/java/org/apache/kafka/clients/ClusterConnectionStates.java
index ec3465a..83d6f37 100644
--- a/clients/src/main/java/org/apache/kafka/clients/ClusterConnectionStates.java
+++ b/clients/src/main/java/org/apache/kafka/clients/ClusterConnectionStates.java
@@ -40,13 +40,16 @@ final class ClusterConnectionStates {
     private final double reconnectBackoffMaxExp;
     private final Map<String, NodeConnectionState> nodeState;
     private final Logger log;
+    private final HostResolver hostResolver;
 
-    public ClusterConnectionStates(long reconnectBackoffMs, long reconnectBackoffMaxMs, LogContext logContext) {
+    public ClusterConnectionStates(long reconnectBackoffMs, long reconnectBackoffMaxMs,
+                                   LogContext logContext, HostResolver hostResolver) {
         this.log = logContext.logger(ClusterConnectionStates.class);
         this.reconnectBackoffInitMs = reconnectBackoffMs;
         this.reconnectBackoffMaxMs = reconnectBackoffMaxMs;
         this.reconnectBackoffMaxExp = Math.log(this.reconnectBackoffMaxMs / (double) Math.max(reconnectBackoffMs, 1)) / Math.log(RECONNECT_BACKOFF_EXP_BASE);
         this.nodeState = new HashMap<>();
+        this.hostResolver = hostResolver;
     }
 
     /**
@@ -139,7 +142,7 @@ final class ClusterConnectionStates {
         // Create a new NodeConnectionState if nodeState does not already contain one
         // for the specified id or if the hostname associated with the node id changed.
         nodeState.put(id, new NodeConnectionState(ConnectionState.CONNECTING, now,
-            this.reconnectBackoffInitMs, host, clientDnsLookup));
+            this.reconnectBackoffInitMs, host, clientDnsLookup, hostResolver));
     }
 
     /**
@@ -158,9 +161,14 @@ final class ClusterConnectionStates {
      */
     public void disconnected(String id, long now) {
         NodeConnectionState nodeState = nodeState(id);
-        nodeState.state = ConnectionState.DISCONNECTED;
         nodeState.lastConnectAttemptMs = now;
         updateReconnectBackoff(nodeState);
+        if (nodeState.state.isConnected()) {
+            // If a connection had previously been established, clear the addresses to trigger a new DNS resolution
+            // because the node IPs may have changed
+            nodeState.clearAddresses();
+        }
+        nodeState.state = ConnectionState.DISCONNECTED;
     }
 
     /**
@@ -373,9 +381,10 @@ final class ClusterConnectionStates {
         private int addressIndex;
         private final String host;
         private final ClientDnsLookup clientDnsLookup;
+        private final HostResolver hostResolver;
 
         private NodeConnectionState(ConnectionState state, long lastConnectAttempt, long reconnectBackoffMs,
-                String host, ClientDnsLookup clientDnsLookup) {
+                String host, ClientDnsLookup clientDnsLookup, HostResolver hostResolver) {
             this.state = state;
             this.addresses = Collections.emptyList();
             this.addressIndex = -1;
@@ -386,6 +395,7 @@ final class ClusterConnectionStates {
             this.throttleUntilTimeMs = 0;
             this.host = host;
             this.clientDnsLookup = clientDnsLookup;
+            this.hostResolver = hostResolver;
         }
 
         public String host() {
@@ -400,7 +410,7 @@ final class ClusterConnectionStates {
         private InetAddress currentAddress() throws UnknownHostException {
             if (addresses.isEmpty()) {
                 // (Re-)initialize list
-                addresses = ClientUtils.resolve(host, clientDnsLookup);
+                addresses = ClientUtils.resolve(host, clientDnsLookup, hostResolver);
                 addressIndex = 0;
             }
 
@@ -420,6 +430,13 @@ final class ClusterConnectionStates {
                 addresses = Collections.emptyList(); // Exhausted list. Re-resolve on next currentAddress() call
         }
 
+        /**
+         * Clears the resolved addresses in order to trigger re-resolving on the next {@link #currentAddress()} call.
+         */
+        private void clearAddresses() {
+            addresses = Collections.emptyList();
+        }
+
         public String toString() {
             return "NodeState(" + state + ", " + lastConnectAttemptMs + ", " + failedAttempts + ", " + throttleUntilTimeMs + ")";
         }
diff --git a/clients/src/main/java/org/apache/kafka/clients/DefaultHostResolver.java b/clients/src/main/java/org/apache/kafka/clients/DefaultHostResolver.java
new file mode 100644
index 0000000..786173e
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/DefaultHostResolver.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.clients;
+
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+
+public class DefaultHostResolver implements HostResolver {
+
+    @Override
+    public InetAddress[] resolve(String host) throws UnknownHostException {
+        return InetAddress.getAllByName(host);
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/HostResolver.java b/clients/src/main/java/org/apache/kafka/clients/HostResolver.java
new file mode 100644
index 0000000..80209ca
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/HostResolver.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.clients;
+
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+
+public interface HostResolver {
+
+    InetAddress[] resolve(String host) throws UnknownHostException;
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java b/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java
index cd3a4bf..1301639 100644
--- a/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java
+++ b/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java
@@ -145,9 +145,8 @@ public class NetworkClient implements KafkaClient {
                          boolean discoverBrokerVersions,
                          ApiVersions apiVersions,
                          LogContext logContext) {
-        this(null,
+        this(selector,
              metadata,
-             selector,
              clientId,
              maxInFlightRequestsPerConnection,
              reconnectBackoffMs,
@@ -164,20 +163,20 @@ public class NetworkClient implements KafkaClient {
     }
 
     public NetworkClient(Selectable selector,
-            Metadata metadata,
-            String clientId,
-            int maxInFlightRequestsPerConnection,
-            long reconnectBackoffMs,
-            long reconnectBackoffMax,
-            int socketSendBuffer,
-            int socketReceiveBuffer,
-            int defaultRequestTimeoutMs,
-            ClientDnsLookup clientDnsLookup,
-            Time time,
-            boolean discoverBrokerVersions,
-            ApiVersions apiVersions,
-            Sensor throttleTimeSensor,
-            LogContext logContext) {
+                         Metadata metadata,
+                         String clientId,
+                         int maxInFlightRequestsPerConnection,
+                         long reconnectBackoffMs,
+                         long reconnectBackoffMax,
+                         int socketSendBuffer,
+                         int socketReceiveBuffer,
+                         int defaultRequestTimeoutMs,
+                         ClientDnsLookup clientDnsLookup,
+                         Time time,
+                         boolean discoverBrokerVersions,
+                         ApiVersions apiVersions,
+                         Sensor throttleTimeSensor,
+                         LogContext logContext) {
         this(null,
              metadata,
              selector,
@@ -193,7 +192,8 @@ public class NetworkClient implements KafkaClient {
              discoverBrokerVersions,
              apiVersions,
              throttleTimeSensor,
-             logContext);
+             logContext,
+             new DefaultHostResolver());
     }
 
     public NetworkClient(Selectable selector,
@@ -225,25 +225,27 @@ public class NetworkClient implements KafkaClient {
              discoverBrokerVersions,
              apiVersions,
              null,
-             logContext);
+             logContext,
+             new DefaultHostResolver());
     }
 
-    private NetworkClient(MetadataUpdater metadataUpdater,
-                          Metadata metadata,
-                          Selectable selector,
-                          String clientId,
-                          int maxInFlightRequestsPerConnection,
-                          long reconnectBackoffMs,
-                          long reconnectBackoffMax,
-                          int socketSendBuffer,
-                          int socketReceiveBuffer,
-                          int defaultRequestTimeoutMs,
-                          ClientDnsLookup clientDnsLookup,
-                          Time time,
-                          boolean discoverBrokerVersions,
-                          ApiVersions apiVersions,
-                          Sensor throttleTimeSensor,
-                          LogContext logContext) {
+    public NetworkClient(MetadataUpdater metadataUpdater,
+                         Metadata metadata,
+                         Selectable selector,
+                         String clientId,
+                         int maxInFlightRequestsPerConnection,
+                         long reconnectBackoffMs,
+                         long reconnectBackoffMax,
+                         int socketSendBuffer,
+                         int socketReceiveBuffer,
+                         int defaultRequestTimeoutMs,
+                         ClientDnsLookup clientDnsLookup,
+                         Time time,
+                         boolean discoverBrokerVersions,
+                         ApiVersions apiVersions,
+                         Sensor throttleTimeSensor,
+                         LogContext logContext,
+                         HostResolver hostResolver) {
         /* It would be better if we could pass `DefaultMetadataUpdater` from the public constructor, but it's not
          * possible because `DefaultMetadataUpdater` is an inner class and it can only be instantiated after the
          * super constructor is invoked.
@@ -258,7 +260,8 @@ public class NetworkClient implements KafkaClient {
         this.selector = selector;
         this.clientId = clientId;
         this.inFlightRequests = new InFlightRequests(maxInFlightRequestsPerConnection);
-        this.connectionStates = new ClusterConnectionStates(reconnectBackoffMs, reconnectBackoffMax, logContext);
+        this.connectionStates = new ClusterConnectionStates(
+                reconnectBackoffMs, reconnectBackoffMax, logContext, hostResolver);
         this.socketSendBuffer = socketSendBuffer;
         this.socketReceiveBuffer = socketReceiveBuffer;
         this.correlation = 0;
diff --git a/clients/src/test/java/org/apache/kafka/clients/AddressChangeHostResolver.java b/clients/src/test/java/org/apache/kafka/clients/AddressChangeHostResolver.java
new file mode 100644
index 0000000..28f9c88
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/clients/AddressChangeHostResolver.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients;
+
+import java.net.InetAddress;
+
+class AddressChangeHostResolver implements HostResolver {
+    private boolean useNewAddresses;
+    private InetAddress[] initialAddresses;
+    private InetAddress[] newAddresses;
+    private int resolutionCount = 0;
+
+    public AddressChangeHostResolver(InetAddress[] initialAddresses, InetAddress[] newAddresses) {
+        this.initialAddresses = initialAddresses;
+        this.newAddresses = newAddresses;
+    }
+
+    @Override
+    public InetAddress[] resolve(String host) {
+        ++resolutionCount;
+        return useNewAddresses ? newAddresses : initialAddresses;
+    }
+
+    public void changeAddresses() {
+        useNewAddresses = true;
+    }
+
+    public boolean useNewAddresses() {
+        return useNewAddresses;
+    }
+
+    public int resolutionCount() {
+        return resolutionCount;
+    }
+}
diff --git a/clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java b/clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java
index 572896f..8e6132e 100644
--- a/clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java
@@ -30,6 +30,7 @@ import org.junit.Test;
 
 public class ClientUtilsTest {
 
+    private HostResolver hostResolver = new DefaultHostResolver();
 
     @Test
     public void testParseAndValidateAddresses() throws UnknownHostException {
@@ -97,17 +98,17 @@ public class ClientUtilsTest {
 
     @Test(expected = UnknownHostException.class)
     public void testResolveUnknownHostException() throws UnknownHostException {
-        ClientUtils.resolve("some.invalid.hostname.foo.bar.local", ClientDnsLookup.DEFAULT);
+        ClientUtils.resolve("some.invalid.hostname.foo.bar.local", ClientDnsLookup.DEFAULT, hostResolver);
     }
 
     @Test
     public void testResolveDnsLookup() throws UnknownHostException {
-        assertEquals(1, ClientUtils.resolve("localhost", ClientDnsLookup.DEFAULT).size());
+        assertEquals(1, ClientUtils.resolve("localhost", ClientDnsLookup.DEFAULT, hostResolver).size());
     }
 
     @Test
     public void testResolveDnsLookupAllIps() throws UnknownHostException {
-        assertTrue(ClientUtils.resolve("kafka.apache.org", ClientDnsLookup.USE_ALL_DNS_IPS).size() > 1);
+        assertTrue(ClientUtils.resolve("kafka.apache.org", ClientDnsLookup.USE_ALL_DNS_IPS, hostResolver).size() > 1);
     }
 
     private List<InetSocketAddress> checkWithoutLookup(String... url) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java b/clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java
index fbb4497..7f5ef7f 100644
--- a/clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java
@@ -23,12 +23,13 @@ import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
-import java.lang.reflect.Field;
-import java.lang.reflect.Method;
 import java.net.InetAddress;
 import java.net.UnknownHostException;
 
+import java.util.ArrayList;
+import java.util.Arrays;
 import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
@@ -37,19 +38,47 @@ import org.junit.Test;
 
 public class ClusterConnectionStatesTest {
 
+    private static ArrayList<InetAddress> initialAddresses;
+    private static ArrayList<InetAddress> newAddresses;
+
+    static {
+        try {
+            initialAddresses = new ArrayList<>(Arrays.asList(
+                    InetAddress.getByName("10.200.20.100"),
+                    InetAddress.getByName("10.200.20.101"),
+                    InetAddress.getByName("10.200.20.102")
+            ));
+            newAddresses = new ArrayList<>(Arrays.asList(
+                    InetAddress.getByName("10.200.20.103"),
+                    InetAddress.getByName("10.200.20.104"),
+                    InetAddress.getByName("10.200.20.105")
+            ));
+        } catch (UnknownHostException e) {
+            fail("Attempted to create an invalid InetAddress, this should not happen");
+        }
+    }
+
     private final MockTime time = new MockTime();
     private final long reconnectBackoffMs = 10 * 1000;
     private final long reconnectBackoffMax = 60 * 1000;
     private final double reconnectBackoffJitter = 0.2;
     private final String nodeId1 = "1001";
     private final String nodeId2 = "2002";
-    private final String hostTwoIps = "kafka.apache.org";
-
+    private final String nodeId3 = "3003";
+    private final String hostTwoIps = "multiple.ip.address";
     private ClusterConnectionStates connectionStates;
 
+    // For testing nodes with a single IP address, use localhost and default DNS resolution
+    private DefaultHostResolver singleIPHostResolver = new DefaultHostResolver();
+
+    // For testing nodes with multiple IP addresses, mock DNS resolution to get consistent results
+    private AddressChangeHostResolver multipleIPHostResolver = new AddressChangeHostResolver(
+            initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0]));;
+
     @Before
     public void setup() {
-        this.connectionStates = new ClusterConnectionStates(reconnectBackoffMs, reconnectBackoffMax, new LogContext());
+        this.connectionStates = new ClusterConnectionStates(reconnectBackoffMs, reconnectBackoffMax,
+                new LogContext(), this.singleIPHostResolver);
     }
 
     @Test
@@ -246,7 +275,7 @@ public class ClusterConnectionStatesTest {
 
     @Test
     public void testSingleIPWithUseAll() throws UnknownHostException {
-        assertEquals(1, ClientUtils.resolve("localhost", ClientDnsLookup.USE_ALL_DNS_IPS).size());
+        assertEquals(1, ClientUtils.resolve("localhost", ClientDnsLookup.USE_ALL_DNS_IPS, singleIPHostResolver).size());
 
         connectionStates.connecting(nodeId1, time.milliseconds(), "localhost", ClientDnsLookup.USE_ALL_DNS_IPS);
         InetAddress currAddress = connectionStates.currentAddress(nodeId1);
@@ -256,7 +285,9 @@ public class ClusterConnectionStatesTest {
 
     @Test
     public void testMultipleIPsWithDefault() throws UnknownHostException {
-        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS).size() > 1);
+        setupMultipleIPs();
+
+        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS, multipleIPHostResolver).size() > 1);
 
         connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps, ClientDnsLookup.DEFAULT);
         InetAddress currAddress = connectionStates.currentAddress(nodeId1);
@@ -266,7 +297,9 @@ public class ClusterConnectionStatesTest {
 
     @Test
     public void testMultipleIPsWithUseAll() throws UnknownHostException {
-        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS).size() > 1);
+        setupMultipleIPs();
+
+        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS, multipleIPHostResolver).size() > 1);
 
         connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS);
         InetAddress addr1 = connectionStates.currentAddress(nodeId1);
@@ -280,19 +313,14 @@ public class ClusterConnectionStatesTest {
 
     @Test
     public void testHostResolveChange() throws UnknownHostException, ReflectiveOperationException {
-        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS).size() > 1);
+        setupMultipleIPs();
+
+        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS, multipleIPHostResolver).size() > 1);
 
         connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps, ClientDnsLookup.DEFAULT);
         InetAddress addr1 = connectionStates.currentAddress(nodeId1);
 
-        // reflection to simulate host change in DNS lookup
-        Method nodeStateMethod = connectionStates.getClass().getDeclaredMethod("nodeState", String.class);
-        nodeStateMethod.setAccessible(true);
-        Object nodeState = nodeStateMethod.invoke(connectionStates, nodeId1);
-        Field hostField = nodeState.getClass().getDeclaredField("host");
-        hostField.setAccessible(true);
-        hostField.set(nodeState, "localhost");
-
+        multipleIPHostResolver.changeAddresses();
         connectionStates.connecting(nodeId1, time.milliseconds(), "localhost", ClientDnsLookup.DEFAULT);
         InetAddress addr2 = connectionStates.currentAddress(nodeId1);
 
@@ -301,9 +329,12 @@ public class ClusterConnectionStatesTest {
 
     @Test
     public void testNodeWithNewHostname() throws UnknownHostException {
+        setupMultipleIPs();
+
         connectionStates.connecting(nodeId1, time.milliseconds(), "localhost", ClientDnsLookup.DEFAULT);
         InetAddress addr1 = connectionStates.currentAddress(nodeId1);
 
+        this.multipleIPHostResolver.changeAddresses();
         connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps, ClientDnsLookup.DEFAULT);
         InetAddress addr2 = connectionStates.currentAddress(nodeId1);
 
@@ -320,4 +351,9 @@ public class ClusterConnectionStatesTest {
         connectionStates.disconnected(nodeId1, time.milliseconds());
         assertFalse(connectionStates.isPreparingConnection(nodeId1));
     }
+
+    private void setupMultipleIPs() {
+        this.connectionStates = new ClusterConnectionStates(reconnectBackoffMs, reconnectBackoffMax,
+                new LogContext(), this.multipleIPHostResolver);
+    }
 }
diff --git a/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java b/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java
index c4bde51..d4cae6a 100644
--- a/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java
@@ -48,6 +48,8 @@ import org.apache.kafka.test.TestUtils;
 import org.junit.Before;
 import org.junit.Test;
 
+import java.net.InetAddress;
+import java.net.UnknownHostException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -55,6 +57,7 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Optional;
 import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
@@ -65,6 +68,7 @@ import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 public class NetworkClientTest {
 
@@ -81,6 +85,26 @@ public class NetworkClientTest {
     private final NetworkClient clientWithStaticNodes = createNetworkClientWithStaticNodes();
     private final NetworkClient clientWithNoVersionDiscovery = createNetworkClientWithNoVersionDiscovery();
 
+    private static ArrayList<InetAddress> initialAddresses;
+    private static ArrayList<InetAddress> newAddresses;
+
+    static {
+        try {
+            initialAddresses = new ArrayList<>(Arrays.asList(
+                    InetAddress.getByName("10.200.20.100"),
+                    InetAddress.getByName("10.200.20.101"),
+                    InetAddress.getByName("10.200.20.102")
+            ));
+            newAddresses = new ArrayList<>(Arrays.asList(
+                    InetAddress.getByName("10.200.20.103"),
+                    InetAddress.getByName("10.200.20.104"),
+                    InetAddress.getByName("10.200.20.105")
+            ));
+        } catch (UnknownHostException e) {
+            fail("Attempted to create an invalid InetAddress, this should not happen");
+        }
+    }
+
     private NetworkClient createNetworkClient(long reconnectBackoffMaxMs) {
         return new NetworkClient(selector, metadataUpdater, "mock", Integer.MAX_VALUE,
                 reconnectBackoffMsTest, reconnectBackoffMaxMs, 64 * 1024, 64 * 1024,
@@ -876,6 +900,139 @@ public class NetworkClientTest {
         ids.forEach(id -> assertTrue(id < SaslClientAuthenticator.MIN_RESERVED_CORRELATION_ID));
     }
 
+    @Test
+    public void testReconnectAfterAddressChange() {
+        AddressChangeHostResolver mockHostResolver = new AddressChangeHostResolver(
+                initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0]));
+        AtomicInteger initialAddressConns = new AtomicInteger();
+        AtomicInteger newAddressConns = new AtomicInteger();
+        MockSelector selector = new MockSelector(this.time, inetSocketAddress -> {
+            InetAddress inetAddress = inetSocketAddress.getAddress();
+            if (initialAddresses.contains(inetAddress)) {
+                initialAddressConns.incrementAndGet();
+            } else if (newAddresses.contains(inetAddress)) {
+                newAddressConns.incrementAndGet();
+            }
+            return (mockHostResolver.useNewAddresses() && newAddresses.contains(inetAddress)) ||
+                   (!mockHostResolver.useNewAddresses() && initialAddresses.contains(inetAddress));
+        });
+        NetworkClient client = new NetworkClient(metadataUpdater, null, selector, "mock", Integer.MAX_VALUE,
+                reconnectBackoffMsTest, reconnectBackoffMaxMsTest, 64 * 1024, 64 * 1024,
+                defaultRequestTimeoutMs, ClientDnsLookup.USE_ALL_DNS_IPS, time, false, new ApiVersions(),
+                null, new LogContext(), mockHostResolver);
+
+        // Connect to one the initial addresses, then change the addresses and disconnect
+        client.ready(node, time.milliseconds());
+        client.poll(0, time.milliseconds());
+        assertTrue(client.isReady(node, time.milliseconds()));
+
+        mockHostResolver.changeAddresses();
+        selector.serverDisconnect(node.idString());
+        client.poll(0, time.milliseconds());
+        assertFalse(client.isReady(node, time.milliseconds()));
+
+        time.sleep(reconnectBackoffMaxMsTest);
+        client.ready(node, time.milliseconds());
+        client.poll(0, time.milliseconds());
+        assertTrue(client.isReady(node, time.milliseconds()));
+
+        // We should have tried to connect to one initial address and one new address, and resolved DNS twice
+        assertEquals(1, initialAddressConns.get());
+        assertEquals(1, newAddressConns.get());
+        assertEquals(2, mockHostResolver.resolutionCount());
+    }
+
+    @Test
+    public void testFailedConnectionToFirstAddress() {
+        AddressChangeHostResolver mockHostResolver = new AddressChangeHostResolver(
+                initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0]));
+        AtomicInteger initialAddressConns = new AtomicInteger();
+        AtomicInteger newAddressConns = new AtomicInteger();
+        MockSelector selector = new MockSelector(this.time, inetSocketAddress -> {
+            InetAddress inetAddress = inetSocketAddress.getAddress();
+            if (initialAddresses.contains(inetAddress)) {
+                initialAddressConns.incrementAndGet();
+            } else if (newAddresses.contains(inetAddress)) {
+                newAddressConns.incrementAndGet();
+            }
+            // Refuse first connection attempt
+            return initialAddressConns.get() > 1;
+        });
+        NetworkClient client = new NetworkClient(metadataUpdater, null, selector, "mock", Integer.MAX_VALUE,
+                reconnectBackoffMsTest, reconnectBackoffMaxMsTest, 64 * 1024, 64 * 1024,
+                defaultRequestTimeoutMs, ClientDnsLookup.USE_ALL_DNS_IPS, time, false, new ApiVersions(),
+                null, new LogContext(), mockHostResolver);
+
+        // First connection attempt should fail
+        client.ready(node, time.milliseconds());
+        // Simulate a failed connection (this will process the state change without the connection having been established)
+        selector.serverDisconnect(node.idString());
+        client.poll(0, time.milliseconds());
+        assertFalse(client.isReady(node, time.milliseconds()));
+
+        // Second connection attempt should succeed
+        time.sleep(reconnectBackoffMaxMsTest);
+        client.ready(node, time.milliseconds());
+        client.poll(0, time.milliseconds());
+        assertTrue(client.isReady(node, time.milliseconds()));
+
+        // We should have tried to connect to two of the initial addresses, none of the new address, and should
+        // only have resolved DNS once
+        assertEquals(2, initialAddressConns.get());
+        assertEquals(0, newAddressConns.get());
+        assertEquals(1, mockHostResolver.resolutionCount());
+    }
+
+    @Test
+    public void testFailedConnectionToFirstAddressAfterReconnect() {
+        AddressChangeHostResolver mockHostResolver = new AddressChangeHostResolver(
+                initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0]));
+        AtomicInteger initialAddressConns = new AtomicInteger();
+        AtomicInteger newAddressConns = new AtomicInteger();
+        MockSelector selector = new MockSelector(this.time, inetSocketAddress -> {
+            InetAddress inetAddress = inetSocketAddress.getAddress();
+            if (initialAddresses.contains(inetAddress)) {
+                initialAddressConns.incrementAndGet();
+            } else if (newAddresses.contains(inetAddress)) {
+                newAddressConns.incrementAndGet();
+            }
+            // Refuse first connection attempt to the new addresses
+            return initialAddresses.contains(inetAddress) || newAddressConns.get() > 1;
+        });
+        NetworkClient client = new NetworkClient(metadataUpdater, null, selector, "mock", Integer.MAX_VALUE,
+                reconnectBackoffMsTest, reconnectBackoffMaxMsTest, 64 * 1024, 64 * 1024,
+                defaultRequestTimeoutMs, ClientDnsLookup.USE_ALL_DNS_IPS, time, false, new ApiVersions(),
+                null, new LogContext(), mockHostResolver);
+
+        // Connect to one the initial addresses, then change the addresses and disconnect
+        client.ready(node, time.milliseconds());
+        client.poll(0, time.milliseconds());
+        assertTrue(client.isReady(node, time.milliseconds()));
+
+        mockHostResolver.changeAddresses();
+        selector.serverDisconnect(node.idString());
+        client.poll(0, time.milliseconds());
+        assertFalse(client.isReady(node, time.milliseconds()));
+
+        // First connection attempt to new addresses should fail
+        time.sleep(reconnectBackoffMaxMsTest);
+        client.ready(node, time.milliseconds());
+        client.poll(0, time.milliseconds());
+        assertFalse(client.isReady(node, time.milliseconds()));
+
+        // Second connection attempt to new addresses should succeed
+        time.sleep(reconnectBackoffMaxMsTest);
+        client.ready(node, time.milliseconds());
+        client.poll(0, time.milliseconds());
+        assertTrue(client.isReady(node, time.milliseconds()));
+
+        // We should have tried to connect to one of the initial addresses and two of the new addresses (the first one
+        // failed), and resolved DNS twice, once for each set of addresses
+        assertEquals(1, initialAddressConns.get());
+        assertEquals(2, newAddressConns.get());
+        assertEquals(2, mockHostResolver.resolutionCount());
+    }
+
     private RequestHeader parseHeader(ByteBuffer buffer) {
         buffer.getInt(); // skip size
         return RequestHeader.parse(buffer.slice());
diff --git a/clients/src/test/java/org/apache/kafka/test/MockSelector.java b/clients/src/test/java/org/apache/kafka/test/MockSelector.java
index 90a83b0..4ebbabf 100644
--- a/clients/src/test/java/org/apache/kafka/test/MockSelector.java
+++ b/clients/src/test/java/org/apache/kafka/test/MockSelector.java
@@ -32,6 +32,7 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.function.Predicate;
 
 /**
  * A fake selector to use for testing
@@ -46,14 +47,22 @@ public class MockSelector implements Selectable {
     private final Map<String, ChannelState> disconnected = new HashMap<>();
     private final List<String> connected = new ArrayList<>();
     private final List<DelayedReceive> delayedReceives = new ArrayList<>();
+    private final Predicate<InetSocketAddress> canConnect;
 
     public MockSelector(Time time) {
+        this(time, null);
+    }
+
+    public MockSelector(Time time, Predicate<InetSocketAddress> canConnect) {
         this.time = time;
+        this.canConnect = canConnect;
     }
 
     @Override
     public void connect(String id, InetSocketAddress address, int sendBufferSize, int receiveBufferSize) throws IOException {
-        this.connected.add(id);
+        if (canConnect == null || canConnect.test(address)) {
+            this.connected.add(id);
+        }
     }
 
     @Override