You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by da...@apache.org on 2021/02/04 14:48:45 UTC

[kafka] branch 2.7 updated: KAFKA-12193: Re-resolve IPs after a client disconnects (#9902)

This is an automated email from the ASF dual-hosted git repository.

dajac pushed a commit to branch 2.7
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.7 by this push:
     new 8ab22e4  KAFKA-12193: Re-resolve IPs after a client disconnects (#9902)
8ab22e4 is described below

commit 8ab22e4a047624c1336f4cf7364d1e71a4304cd5
Author: Bob Barrett <bo...@confluent.io>
AuthorDate: Thu Feb 4 03:42:43 2021 -0800

    KAFKA-12193: Re-resolve IPs after a client disconnects (#9902)
    
    This patch changes the NetworkClient behavior to resolve the target node's hostname after disconnecting from an established connection, rather than waiting until the previously-resolved addresses are exhausted. This is to handle the scenario when the node's IP addresses have changed during the lifetime of the connection, and means that the client does not have to try to connect to invalid IP addresses until it has tried each address.
    
    Reviewers: Mickael Maison <mi...@gmail.com>, Satish Duggana <sa...@apache.org>, David Jacot <dj...@confluent.io>
    
    Conflicts (related to junit upgrade):
    	clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java
    	clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java
    	clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java
---
 .../java/org/apache/kafka/clients/ClientUtils.java |   5 +-
 .../kafka/clients/ClusterConnectionStates.java     |  26 +++-
 .../apache/kafka/clients/DefaultHostResolver.java  |  29 ++++
 .../org/apache/kafka/clients/HostResolver.java     |  26 ++++
 .../org/apache/kafka/clients/NetworkClient.java    |  80 +++++-----
 .../kafka/clients/AddressChangeHostResolver.java   |  49 +++++++
 .../org/apache/kafka/clients/ClientUtilsTest.java  |   9 +-
 .../kafka/clients/ClusterConnectionStatesTest.java |  71 ++++++---
 .../apache/kafka/clients/NetworkClientTest.java    | 162 +++++++++++++++++++++
 .../java/org/apache/kafka/test/MockSelector.java   |  11 +-
 10 files changed, 399 insertions(+), 69 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/ClientUtils.java b/clients/src/main/java/org/apache/kafka/clients/ClientUtils.java
index 5e5286e..5adb7e3 100644
--- a/clients/src/main/java/org/apache/kafka/clients/ClientUtils.java
+++ b/clients/src/main/java/org/apache/kafka/clients/ClientUtils.java
@@ -106,8 +106,9 @@ public final class ClientUtils {
                 clientSaslMechanism, time, true, logContext);
     }
 
-    static List<InetAddress> resolve(String host, ClientDnsLookup clientDnsLookup) throws UnknownHostException {
-        InetAddress[] addresses = InetAddress.getAllByName(host);
+    static List<InetAddress> resolve(String host, ClientDnsLookup clientDnsLookup,
+                                     HostResolver hostResolver) throws UnknownHostException {
+        InetAddress[] addresses = hostResolver.resolve(host);
 
         switch (clientDnsLookup) {
             case DEFAULT:
diff --git a/clients/src/main/java/org/apache/kafka/clients/ClusterConnectionStates.java b/clients/src/main/java/org/apache/kafka/clients/ClusterConnectionStates.java
index 39d8b16..16bf59ac 100644
--- a/clients/src/main/java/org/apache/kafka/clients/ClusterConnectionStates.java
+++ b/clients/src/main/java/org/apache/kafka/clients/ClusterConnectionStates.java
@@ -43,13 +43,14 @@ final class ClusterConnectionStates {
     final static double CONNECTION_SETUP_TIMEOUT_JITTER = 0.2;
     private final Map<String, NodeConnectionState> nodeState;
     private final Logger log;
+    private final HostResolver hostResolver;
     private Set<String> connectingNodes;
     private ExponentialBackoff reconnectBackoff;
     private ExponentialBackoff connectionSetupTimeout;
 
     public ClusterConnectionStates(long reconnectBackoffMs, long reconnectBackoffMaxMs,
                                    long connectionSetupTimeoutMs, long connectionSetupTimeoutMaxMs,
-                                   LogContext logContext) {
+                                   LogContext logContext, HostResolver hostResolver) {
         this.log = logContext.logger(ClusterConnectionStates.class);
         this.reconnectBackoff = new ExponentialBackoff(
                 reconnectBackoffMs,
@@ -63,6 +64,7 @@ final class ClusterConnectionStates {
                 CONNECTION_SETUP_TIMEOUT_JITTER);
         this.nodeState = new HashMap<>();
         this.connectingNodes = new HashSet<>();
+        this.hostResolver = hostResolver;
     }
 
     /**
@@ -156,7 +158,8 @@ final class ClusterConnectionStates {
         // Create a new NodeConnectionState if nodeState does not already contain one
         // for the specified id or if the hostname associated with the node id changed.
         nodeState.put(id, new NodeConnectionState(ConnectionState.CONNECTING, now,
-            reconnectBackoff.backoff(0), connectionSetupTimeout.backoff(0), host, clientDnsLookup));
+                reconnectBackoff.backoff(0), connectionSetupTimeout.backoff(0), host,
+                clientDnsLookup, hostResolver));
         connectingNodes.add(id);
     }
 
@@ -183,6 +186,11 @@ final class ClusterConnectionStates {
             connectingNodes.remove(id);
         } else {
             resetConnectionSetupTimeout(nodeState);
+            if (nodeState.state.isConnected()) {
+                // If a connection had previously been established, clear the addresses to trigger a new DNS resolution
+                // because the node IPs may have changed
+                nodeState.clearAddresses();
+            }
         }
         nodeState.state = ConnectionState.DISCONNECTED;
     }
@@ -470,9 +478,11 @@ final class ClusterConnectionStates {
         private int addressIndex;
         private final String host;
         private final ClientDnsLookup clientDnsLookup;
+        private final HostResolver hostResolver;
 
         private NodeConnectionState(ConnectionState state, long lastConnectAttempt, long reconnectBackoffMs,
-                long connectionSetupTimeoutMs, String host, ClientDnsLookup clientDnsLookup) {
+                long connectionSetupTimeoutMs, String host, ClientDnsLookup clientDnsLookup,
+                HostResolver hostResolver) {
             this.state = state;
             this.addresses = Collections.emptyList();
             this.addressIndex = -1;
@@ -484,6 +494,7 @@ final class ClusterConnectionStates {
             this.throttleUntilTimeMs = 0;
             this.host = host;
             this.clientDnsLookup = clientDnsLookup;
+            this.hostResolver = hostResolver;
         }
 
         public String host() {
@@ -498,7 +509,7 @@ final class ClusterConnectionStates {
         private InetAddress currentAddress() throws UnknownHostException {
             if (addresses.isEmpty()) {
                 // (Re-)initialize list
-                addresses = ClientUtils.resolve(host, clientDnsLookup);
+                addresses = ClientUtils.resolve(host, clientDnsLookup, hostResolver);
                 addressIndex = 0;
             }
 
@@ -518,6 +529,13 @@ final class ClusterConnectionStates {
                 addresses = Collections.emptyList(); // Exhausted list. Re-resolve on next currentAddress() call
         }
 
+        /**
+         * Clears the resolved addresses in order to trigger re-resolving on the next {@link #currentAddress()} call.
+         */
+        private void clearAddresses() {
+            addresses = Collections.emptyList();
+        }
+
         public String toString() {
             return "NodeState(" + state + ", " + lastConnectAttemptMs + ", " + failedAttempts + ", " + throttleUntilTimeMs + ")";
         }
diff --git a/clients/src/main/java/org/apache/kafka/clients/DefaultHostResolver.java b/clients/src/main/java/org/apache/kafka/clients/DefaultHostResolver.java
new file mode 100644
index 0000000..786173e
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/DefaultHostResolver.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.clients;
+
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+
+public class DefaultHostResolver implements HostResolver {
+
+    @Override
+    public InetAddress[] resolve(String host) throws UnknownHostException {
+        return InetAddress.getAllByName(host);
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/HostResolver.java b/clients/src/main/java/org/apache/kafka/clients/HostResolver.java
new file mode 100644
index 0000000..80209ca
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/HostResolver.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.clients;
+
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+
+public interface HostResolver {
+
+    InetAddress[] resolve(String host) throws UnknownHostException;
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java b/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java
index 5287124..f316be1 100644
--- a/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java
+++ b/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java
@@ -148,9 +148,8 @@ public class NetworkClient implements KafkaClient {
                          boolean discoverBrokerVersions,
                          ApiVersions apiVersions,
                          LogContext logContext) {
-        this(null,
+        this(selector,
              metadata,
-             selector,
              clientId,
              maxInFlightRequestsPerConnection,
              reconnectBackoffMs,
@@ -169,22 +168,22 @@ public class NetworkClient implements KafkaClient {
     }
 
     public NetworkClient(Selectable selector,
-            Metadata metadata,
-            String clientId,
-            int maxInFlightRequestsPerConnection,
-            long reconnectBackoffMs,
-            long reconnectBackoffMax,
-            int socketSendBuffer,
-            int socketReceiveBuffer,
-            int defaultRequestTimeoutMs,
-            long connectionSetupTimeoutMs,
-            long connectionSetupTimeoutMaxMs,
-            ClientDnsLookup clientDnsLookup,
-            Time time,
-            boolean discoverBrokerVersions,
-            ApiVersions apiVersions,
-            Sensor throttleTimeSensor,
-            LogContext logContext) {
+                         Metadata metadata,
+                         String clientId,
+                         int maxInFlightRequestsPerConnection,
+                         long reconnectBackoffMs,
+                         long reconnectBackoffMax,
+                         int socketSendBuffer,
+                         int socketReceiveBuffer,
+                         int defaultRequestTimeoutMs,
+                         long connectionSetupTimeoutMs,
+                         long connectionSetupTimeoutMaxMs,
+                         ClientDnsLookup clientDnsLookup,
+                         Time time,
+                         boolean discoverBrokerVersions,
+                         ApiVersions apiVersions,
+                         Sensor throttleTimeSensor,
+                         LogContext logContext) {
         this(null,
              metadata,
              selector,
@@ -202,7 +201,8 @@ public class NetworkClient implements KafkaClient {
              discoverBrokerVersions,
              apiVersions,
              throttleTimeSensor,
-             logContext);
+             logContext,
+             new DefaultHostResolver());
     }
 
     public NetworkClient(Selectable selector,
@@ -238,27 +238,29 @@ public class NetworkClient implements KafkaClient {
              discoverBrokerVersions,
              apiVersions,
              null,
-             logContext);
+             logContext,
+             new DefaultHostResolver());
     }
 
-    private NetworkClient(MetadataUpdater metadataUpdater,
-                          Metadata metadata,
-                          Selectable selector,
-                          String clientId,
-                          int maxInFlightRequestsPerConnection,
-                          long reconnectBackoffMs,
-                          long reconnectBackoffMax,
-                          int socketSendBuffer,
-                          int socketReceiveBuffer,
-                          int defaultRequestTimeoutMs,
-                          long connectionSetupTimeoutMs,
-                          long connectionSetupTimeoutMaxMs,
-                          ClientDnsLookup clientDnsLookup,
-                          Time time,
-                          boolean discoverBrokerVersions,
-                          ApiVersions apiVersions,
-                          Sensor throttleTimeSensor,
-                          LogContext logContext) {
+    public NetworkClient(MetadataUpdater metadataUpdater,
+                         Metadata metadata,
+                         Selectable selector,
+                         String clientId,
+                         int maxInFlightRequestsPerConnection,
+                         long reconnectBackoffMs,
+                         long reconnectBackoffMax,
+                         int socketSendBuffer,
+                         int socketReceiveBuffer,
+                         int defaultRequestTimeoutMs,
+                         long connectionSetupTimeoutMs,
+                         long connectionSetupTimeoutMaxMs,
+                         ClientDnsLookup clientDnsLookup,
+                         Time time,
+                         boolean discoverBrokerVersions,
+                         ApiVersions apiVersions,
+                         Sensor throttleTimeSensor,
+                         LogContext logContext,
+                         HostResolver hostResolver) {
         /* It would be better if we could pass `DefaultMetadataUpdater` from the public constructor, but it's not
          * possible because `DefaultMetadataUpdater` is an inner class and it can only be instantiated after the
          * super constructor is invoked.
@@ -275,7 +277,7 @@ public class NetworkClient implements KafkaClient {
         this.inFlightRequests = new InFlightRequests(maxInFlightRequestsPerConnection);
         this.connectionStates = new ClusterConnectionStates(
                 reconnectBackoffMs, reconnectBackoffMax,
-                connectionSetupTimeoutMs, connectionSetupTimeoutMaxMs, logContext);
+                connectionSetupTimeoutMs, connectionSetupTimeoutMaxMs, logContext, hostResolver);
         this.socketSendBuffer = socketSendBuffer;
         this.socketReceiveBuffer = socketReceiveBuffer;
         this.correlation = 0;
diff --git a/clients/src/test/java/org/apache/kafka/clients/AddressChangeHostResolver.java b/clients/src/test/java/org/apache/kafka/clients/AddressChangeHostResolver.java
new file mode 100644
index 0000000..28f9c88
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/clients/AddressChangeHostResolver.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients;
+
+import java.net.InetAddress;
+
+class AddressChangeHostResolver implements HostResolver {
+    private boolean useNewAddresses;
+    private InetAddress[] initialAddresses;
+    private InetAddress[] newAddresses;
+    private int resolutionCount = 0;
+
+    public AddressChangeHostResolver(InetAddress[] initialAddresses, InetAddress[] newAddresses) {
+        this.initialAddresses = initialAddresses;
+        this.newAddresses = newAddresses;
+    }
+
+    @Override
+    public InetAddress[] resolve(String host) {
+        ++resolutionCount;
+        return useNewAddresses ? newAddresses : initialAddresses;
+    }
+
+    public void changeAddresses() {
+        useNewAddresses = true;
+    }
+
+    public boolean useNewAddresses() {
+        return useNewAddresses;
+    }
+
+    public int resolutionCount() {
+        return resolutionCount;
+    }
+}
diff --git a/clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java b/clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java
index 3a281f8..9dfe83c 100644
--- a/clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java
@@ -30,6 +30,7 @@ import org.junit.Test;
 
 public class ClientUtilsTest {
 
+    private HostResolver hostResolver = new DefaultHostResolver();
 
     @Test
     public void testParseAndValidateAddresses() throws UnknownHostException {
@@ -97,25 +98,25 @@ public class ClientUtilsTest {
 
     @Test(expected = UnknownHostException.class)
     public void testResolveUnknownHostException() throws UnknownHostException {
-        ClientUtils.resolve("some.invalid.hostname.foo.bar.local", ClientDnsLookup.USE_ALL_DNS_IPS);
+        ClientUtils.resolve("some.invalid.hostname.foo.bar.local", ClientDnsLookup.USE_ALL_DNS_IPS, hostResolver);
     }
 
     @Test
     public void testResolveDnsLookup() throws UnknownHostException {
         // Note that kafka.apache.org resolves to at least 2 IP addresses
-        assertEquals(1, ClientUtils.resolve("kafka.apache.org", ClientDnsLookup.DEFAULT).size());
+        assertEquals(1, ClientUtils.resolve("kafka.apache.org", ClientDnsLookup.DEFAULT, hostResolver).size());
     }
 
     @Test
     public void testResolveDnsLookupAllIps() throws UnknownHostException {
         // Note that kafka.apache.org resolves to at least 2 IP addresses
-        assertTrue(ClientUtils.resolve("kafka.apache.org", ClientDnsLookup.USE_ALL_DNS_IPS).size() > 1);
+        assertTrue(ClientUtils.resolve("kafka.apache.org", ClientDnsLookup.USE_ALL_DNS_IPS, hostResolver).size() > 1);
     }
 
     @Test
     public void testResolveDnsLookupResolveCanonicalBootstrapServers() throws UnknownHostException {
         // Note that kafka.apache.org resolves to at least 2 IP addresses
-        assertTrue(ClientUtils.resolve("kafka.apache.org", ClientDnsLookup.RESOLVE_CANONICAL_BOOTSTRAP_SERVERS_ONLY).size() > 1);
+        assertTrue(ClientUtils.resolve("kafka.apache.org", ClientDnsLookup.RESOLVE_CANONICAL_BOOTSTRAP_SERVERS_ONLY, hostResolver).size() > 1);
     }
 
     private List<InetSocketAddress> checkWithoutLookup(String... url) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java b/clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java
index 608ef89..c0109c7 100644
--- a/clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java
@@ -23,12 +23,13 @@ import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
-import java.lang.reflect.Field;
-import java.lang.reflect.Method;
 import java.net.InetAddress;
 import java.net.UnknownHostException;
 
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Set;
 import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.utils.LogContext;
@@ -38,6 +39,26 @@ import org.junit.Test;
 
 public class ClusterConnectionStatesTest {
 
+    private static ArrayList<InetAddress> initialAddresses;
+    private static ArrayList<InetAddress> newAddresses;
+
+    static {
+        try {
+            initialAddresses = new ArrayList<>(Arrays.asList(
+                    InetAddress.getByName("10.200.20.100"),
+                    InetAddress.getByName("10.200.20.101"),
+                    InetAddress.getByName("10.200.20.102")
+            ));
+            newAddresses = new ArrayList<>(Arrays.asList(
+                    InetAddress.getByName("10.200.20.103"),
+                    InetAddress.getByName("10.200.20.104"),
+                    InetAddress.getByName("10.200.20.105")
+            ));
+        } catch (UnknownHostException e) {
+            fail("Attempted to create an invalid InetAddress, this should not happen");
+        }
+    }
+
     private final MockTime time = new MockTime();
     private final long reconnectBackoffMs = 10 * 1000;
     private final long reconnectBackoffMax = 60 * 1000;
@@ -50,15 +71,20 @@ public class ClusterConnectionStatesTest {
     private final String nodeId1 = "1001";
     private final String nodeId2 = "2002";
     private final String nodeId3 = "3003";
-    private final String hostTwoIps = "kafka.apache.org";
-
+    private final String hostTwoIps = "multiple.ip.address";
     private ClusterConnectionStates connectionStates;
 
+    // For testing nodes with a single IP address, use localhost and default DNS resolution
+    private DefaultHostResolver singleIPHostResolver = new DefaultHostResolver();
+
+    // For testing nodes with multiple IP addresses, mock DNS resolution to get consistent results
+    private AddressChangeHostResolver multipleIPHostResolver = new AddressChangeHostResolver(
+            initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0]));;
+
     @Before
     public void setup() {
-        this.connectionStates = new ClusterConnectionStates(
-                reconnectBackoffMs, reconnectBackoffMax,
-                connectionSetupTimeoutMs, connectionSetupTimeoutMaxMs, new LogContext());
+        this.connectionStates = new ClusterConnectionStates(reconnectBackoffMs, reconnectBackoffMax,
+                connectionSetupTimeoutMs, connectionSetupTimeoutMaxMs, new LogContext(), this.singleIPHostResolver);
     }
 
     @Test
@@ -253,7 +279,7 @@ public class ClusterConnectionStatesTest {
 
     @Test
     public void testSingleIPWithUseAll() throws UnknownHostException {
-        assertEquals(1, ClientUtils.resolve("localhost", ClientDnsLookup.USE_ALL_DNS_IPS).size());
+        assertEquals(1, ClientUtils.resolve("localhost", ClientDnsLookup.USE_ALL_DNS_IPS, singleIPHostResolver).size());
 
         connectionStates.connecting(nodeId1, time.milliseconds(), "localhost", ClientDnsLookup.USE_ALL_DNS_IPS);
         InetAddress currAddress = connectionStates.currentAddress(nodeId1);
@@ -263,7 +289,9 @@ public class ClusterConnectionStatesTest {
 
     @Test
     public void testMultipleIPsWithDefault() throws UnknownHostException {
-        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS).size() > 1);
+        setupMultipleIPs();
+
+        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS, multipleIPHostResolver).size() > 1);
 
         connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps, ClientDnsLookup.DEFAULT);
         InetAddress currAddress = connectionStates.currentAddress(nodeId1);
@@ -273,7 +301,9 @@ public class ClusterConnectionStatesTest {
 
     @Test
     public void testMultipleIPsWithUseAll() throws UnknownHostException {
-        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS).size() > 1);
+        setupMultipleIPs();
+
+        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS, multipleIPHostResolver).size() > 1);
 
         connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS);
         InetAddress addr1 = connectionStates.currentAddress(nodeId1);
@@ -287,19 +317,14 @@ public class ClusterConnectionStatesTest {
 
     @Test
     public void testHostResolveChange() throws UnknownHostException, ReflectiveOperationException {
-        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS).size() > 1);
+        setupMultipleIPs();
+
+        assertTrue(ClientUtils.resolve(hostTwoIps, ClientDnsLookup.USE_ALL_DNS_IPS, multipleIPHostResolver).size() > 1);
 
         connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps, ClientDnsLookup.DEFAULT);
         InetAddress addr1 = connectionStates.currentAddress(nodeId1);
 
-        // reflection to simulate host change in DNS lookup
-        Method nodeStateMethod = connectionStates.getClass().getDeclaredMethod("nodeState", String.class);
-        nodeStateMethod.setAccessible(true);
-        Object nodeState = nodeStateMethod.invoke(connectionStates, nodeId1);
-        Field hostField = nodeState.getClass().getDeclaredField("host");
-        hostField.setAccessible(true);
-        hostField.set(nodeState, "localhost");
-
+        multipleIPHostResolver.changeAddresses();
         connectionStates.connecting(nodeId1, time.milliseconds(), "localhost", ClientDnsLookup.DEFAULT);
         InetAddress addr2 = connectionStates.currentAddress(nodeId1);
 
@@ -308,9 +333,12 @@ public class ClusterConnectionStatesTest {
 
     @Test
     public void testNodeWithNewHostname() throws UnknownHostException {
+        setupMultipleIPs();
+
         connectionStates.connecting(nodeId1, time.milliseconds(), "localhost", ClientDnsLookup.DEFAULT);
         InetAddress addr1 = connectionStates.currentAddress(nodeId1);
 
+        this.multipleIPHostResolver.changeAddresses();
         connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps, ClientDnsLookup.DEFAULT);
         InetAddress addr2 = connectionStates.currentAddress(nodeId1);
 
@@ -409,4 +437,9 @@ public class ClusterConnectionStatesTest {
         // Expect no timed out connections
         assertEquals(0, connectionStates.nodesWithConnectionSetupTimeout(time.milliseconds()).size());
     }
+
+    private void setupMultipleIPs() {
+        this.connectionStates = new ClusterConnectionStates(reconnectBackoffMs, reconnectBackoffMax,
+                connectionSetupTimeoutMs, connectionSetupTimeoutMaxMs, new LogContext(), this.multipleIPHostResolver);
+    }
 }
diff --git a/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java b/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java
index 896c237..de0e524 100644
--- a/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java
@@ -48,6 +48,8 @@ import org.apache.kafka.test.TestUtils;
 import org.junit.Before;
 import org.junit.Test;
 
+import java.net.InetAddress;
+import java.net.UnknownHostException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -56,6 +58,7 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Optional;
 import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
@@ -66,6 +69,7 @@ import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 public class NetworkClientTest {
 
@@ -83,6 +87,26 @@ public class NetworkClientTest {
     private final NetworkClient clientWithStaticNodes = createNetworkClientWithStaticNodes();
     private final NetworkClient clientWithNoVersionDiscovery = createNetworkClientWithNoVersionDiscovery();
 
+    private static ArrayList<InetAddress> initialAddresses;
+    private static ArrayList<InetAddress> newAddresses;
+
+    static {
+        try {
+            initialAddresses = new ArrayList<>(Arrays.asList(
+                    InetAddress.getByName("10.200.20.100"),
+                    InetAddress.getByName("10.200.20.101"),
+                    InetAddress.getByName("10.200.20.102")
+            ));
+            newAddresses = new ArrayList<>(Arrays.asList(
+                    InetAddress.getByName("10.200.20.103"),
+                    InetAddress.getByName("10.200.20.104"),
+                    InetAddress.getByName("10.200.20.105")
+            ));
+        } catch (UnknownHostException e) {
+            fail("Attempted to create an invalid InetAddress, this should not happen");
+        }
+    }
+
     private NetworkClient createNetworkClient(long reconnectBackoffMaxMs) {
         return new NetworkClient(selector, metadataUpdater, "mock", Integer.MAX_VALUE,
                 reconnectBackoffMsTest, reconnectBackoffMaxMs, 64 * 1024, 64 * 1024,
@@ -937,6 +961,144 @@ public class NetworkClientTest {
         ids.forEach(id -> assertTrue(id < SaslClientAuthenticator.MIN_RESERVED_CORRELATION_ID));
     }
 
+    @Test
+    public void testReconnectAfterAddressChange() {
+        AddressChangeHostResolver mockHostResolver = new AddressChangeHostResolver(
+                initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0]));
+        AtomicInteger initialAddressConns = new AtomicInteger();
+        AtomicInteger newAddressConns = new AtomicInteger();
+        MockSelector selector = new MockSelector(this.time, inetSocketAddress -> {
+            InetAddress inetAddress = inetSocketAddress.getAddress();
+            if (initialAddresses.contains(inetAddress)) {
+                initialAddressConns.incrementAndGet();
+            } else if (newAddresses.contains(inetAddress)) {
+                newAddressConns.incrementAndGet();
+            }
+            return (mockHostResolver.useNewAddresses() && newAddresses.contains(inetAddress)) ||
+                   (!mockHostResolver.useNewAddresses() && initialAddresses.contains(inetAddress));
+        });
+        NetworkClient client = new NetworkClient(metadataUpdater, null, selector, "mock", Integer.MAX_VALUE,
+                reconnectBackoffMsTest, reconnectBackoffMaxMsTest, 64 * 1024, 64 * 1024,
+                defaultRequestTimeoutMs, connectionSetupTimeoutMsTest, connectionSetupTimeoutMaxMsTest,
+                ClientDnsLookup.USE_ALL_DNS_IPS, time, false, new ApiVersions(), null, new LogContext(), mockHostResolver);
+
+        // Connect to one the initial addresses, then change the addresses and disconnect
+        client.ready(node, time.milliseconds());
+        time.sleep(connectionSetupTimeoutMaxMsTest);
+        client.poll(0, time.milliseconds());
+        assertTrue(client.isReady(node, time.milliseconds()));
+
+        mockHostResolver.changeAddresses();
+        selector.serverDisconnect(node.idString());
+        client.poll(0, time.milliseconds());
+        assertFalse(client.isReady(node, time.milliseconds()));
+
+        time.sleep(reconnectBackoffMaxMsTest);
+        client.ready(node, time.milliseconds());
+        time.sleep(connectionSetupTimeoutMaxMsTest);
+        client.poll(0, time.milliseconds());
+        assertTrue(client.isReady(node, time.milliseconds()));
+
+        // We should have tried to connect to one initial address and one new address, and resolved DNS twice
+        assertEquals(1, initialAddressConns.get());
+        assertEquals(1, newAddressConns.get());
+        assertEquals(2, mockHostResolver.resolutionCount());
+    }
+
+    @Test
+    public void testFailedConnectionToFirstAddress() {
+        AddressChangeHostResolver mockHostResolver = new AddressChangeHostResolver(
+                initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0]));
+        AtomicInteger initialAddressConns = new AtomicInteger();
+        AtomicInteger newAddressConns = new AtomicInteger();
+        MockSelector selector = new MockSelector(this.time, inetSocketAddress -> {
+            InetAddress inetAddress = inetSocketAddress.getAddress();
+            if (initialAddresses.contains(inetAddress)) {
+                initialAddressConns.incrementAndGet();
+            } else if (newAddresses.contains(inetAddress)) {
+                newAddressConns.incrementAndGet();
+            }
+            // Refuse first connection attempt
+            return initialAddressConns.get() > 1;
+        });
+        NetworkClient client = new NetworkClient(metadataUpdater, null, selector, "mock", Integer.MAX_VALUE,
+                reconnectBackoffMsTest, reconnectBackoffMaxMsTest, 64 * 1024, 64 * 1024,
+                defaultRequestTimeoutMs, connectionSetupTimeoutMsTest, connectionSetupTimeoutMaxMsTest,
+                ClientDnsLookup.USE_ALL_DNS_IPS, time, false, new ApiVersions(), null, new LogContext(), mockHostResolver);
+
+        // First connection attempt should fail
+        client.ready(node, time.milliseconds());
+        time.sleep(connectionSetupTimeoutMaxMsTest);
+        client.poll(0, time.milliseconds());
+        assertFalse(client.isReady(node, time.milliseconds()));
+
+        // Second connection attempt should succeed
+        time.sleep(reconnectBackoffMaxMsTest);
+        client.ready(node, time.milliseconds());
+        time.sleep(connectionSetupTimeoutMaxMsTest);
+        client.poll(0, time.milliseconds());
+        assertTrue(client.isReady(node, time.milliseconds()));
+
+        // We should have tried to connect to two of the initial addresses, none of the new address, and should
+        // only have resolved DNS once
+        assertEquals(2, initialAddressConns.get());
+        assertEquals(0, newAddressConns.get());
+        assertEquals(1, mockHostResolver.resolutionCount());
+    }
+
+    @Test
+    public void testFailedConnectionToFirstAddressAfterReconnect() {
+        AddressChangeHostResolver mockHostResolver = new AddressChangeHostResolver(
+                initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0]));
+        AtomicInteger initialAddressConns = new AtomicInteger();
+        AtomicInteger newAddressConns = new AtomicInteger();
+        MockSelector selector = new MockSelector(this.time, inetSocketAddress -> {
+            InetAddress inetAddress = inetSocketAddress.getAddress();
+            if (initialAddresses.contains(inetAddress)) {
+                initialAddressConns.incrementAndGet();
+            } else if (newAddresses.contains(inetAddress)) {
+                newAddressConns.incrementAndGet();
+            }
+            // Refuse first connection attempt to the new addresses
+            return initialAddresses.contains(inetAddress) || newAddressConns.get() > 1;
+        });
+        NetworkClient client = new NetworkClient(metadataUpdater, null, selector, "mock", Integer.MAX_VALUE,
+                reconnectBackoffMsTest, reconnectBackoffMaxMsTest, 64 * 1024, 64 * 1024,
+                defaultRequestTimeoutMs, connectionSetupTimeoutMsTest, connectionSetupTimeoutMaxMsTest,
+                ClientDnsLookup.USE_ALL_DNS_IPS, time, false, new ApiVersions(), null, new LogContext(), mockHostResolver);
+
+        // Connect to one the initial addresses, then change the addresses and disconnect
+        client.ready(node, time.milliseconds());
+        time.sleep(connectionSetupTimeoutMaxMsTest);
+        client.poll(0, time.milliseconds());
+        assertTrue(client.isReady(node, time.milliseconds()));
+
+        mockHostResolver.changeAddresses();
+        selector.serverDisconnect(node.idString());
+        client.poll(0, time.milliseconds());
+        assertFalse(client.isReady(node, time.milliseconds()));
+
+        // First connection attempt to new addresses should fail
+        time.sleep(reconnectBackoffMaxMsTest);
+        client.ready(node, time.milliseconds());
+        time.sleep(connectionSetupTimeoutMaxMsTest);
+        client.poll(0, time.milliseconds());
+        assertFalse(client.isReady(node, time.milliseconds()));
+
+        // Second connection attempt to new addresses should succeed
+        time.sleep(reconnectBackoffMaxMsTest);
+        client.ready(node, time.milliseconds());
+        time.sleep(connectionSetupTimeoutMaxMsTest);
+        client.poll(0, time.milliseconds());
+        assertTrue(client.isReady(node, time.milliseconds()));
+
+        // We should have tried to connect to one of the initial addresses and two of the new addresses (the first one
+        // failed), and resolved DNS twice, once for each set of addresses
+        assertEquals(1, initialAddressConns.get());
+        assertEquals(2, newAddressConns.get());
+        assertEquals(2, mockHostResolver.resolutionCount());
+    }
+
     private RequestHeader parseHeader(ByteBuffer buffer) {
         buffer.getInt(); // skip size
         return RequestHeader.parse(buffer.slice());
diff --git a/clients/src/test/java/org/apache/kafka/test/MockSelector.java b/clients/src/test/java/org/apache/kafka/test/MockSelector.java
index 20f75be..c086e6b 100644
--- a/clients/src/test/java/org/apache/kafka/test/MockSelector.java
+++ b/clients/src/test/java/org/apache/kafka/test/MockSelector.java
@@ -32,6 +32,7 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.function.Predicate;
 
 /**
  * A fake selector to use for testing
@@ -46,14 +47,22 @@ public class MockSelector implements Selectable {
     private final Map<String, ChannelState> disconnected = new HashMap<>();
     private final List<String> connected = new ArrayList<>();
     private final List<DelayedReceive> delayedReceives = new ArrayList<>();
+    private final Predicate<InetSocketAddress> canConnect;
 
     public MockSelector(Time time) {
+        this(time, null);
+    }
+
+    public MockSelector(Time time, Predicate<InetSocketAddress> canConnect) {
         this.time = time;
+        this.canConnect = canConnect;
     }
 
     @Override
     public void connect(String id, InetSocketAddress address, int sendBufferSize, int receiveBufferSize) throws IOException {
-        this.connected.add(id);
+        if (canConnect == null || canConnect.test(address)) {
+            this.connected.add(id);
+        }
     }
 
     @Override