You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by da...@apache.org on 2023/02/09 13:53:25 UTC

[kafka] branch trunk updated: KAFKA-7109: Close fetch sessions on close of consumer (#12590)

This is an automated email from the ASF dual-hosted git repository.

dajac pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new e903f2cd964 KAFKA-7109: Close fetch sessions on close of consumer (#12590)
e903f2cd964 is described below

commit e903f2cd9646639cbada753a705b49fb903e1add
Author: Divij Vaidya <di...@amazon.com>
AuthorDate: Thu Feb 9 14:53:10 2023 +0100

    KAFKA-7109: Close fetch sessions on close of consumer (#12590)
    
    ## Problem
    When consumer is closed, fetch sessions associated with the consumer should notify the server about it's intention to close using a Fetch call with epoch = -1 (identified by `FINAL_EPOCH` in `FetchMetadata.java`). However, we are not sending this final fetch request in the current flow which leads to unnecessary fetch sessions on the server which are closed only after timeout.
    
    ## Changes
    1. Change `close()` in `Fetcher` to add a logic to send the final Fetch request notifying close to the server.
    2. Change `close()` in `Consumer` to respect the timeout duration passed to it. Prior to this change, the timeout parameter was being ignored.
    3. Change tests to close with `Duration.zero` to reduce the execution time of the tests. Otherwise the tests will wait for default timeout to exit (close() in the tests is expected to be unsuccessful because there is no server to send the request to).
    4. Distinguish between the case of "close existing session and create new session" and "close existing session" by renaming the `nextCloseExisting` function to `nextCloseExistingAttemptNew`.
    
    ## Testing
    Added unit test which validates that the correct close request is sent to the server.
    
    Reviewers: Ismael Juma <is...@juma.me.uk>, Kirk True <ki...@mustardgrain.com>, Philip Nee <ph...@gmail.com>, Luke Chen <sh...@gmail.com>, David Jacot <dj...@confluent.io>
---
 .../apache/kafka/clients/FetchSessionHandler.java  |  19 +-
 .../kafka/clients/consumer/KafkaConsumer.java      |  42 ++-
 .../kafka/clients/consumer/internals/Fetcher.java  | 122 +++++--
 .../kafka/common/requests/FetchMetadata.java       |  11 +-
 .../apache/kafka/common/requests/FetchRequest.java |   5 +
 .../java/org/apache/kafka/common/utils/Utils.java  |  42 ++-
 .../kafka/clients/consumer/KafkaConsumerTest.java  | 366 +++++++++++----------
 .../clients/consumer/internals/FetcherTest.java    |  59 +++-
 .../java/kafka/testkit/KafkaClusterTestKit.java    |   5 +-
 9 files changed, 439 insertions(+), 232 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java
index 4d94b84291a..e7556d2c8c3 100644
--- a/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java
+++ b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java
@@ -71,6 +71,11 @@ public class FetchSessionHandler {
         this.node = node;
     }
 
+    // visible for testing
+    public int sessionId() {
+        return nextMetadata.sessionId();
+    }
+
     /**
      * All of the partitions which exist in the fetch request session.
      */
@@ -525,7 +530,7 @@ public class FetchSessionHandler {
             if (response.error() == Errors.FETCH_SESSION_ID_NOT_FOUND) {
                 nextMetadata = FetchMetadata.INITIAL;
             } else {
-                nextMetadata = nextMetadata.nextCloseExisting();
+                nextMetadata = nextMetadata.nextCloseExistingAttemptNew();
             }
             return false;
         }
@@ -567,7 +572,7 @@ public class FetchSessionHandler {
             String problem = verifyIncrementalFetchResponsePartitions(topicPartitions, response.topicIds(), version);
             if (problem != null) {
                 log.info("Node {} sent an invalid incremental fetch response with {}", node, problem);
-                nextMetadata = nextMetadata.nextCloseExisting();
+                nextMetadata = nextMetadata.nextCloseExistingAttemptNew();
                 return false;
             } else if (response.sessionId() == INVALID_SESSION_ID) {
                 // The incremental fetch session was closed by the server.
@@ -590,6 +595,14 @@ public class FetchSessionHandler {
         }
     }
 
+    /**
+     * The client will initiate the session close on next fetch request.
+     */
+    public void notifyClose() {
+        log.debug("Set the metadata for next fetch request to close the existing session ID={}", nextMetadata.sessionId());
+        nextMetadata = nextMetadata.nextCloseExisting();
+    }
+
     /**
      * Handle an error sending the prepared request.
      *
@@ -600,7 +613,7 @@ public class FetchSessionHandler {
      */
     public void handleError(Throwable t) {
         log.info("Error sending fetch request {} to node {}:", nextMetadata, node, t);
-        nextMetadata = nextMetadata.nextCloseExisting();
+        nextMetadata = nextMetadata.nextCloseExistingAttemptNew();
     }
 
     /**
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
index 1d756d1e64c..f7f64de2945 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
@@ -59,6 +59,7 @@ import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Timer;
 import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
+import org.slf4j.event.Level;
 
 import java.net.InetSocketAddress;
 import java.time.Duration;
@@ -824,7 +825,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
             // call close methods if internal objects are already constructed; this is to prevent resource leak. see KAFKA-2121
             // we do not need to call `close` at all when `log` is null, which means no internal objects were initialized.
             if (this.log != null) {
-                close(0, true);
+                close(Duration.ZERO, true);
             }
             // now propagate the exception
             throw new KafkaException("Failed to construct kafka consumer", t);
@@ -2397,7 +2398,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
             if (!closed) {
                 // need to close before setting the flag since the close function
                 // itself may trigger rebalance callback that needs the consumer to be open still
-                close(timeout.toMillis(), false);
+                close(timeout, false);
             }
         } finally {
             closed = true;
@@ -2425,17 +2426,38 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
         return clusterResourceListeners;
     }
 
-    private void close(long timeoutMs, boolean swallowException) {
+    private Timer createTimerForRequest(final Duration timeout) {
+        // this.time could be null if an exception occurs in constructor prior to setting the this.time field
+        final Time localTime = (time == null) ? Time.SYSTEM : time;
+        return localTime.timer(Math.min(timeout.toMillis(), requestTimeoutMs));
+    }
+
+    private void close(Duration timeout, boolean swallowException) {
         log.trace("Closing the Kafka consumer");
         AtomicReference<Throwable> firstException = new AtomicReference<>();
-        try {
-            if (coordinator != null)
-                coordinator.close(time.timer(Math.min(timeoutMs, requestTimeoutMs)));
-        } catch (Throwable t) {
-            firstException.compareAndSet(null, t);
-            log.error("Failed to close coordinator", t);
+
+        final Timer closeTimer = createTimerForRequest(timeout);
+        // Close objects with a timeout. The timeout is required because the coordinator & the fetcher send requests to
+        // the server in the process of closing which may not respect the overall timeout defined for closing the
+        // consumer.
+        if (coordinator != null) {
+            // This is a blocking call bound by the time remaining in closeTimer
+            Utils.swallow(log, Level.ERROR, "Failed to close coordinator with a timeout(ms)=" + closeTimer.timeoutMs(), () -> coordinator.close(closeTimer), firstException);
         }
-        Utils.closeQuietly(fetcher, "fetcher", firstException);
+
+        if (fetcher != null) {
+            // the timeout for the session close is at-most the requestTimeoutMs
+            long remainingDurationInTimeout = Math.max(0, timeout.toMillis() - closeTimer.elapsedMs());
+            if (remainingDurationInTimeout > 0) {
+                remainingDurationInTimeout = Math.min(requestTimeoutMs, remainingDurationInTimeout);
+            }
+
+            closeTimer.reset(remainingDurationInTimeout);
+
+            // This is a blocking call bound by the time remaining in closeTimer
+            Utils.swallow(log, Level.ERROR, "Failed to close fetcher with a timeout(ms)=" + closeTimer.timeoutMs(), () -> fetcher.close(closeTimer), firstException);
+        }
+
         Utils.closeQuietly(interceptors, "consumer interceptors", firstException);
         Utils.closeQuietly(kafkaConsumerMetrics, "kafka consumer metrics", firstException);
         Utils.closeQuietly(metrics, "consumer metrics", firstException);
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index c93b675f755..f81c4352c81 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -104,6 +104,7 @@ import java.util.PriorityQueue;
 import java.util.Queue;
 import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
@@ -159,7 +160,7 @@ public class Fetcher<K, V> implements Closeable {
     private final Set<Integer> nodesWithPendingFetchRequests;
     private final ApiVersions apiVersions;
     private final AtomicInteger metadataUpdateVersion = new AtomicInteger(-1);
-
+    private final AtomicBoolean isClosed = new AtomicBoolean(false);
     private CompletedFetch nextInLineFetch = null;
 
     public Fetcher(LogContext logContext,
@@ -253,25 +254,7 @@ public class Fetcher<K, V> implements Closeable {
         for (Map.Entry<Node, FetchSessionHandler.FetchRequestData> entry : fetchRequestMap.entrySet()) {
             final Node fetchTarget = entry.getKey();
             final FetchSessionHandler.FetchRequestData data = entry.getValue();
-            final short maxVersion;
-            if (!data.canUseTopicIds()) {
-                maxVersion = (short) 12;
-            } else {
-                maxVersion = ApiKeys.FETCH.latestVersion();
-            }
-            final FetchRequest.Builder request = FetchRequest.Builder
-                    .forConsumer(maxVersion, this.maxWaitMs, this.minBytes, data.toSend())
-                    .isolationLevel(isolationLevel)
-                    .setMaxBytes(this.maxBytes)
-                    .metadata(data.metadata())
-                    .removed(data.toForget())
-                    .replaced(data.toReplace())
-                    .rackId(clientRackId);
-
-            if (log.isDebugEnabled()) {
-                log.debug("Sending {} {} to broker {}", isolationLevel, data.toString(), fetchTarget);
-            }
-            RequestFuture<ClientResponse> future = client.send(fetchTarget, request);
+            final RequestFuture<ClientResponse> future = sendFetchRequestToNode(data, fetchTarget);
             // We add the node to the set of nodes with pending fetch requests before adding the
             // listener because the future may have been fulfilled on another thread (e.g. during a
             // disconnection being handled by the heartbeat thread) which will mean the listener
@@ -447,6 +430,33 @@ public class Fetcher<K, V> implements Closeable {
             return client.send(node, request);
     }
 
+    /**
+     * Send Fetch Request to Kafka cluster asynchronously.
+     *
+     * This method is visible for testing.
+     *
+     * @return A future that indicates result of sent Fetch request
+     */
+    private RequestFuture<ClientResponse> sendFetchRequestToNode(final FetchSessionHandler.FetchRequestData requestData,
+                                                                 final Node fetchTarget) {
+        // Version 12 is the maximum version that could be used without topic IDs. See FetchRequest.json for schema
+        // changelog.
+        final short maxVersion = requestData.canUseTopicIds() ? ApiKeys.FETCH.latestVersion() : (short) 12;
+
+        final FetchRequest.Builder request = FetchRequest.Builder
+                .forConsumer(maxVersion, this.maxWaitMs, this.minBytes, requestData.toSend())
+                .isolationLevel(isolationLevel)
+                .setMaxBytes(this.maxBytes)
+                .metadata(requestData.metadata())
+                .removed(requestData.toForget())
+                .replaced(requestData.toReplace())
+                .rackId(clientRackId);
+
+        log.debug("Sending {} {} to broker {}", isolationLevel, requestData, fetchTarget);
+
+        return client.send(fetchTarget, request);
+    }
+
     private Long offsetResetStrategyTimestamp(final TopicPartition partition) {
         OffsetResetStrategy strategy = subscriptions.resetStrategy(partition);
         if (strategy == OffsetResetStrategy.EARLIEST)
@@ -1936,11 +1946,77 @@ public class Fetcher<K, V> implements Closeable {
         }
     }
 
+    // Visible for testing
+    void maybeCloseFetchSessions(final Timer timer) {
+        final Cluster cluster = metadata.fetch();
+        final List<RequestFuture<ClientResponse>> requestFutures = new ArrayList<>();
+        sessionHandlers.forEach((fetchTargetNodeId, sessionHandler) -> {
+            // set the session handler to notify close. This will set the next metadata request to send close message.
+            sessionHandler.notifyClose();
+
+            final int sessionId = sessionHandler.sessionId();
+            // FetchTargetNode may not be available as it may have disconnected the connection. In such cases, we will
+            // skip sending the close request.
+            final Node fetchTarget = cluster.nodeById(fetchTargetNodeId);
+            if (fetchTarget == null || client.isUnavailable(fetchTarget)) {
+                log.debug("Skip sending close session request to broker {} since it is not reachable", fetchTarget);
+                return;
+            }
+
+            final RequestFuture<ClientResponse> responseFuture = sendFetchRequestToNode(sessionHandler.newBuilder().build(), fetchTarget);
+            responseFuture.addListener(new RequestFutureListener<ClientResponse>() {
+                @Override
+                public void onSuccess(ClientResponse value) {
+                    log.debug("Successfully sent a close message for fetch session: {} to node: {}", sessionId, fetchTarget);
+                }
+
+                @Override
+                public void onFailure(RuntimeException e) {
+                    log.debug("Unable to a close message for fetch session: {} to node: {}. " +
+                        "This may result in unnecessary fetch sessions at the broker.", sessionId, fetchTarget, e);
+                }
+            });
+
+            requestFutures.add(responseFuture);
+        });
+
+        // Poll to ensure that request has been written to the socket. Wait until either the timer has expired or until
+        // all requests have received a response.
+        while (timer.notExpired() && !requestFutures.stream().allMatch(RequestFuture::isDone)) {
+            client.poll(timer, null, true);
+        }
+
+        if (!requestFutures.stream().allMatch(RequestFuture::isDone)) {
+            // we ran out of time before completing all futures. It is ok since we don't want to block the shutdown
+            // here.
+            log.debug("All requests couldn't be sent in the specific timeout period {}ms. " +
+                "This may result in unnecessary fetch sessions at the broker. Consider increasing the timeout passed for " +
+                "KafkaConsumer.close(Duration timeout)", timer.timeoutMs());
+        }
+    }
+
+    public void close(final Timer timer) {
+        if (!isClosed.compareAndSet(false, true)) {
+            log.info("Fetcher {} is already closed.", this);
+            return;
+        }
+
+        // Shared states (e.g. sessionHandlers) could be accessed by multiple threads (such as heartbeat thread), hence,
+        // it is necessary to acquire a lock on the fetcher instance before modifying the states.
+        synchronized (Fetcher.this) {
+            // we do not need to re-enable wakeups since we are closing already
+            client.disableWakeups();
+            if (nextInLineFetch != null)
+                nextInLineFetch.drain();
+            maybeCloseFetchSessions(timer);
+            Utils.closeQuietly(decompressionBufferSupplier, "decompressionBufferSupplier");
+            sessionHandlers.clear();
+        }
+    }
+
     @Override
     public void close() {
-        if (nextInLineFetch != null)
-            nextInLineFetch.drain();
-        decompressionBufferSupplier.close();
+        close(time.timer(0));
     }
 
     private Set<String> topicsForPartitions(Collection<TopicPartition> partitions) {
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java
index feb6953f9da..f483296132c 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java
@@ -114,9 +114,18 @@ public class FetchMetadata {
     }
 
     /**
-     * Return the metadata for the next error response.
+     * Return the metadata for the next request. The metadata is set to indicate that the client wants to close the
+     * existing session.
      */
     public FetchMetadata nextCloseExisting() {
+        return new FetchMetadata(sessionId, FINAL_EPOCH);
+    }
+
+    /**
+     * Return the metadata for the next request. The metadata is set to indicate that the client wants to close the
+     * existing session and create a new one if possible.
+     */
+    public FetchMetadata nextCloseExistingAttemptNew() {
         return new FetchMetadata(sessionId, INITIAL_EPOCH);
     }
 
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java
index 2510f1e607c..aaad3b89104 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java
@@ -166,6 +166,11 @@ public class FetchRequest extends AbstractRequest {
             return this;
         }
 
+        // Visible for testing
+        public FetchMetadata metadata() {
+            return this.metadata;
+        }
+
         public Builder metadata(FetchMetadata metadata) {
             this.metadata = metadata;
             return this;
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
old mode 100755
new mode 100644
index 9249d7f96aa..4ac8d7d2fe1
--- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
@@ -28,6 +28,7 @@ import org.apache.kafka.common.config.ConfigException;
 import org.apache.kafka.common.network.TransferableChannel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import org.slf4j.event.Level;
 
 import java.io.Closeable;
 import java.io.DataOutput;
@@ -997,16 +998,39 @@ public final class Utils {
         if (exception != null)
             throw exception;
     }
+    public static void swallow(final Logger log, final Level level, final String what, final Runnable code) {
+        swallow(log, level, what, code, null);
+    }
 
-    public static void swallow(
-        Logger log,
-        String what,
-        Runnable runnable
-    ) {
-        try {
-            runnable.run();
-        } catch (Throwable e) {
-            log.warn("{} error", what, e);
+    /**
+     * Run the supplied code. If an exception is thrown, it is swallowed and registered to the firstException parameter.
+     */
+    public static void swallow(final Logger log, final Level level, final String what, final Runnable code,
+                               final AtomicReference<Throwable> firstException) {
+        if (code != null) {
+            try {
+                code.run();
+            } catch (Throwable t) {
+                switch (level) {
+                    case INFO:
+                        log.info(what, t);
+                        break;
+                    case DEBUG:
+                        log.debug(what, t);
+                        break;
+                    case ERROR:
+                        log.error(what, t);
+                        break;
+                    case TRACE:
+                        log.trace(what, t);
+                        break;
+                    case WARN:
+                    default:
+                        log.warn(what, t);
+                }
+                if (firstException != null)
+                    firstException.compareAndSet(null, t);
+            }
         }
     }
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
index f08ac45ddaf..713c8906ad2 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
@@ -98,6 +98,7 @@ import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.test.MockConsumerInterceptor;
 import org.apache.kafka.test.MockMetricsReporter;
 import org.apache.kafka.test.TestUtils;
+import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.Test;
 
 import javax.management.MBeanServer;
@@ -149,7 +150,13 @@ import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.fail;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
 
+/**
+ * Note to future authors in this class. If you close the consumer, close with DURATION.ZERO to reduce the duration of
+ * the test.
+ */
 public class KafkaConsumerTest {
     private final String topic = "test";
     private final Uuid topicId = Uuid.randomUuid();
@@ -198,19 +205,26 @@ public class KafkaConsumerTest {
     private final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
     private final ConsumerPartitionAssignor assignor = new RoundRobinAssignor();
 
+    private KafkaConsumer<?, ?> consumer;
+
+    @AfterEach
+    public void cleanup() {
+        if (consumer != null) {
+            consumer.close(Duration.ZERO);
+        }
+    }
+
     @Test
     public void testMetricsReporterAutoGeneratedClientId() {
         Properties props = new Properties();
         props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
         props.setProperty(ConsumerConfig.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName());
-        KafkaConsumer<String, String> consumer = new KafkaConsumer<>(
-                props, new StringDeserializer(), new StringDeserializer());
+        consumer = new KafkaConsumer<>(props, new StringDeserializer(), new StringDeserializer());
 
         MockMetricsReporter mockMetricsReporter = (MockMetricsReporter) consumer.metrics.reporters().get(0);
 
         assertEquals(consumer.getClientId(), mockMetricsReporter.clientId);
         assertEquals(2, consumer.metrics.reporters().size());
-        consumer.close();
     }
 
     @Test
@@ -219,9 +233,8 @@ public class KafkaConsumerTest {
         Properties props = new Properties();
         props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
         props.setProperty(ConsumerConfig.AUTO_INCLUDE_JMX_REPORTER_CONFIG, "false");
-        KafkaConsumer<String, String> consumer = new KafkaConsumer<>(props, new StringDeserializer(), new StringDeserializer());
+        consumer = new KafkaConsumer<>(props, new StringDeserializer(), new StringDeserializer());
         assertTrue(consumer.metrics.reporters().isEmpty());
-        consumer.close();
     }
 
     @Test
@@ -229,32 +242,31 @@ public class KafkaConsumerTest {
         Properties props = new Properties();
         props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
         props.setProperty(ConsumerConfig.METRIC_REPORTER_CLASSES_CONFIG, "org.apache.kafka.common.metrics.JmxReporter");
-        KafkaConsumer<String, String> consumer = new KafkaConsumer<>(props, new StringDeserializer(), new StringDeserializer());
+        consumer = new KafkaConsumer<>(props, new StringDeserializer(), new StringDeserializer());
         assertEquals(1, consumer.metrics.reporters().size());
-        consumer.close();
     }
 
     @Test
+    @SuppressWarnings("unchecked")
     public void testPollReturnsRecords() {
-        KafkaConsumer<String, String> consumer = setUpConsumerWithRecordsToPoll(tp0, 5);
+        consumer = setUpConsumerWithRecordsToPoll(tp0, 5);
 
-        ConsumerRecords<String, String> records = consumer.poll(Duration.ZERO);
+        ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ZERO);
 
         assertEquals(records.count(), 5);
         assertEquals(records.partitions(), Collections.singleton(tp0));
         assertEquals(records.records(tp0).size(), 5);
-
-        consumer.close(Duration.ofMillis(0));
     }
 
     @Test
+    @SuppressWarnings("unchecked")
     public void testSecondPollWithDeserializationErrorThrowsRecordDeserializationException() {
         int invalidRecordNumber = 4;
         int invalidRecordOffset = 3;
         StringDeserializer deserializer = mockErrorDeserializer(invalidRecordNumber);
 
-        KafkaConsumer<String, String> consumer = setUpConsumerWithRecordsToPoll(tp0, 5, deserializer);
-        ConsumerRecords<String, String> records = consumer.poll(Duration.ZERO);
+        consumer = setUpConsumerWithRecordsToPoll(tp0, 5, deserializer);
+        ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ZERO);
 
         assertEquals(invalidRecordNumber - 1, records.count());
         assertEquals(Collections.singleton(tp0), records.partitions());
@@ -266,7 +278,6 @@ public class KafkaConsumerTest {
         assertEquals(invalidRecordOffset, rde.offset());
         assertEquals(tp0, rde.topicPartition());
         assertEquals(rde.offset(), consumer.position(tp0));
-        consumer.close(Duration.ofMillis(0));
     }
 
     /*
@@ -288,18 +299,18 @@ public class KafkaConsumerTest {
         };
     }
 
-    private KafkaConsumer<String, String> setUpConsumerWithRecordsToPoll(TopicPartition tp, int recordCount) {
+    private KafkaConsumer<?, ?> setUpConsumerWithRecordsToPoll(TopicPartition tp, int recordCount) {
         return setUpConsumerWithRecordsToPoll(tp, recordCount, new StringDeserializer());
     }
 
-    private KafkaConsumer<String, String> setUpConsumerWithRecordsToPoll(TopicPartition tp, int recordCount, Deserializer<String> deserializer) {
+    private KafkaConsumer<?, ?> setUpConsumerWithRecordsToPoll(TopicPartition tp, int recordCount, Deserializer<String> deserializer) {
         Cluster cluster = TestUtils.singletonCluster(tp.topic(), 1);
         Node node = cluster.nodes().get(0);
 
         ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
         initMetadata(client, Collections.singletonMap(topic, 1));
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor,
+        consumer = newConsumer(time, client, subscription, metadata, assignor,
                 true, groupId, groupInstanceId, Optional.of(deserializer), false);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         prepareRebalance(client, node, assignor, singletonList(tp), null);
@@ -333,9 +344,7 @@ public class KafkaConsumerTest {
         config.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
         config.put(ConsumerConfig.SEND_BUFFER_CONFIG, Selectable.USE_DEFAULT_BUFFER_SIZE);
         config.put(ConsumerConfig.RECEIVE_BUFFER_CONFIG, Selectable.USE_DEFAULT_BUFFER_SIZE);
-        KafkaConsumer<byte[], byte[]> consumer = new KafkaConsumer<>(
-                config, new ByteArrayDeserializer(), new ByteArrayDeserializer());
-        consumer.close();
+        consumer = new KafkaConsumer<>(config, new ByteArrayDeserializer(), new ByteArrayDeserializer());
     }
 
     @Test
@@ -361,14 +370,12 @@ public class KafkaConsumerTest {
         Map<String, Object> config = new HashMap<>();
         config.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
         config.put(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG, "instance_id");
-        KafkaConsumer<byte[], byte[]> consumer = new KafkaConsumer<>(
-                config, new ByteArrayDeserializer(), new ByteArrayDeserializer());
-        consumer.close();
+        consumer = new KafkaConsumer<>(config, new ByteArrayDeserializer(), new ByteArrayDeserializer());
     }
 
     @Test
     public void testSubscription() {
-        KafkaConsumer<byte[], byte[]> consumer = newConsumer(groupId);
+        consumer = newConsumer(groupId);
 
         consumer.subscribe(singletonList(topic));
         assertEquals(singleton(topic), consumer.subscription());
@@ -385,46 +392,39 @@ public class KafkaConsumerTest {
         consumer.unsubscribe();
         assertTrue(consumer.subscription().isEmpty());
         assertTrue(consumer.assignment().isEmpty());
-
-        consumer.close();
     }
 
     @Test
     public void testSubscriptionOnNullTopicCollection() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer(groupId)) {
-            assertThrows(IllegalArgumentException.class, () -> consumer.subscribe((List<String>) null));
-        }
+        consumer = newConsumer(groupId);
+        assertThrows(IllegalArgumentException.class, () -> consumer.subscribe((List<String>) null));
     }
 
     @Test
     public void testSubscriptionOnNullTopic() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer(groupId)) {
-            assertThrows(IllegalArgumentException.class, () -> consumer.subscribe(singletonList(null)));
-        }
+        consumer = newConsumer(groupId);
+        assertThrows(IllegalArgumentException.class, () -> consumer.subscribe(singletonList(null)));
     }
 
     @Test
     public void testSubscriptionOnEmptyTopic() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer(groupId)) {
-            String emptyTopic = "  ";
-            assertThrows(IllegalArgumentException.class, () -> consumer.subscribe(singletonList(emptyTopic)));
-        }
+        consumer = newConsumer(groupId);
+        String emptyTopic = "  ";
+        assertThrows(IllegalArgumentException.class, () -> consumer.subscribe(singletonList(emptyTopic)));
     }
 
     @Test
     public void testSubscriptionOnNullPattern() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer(groupId)) {
-            assertThrows(IllegalArgumentException.class,
-                () -> consumer.subscribe((Pattern) null));
-        }
+        consumer = newConsumer(groupId);
+        assertThrows(IllegalArgumentException.class,
+            () -> consumer.subscribe((Pattern) null));
     }
 
     @Test
     public void testSubscriptionOnEmptyPattern() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer(groupId)) {
-            assertThrows(IllegalArgumentException.class,
-                () -> consumer.subscribe(Pattern.compile("")));
-        }
+        consumer = newConsumer(groupId);
+        assertThrows(IllegalArgumentException.class,
+            () -> consumer.subscribe(Pattern.compile("")));
     }
 
     @Test
@@ -433,49 +433,43 @@ public class KafkaConsumerTest {
         props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
         props.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, "");
         props.setProperty(ConsumerConfig.GROUP_ID_CONFIG, groupId);
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer(props)) {
-            assertThrows(IllegalStateException.class,
-                () -> consumer.subscribe(singletonList(topic)));
-        }
+        consumer = newConsumer(props);
+        assertThrows(IllegalStateException.class,
+            () -> consumer.subscribe(singletonList(topic)));
     }
 
     @Test
     public void testSeekNegative() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer((String) null)) {
-            consumer.assign(singleton(new TopicPartition("nonExistTopic", 0)));
-            assertThrows(IllegalArgumentException.class,
-                () -> consumer.seek(new TopicPartition("nonExistTopic", 0), -1));
-        }
+        consumer = newConsumer((String) null);
+        consumer.assign(singleton(new TopicPartition("nonExistTopic", 0)));
+        assertThrows(IllegalArgumentException.class,
+            () -> consumer.seek(new TopicPartition("nonExistTopic", 0), -1));
     }
 
     @Test
     public void testAssignOnNullTopicPartition() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer((String) null)) {
-            assertThrows(IllegalArgumentException.class, () -> consumer.assign(null));
-        }
+        consumer = newConsumer((String) null);
+        assertThrows(IllegalArgumentException.class, () -> consumer.assign(null));
     }
 
     @Test
     public void testAssignOnEmptyTopicPartition() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer(groupId)) {
-            consumer.assign(Collections.emptyList());
-            assertTrue(consumer.subscription().isEmpty());
-            assertTrue(consumer.assignment().isEmpty());
-        }
+        consumer = newConsumer(groupId);
+        consumer.assign(Collections.emptyList());
+        assertTrue(consumer.subscription().isEmpty());
+        assertTrue(consumer.assignment().isEmpty());
     }
 
     @Test
     public void testAssignOnNullTopicInPartition() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer((String) null)) {
-            assertThrows(IllegalArgumentException.class, () -> consumer.assign(singleton(new TopicPartition(null, 0))));
-        }
+        consumer = newConsumer((String) null);
+        assertThrows(IllegalArgumentException.class, () -> consumer.assign(singleton(new TopicPartition(null, 0))));
     }
 
     @Test
     public void testAssignOnEmptyTopicInPartition() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer((String) null)) {
-            assertThrows(IllegalArgumentException.class, () -> consumer.assign(singleton(new TopicPartition("  ", 0))));
-        }
+        consumer = newConsumer((String) null);
+        assertThrows(IllegalArgumentException.class, () -> consumer.assign(singleton(new TopicPartition("  ", 0))));
     }
 
     @Test
@@ -486,12 +480,12 @@ public class KafkaConsumerTest {
             props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
             props.setProperty(ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG, MockConsumerInterceptor.class.getName());
 
-            KafkaConsumer<String, String> consumer = new KafkaConsumer<>(
+            consumer = new KafkaConsumer<>(
                     props, new StringDeserializer(), new StringDeserializer());
             assertEquals(1, MockConsumerInterceptor.INIT_COUNT.get());
             assertEquals(0, MockConsumerInterceptor.CLOSE_COUNT.get());
 
-            consumer.close();
+            consumer.close(Duration.ZERO);
             assertEquals(1, MockConsumerInterceptor.INIT_COUNT.get());
             assertEquals(1, MockConsumerInterceptor.CLOSE_COUNT.get());
             // Cluster metadata will only be updated on calling poll.
@@ -505,7 +499,7 @@ public class KafkaConsumerTest {
 
     @Test
     public void testPause() {
-        KafkaConsumer<byte[], byte[]> consumer = newConsumer(groupId);
+        consumer = newConsumer(groupId);
 
         consumer.assign(singletonList(tp0));
         assertEquals(singleton(tp0), consumer.assignment());
@@ -519,8 +513,6 @@ public class KafkaConsumerTest {
 
         consumer.unsubscribe();
         assertTrue(consumer.paused().isEmpty());
-
-        consumer.close();
     }
 
     @Test
@@ -530,14 +522,12 @@ public class KafkaConsumerTest {
         config.put(ConsumerConfig.SEND_BUFFER_CONFIG, Selectable.USE_DEFAULT_BUFFER_SIZE);
         config.put(ConsumerConfig.RECEIVE_BUFFER_CONFIG, Selectable.USE_DEFAULT_BUFFER_SIZE);
         config.put("client.id", "client-1");
-        KafkaConsumer<byte[], byte[]> consumer = new KafkaConsumer<>(
-                config, new ByteArrayDeserializer(), new ByteArrayDeserializer());
+        consumer = new KafkaConsumer<>(config, new ByteArrayDeserializer(), new ByteArrayDeserializer());
         MBeanServer server = ManagementFactory.getPlatformMBeanServer();
         MetricName testMetricName = consumer.metrics.metricName("test-metric",
                 "grp1", "test metric");
         consumer.metrics.addMetric(testMetricName, new Avg());
         assertNotNull(server.getObjectInstance(new ObjectName("kafka.consumer:type=grp1,client-id=client-1")));
-        consumer.close();
     }
 
     private KafkaConsumer<byte[], byte[]> newConsumer(String groupId) {
@@ -568,7 +558,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
 
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null);
@@ -588,7 +578,6 @@ public class KafkaConsumerTest {
         consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
 
         assertTrue(heartbeatReceived.get());
-        consumer.close(Duration.ofMillis(0));
     }
 
     @Test
@@ -599,7 +588,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -619,7 +608,6 @@ public class KafkaConsumerTest {
         consumer.poll(Duration.ZERO);
 
         assertTrue(heartbeatReceived.get());
-        consumer.close(Duration.ofMillis(0));
     }
 
     @Test
@@ -630,7 +618,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
 
-        final KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         // Since we would enable the heartbeat thread after received join-response which could
         // send the sync-group on behalf of the consumer if it is enqueued, we may still complete
@@ -654,7 +642,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
 
-        final KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -668,13 +656,14 @@ public class KafkaConsumerTest {
     }
 
     @Test
+    @SuppressWarnings("unchecked")
     public void verifyNoCoordinatorLookupForManualAssignmentWithSeek() {
         ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, Collections.singletonMap(topic, 1));
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, null, groupInstanceId, false);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, null, groupInstanceId, false);
         consumer.assign(singleton(tp0));
         consumer.seekToBeginning(singleton(tp0));
 
@@ -683,10 +672,9 @@ public class KafkaConsumerTest {
         client.prepareResponse(listOffsetsResponse(Collections.singletonMap(tp0, 50L)));
         client.prepareResponse(fetchResponse(tp0, 50L, 5));
 
-        ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(1));
+        ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ofMillis(1));
         assertEquals(5, records.count());
         assertEquals(55L, consumer.position(tp0));
-        consumer.close(Duration.ofMillis(0));
     }
 
     @Test
@@ -698,7 +686,7 @@ public class KafkaConsumerTest {
         Node node = metadata.fetch().nodes().get(0);
 
         // create a consumer with groupID with manual assignment
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
         consumer.assign(singleton(tp0));
 
         // 1st coordinator error should cause coordinator unknown
@@ -711,7 +699,8 @@ public class KafkaConsumerTest {
         client.prepareResponse(offsetResponse(Collections.singletonMap(tp0, 50L), Errors.NONE));
         client.prepareResponse(fetchResponse(tp0, 50L, 5));
 
-        ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(0));
+        @SuppressWarnings("unchecked")
+        ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ofMillis(0));
         assertEquals(5, records.count());
         assertEquals(55L, consumer.position(tp0));
 
@@ -722,7 +711,6 @@ public class KafkaConsumerTest {
         // verify the offset is committed
         client.prepareResponse(offsetResponse(Collections.singletonMap(tp0, 55L), Errors.NONE));
         assertEquals(55, consumer.committed(Collections.singleton(tp0), Duration.ZERO).get(tp0).offset());
-        consumer.close(Duration.ofMillis(0));
     }
 
     @Test
@@ -735,7 +723,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 2));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumerNoAutoCommit(time, client, subscription, metadata);
+        consumer = newConsumerNoAutoCommit(time, client, subscription, metadata);
         consumer.assign(Arrays.asList(tp0, tp1));
         consumer.seekToEnd(singleton(tp0));
         consumer.seekToBeginning(singleton(tp1));
@@ -766,7 +754,8 @@ public class KafkaConsumerTest {
 
             }, fetchResponse(tp0, 50L, 5));
 
-        ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(1));
+        @SuppressWarnings("unchecked")
+        ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ofMillis(1));
         assertEquals(5, records.count());
         assertEquals(singleton(tp0), records.partitions());
     }
@@ -790,7 +779,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor,
+        consumer = newConsumer(time, client, subscription, metadata, assignor,
                 true, groupId, groupInstanceId, false);
         consumer.assign(singletonList(tp0));
 
@@ -833,7 +822,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor,
+        consumer = newConsumer(time, client, subscription, metadata, assignor,
                 true, groupId, groupInstanceId, false);
         consumer.assign(singletonList(tp0));
 
@@ -856,7 +845,7 @@ public class KafkaConsumerTest {
 
         initMetadata(client, Collections.singletonMap(topic, 1));
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor,
+        consumer = newConsumer(time, client, subscription, metadata, assignor,
                 true, groupId, Optional.empty(), false);
         consumer.assign(singletonList(tp0));
         consumer.seek(tp0, 20L);
@@ -875,7 +864,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 2));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
         consumer.assign(singletonList(tp0));
 
         // lookup coordinator
@@ -898,7 +887,6 @@ public class KafkaConsumerTest {
         offsets.put(tp1, offset2);
         client.prepareResponseFrom(offsetResponse(offsets, Errors.NONE), coordinator);
         assertEquals(offset2, consumer.committed(Collections.singleton(tp1)).get(tp1).offset());
-        consumer.close(Duration.ofMillis(0));
     }
 
     @Test
@@ -916,7 +904,7 @@ public class KafkaConsumerTest {
         assertThrows(UnsupportedVersionException.class, () -> setupThrowableConsumer().position(tp0));
     }
 
-    private KafkaConsumer<String, String> setupThrowableConsumer() {
+    private KafkaConsumer<?, ?> setupThrowableConsumer() {
         long offset1 = 10000;
 
         ConsumerMetadata metadata = createMetadata(subscription);
@@ -927,7 +915,7 @@ public class KafkaConsumerTest {
 
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(
+        consumer = newConsumer(
             time, client, subscription, metadata, assignor, true, groupId, groupInstanceId, true);
         consumer.assign(singletonList(tp0));
 
@@ -949,7 +937,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 2));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
         consumer.assign(Arrays.asList(tp0, tp1));
 
         // lookup coordinator
@@ -962,8 +950,6 @@ public class KafkaConsumerTest {
         assertEquals(2, committed.size());
         assertEquals(offset1, committed.get(tp0).offset());
         assertNull(committed.get(tp1));
-
-        consumer.close(Duration.ofMillis(0));
     }
 
     @Test
@@ -974,7 +960,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -995,7 +981,6 @@ public class KafkaConsumerTest {
         consumer.poll(Duration.ZERO);
 
         assertTrue(commitReceived.get());
-        consumer.close(Duration.ofMillis(0));
     }
 
     @Test
@@ -1011,7 +996,7 @@ public class KafkaConsumerTest {
         initMetadata(client, partitionCounts);
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
         prepareRebalance(client, node, singleton(topic), assignor, singletonList(tp0), null);
 
         consumer.subscribe(Pattern.compile(topic), getConsumerRebalanceListener(consumer));
@@ -1022,7 +1007,6 @@ public class KafkaConsumerTest {
 
         assertEquals(singleton(topic), consumer.subscription());
         assertEquals(singleton(tp0), consumer.assignment());
-        consumer.close(Duration.ofMillis(0));
     }
 
     @Test
@@ -1040,7 +1024,7 @@ public class KafkaConsumerTest {
         initMetadata(client, partitionCounts);
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
 
         Node coordinator = prepareRebalance(client, node, singleton(topic), assignor, singletonList(tp0), null);
         consumer.subscribe(Pattern.compile(topic), getConsumerRebalanceListener(consumer));
@@ -1057,7 +1041,6 @@ public class KafkaConsumerTest {
         consumer.poll(Duration.ZERO);
 
         assertEquals(singleton(otherTopic), consumer.subscription());
-        consumer.close(Duration.ofMillis(0));
     }
 
     @Test
@@ -1068,7 +1051,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -1087,7 +1070,8 @@ public class KafkaConsumerTest {
         assertEquals(0, consumer.position(tp0));
 
         // the next poll should return the completed fetch
-        ConsumerRecords<String, String> records = consumer.poll(Duration.ZERO);
+        @SuppressWarnings("unchecked")
+        ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ZERO);
         assertEquals(5, records.count());
         // Increment time asynchronously to clear timeouts in closing the consumer
         final ScheduledExecutorService exec = Executors.newSingleThreadScheduledExecutor();
@@ -1105,7 +1089,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
 
@@ -1119,7 +1103,6 @@ public class KafkaConsumerTest {
         } finally {
             // clear interrupted state again since this thread may be reused by JUnit
             Thread.interrupted();
-            consumer.close(Duration.ofMillis(0));
         }
     }
 
@@ -1131,7 +1114,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
         consumer.subscribe(singletonList(topic), getConsumerRebalanceListener(consumer));
 
         prepareRebalance(client, node, assignor, singletonList(tp0), null);
@@ -1143,9 +1126,9 @@ public class KafkaConsumerTest {
 
         consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
 
-        ConsumerRecords<String, String> records = consumer.poll(Duration.ZERO);
+        @SuppressWarnings("unchecked")
+        ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ZERO);
         assertEquals(0, records.count());
-        consumer.close(Duration.ofMillis(0));
     }
 
     /**
@@ -1156,6 +1139,7 @@ public class KafkaConsumerTest {
      * are both updated right away but its consumed offsets are not auto committed.
      */
     @Test
+    @SuppressWarnings("unchecked")
     public void testSubscriptionChangesWithAutoCommitEnabled() {
         ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
@@ -1169,7 +1153,7 @@ public class KafkaConsumerTest {
 
         ConsumerPartitionAssignor assignor = new RangeAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
 
         // initial subscription
         consumer.subscribe(Arrays.asList(topic, topic2), getConsumerRebalanceListener(consumer));
@@ -1198,7 +1182,7 @@ public class KafkaConsumerTest {
         client.respondFrom(fetchResponse(fetches1), node);
         client.poll(0, time.milliseconds());
 
-        ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(1));
+        ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ofMillis(1));
 
         // clear out the prefetch so it doesn't interfere with the rest of the test
         fetches1.put(tp0, new FetchInfo(1, 0));
@@ -1235,7 +1219,7 @@ public class KafkaConsumerTest {
         fetches2.put(t3p0, new FetchInfo(0, 100));
         client.prepareResponse(fetchResponse(fetches2));
 
-        records = consumer.poll(Duration.ofMillis(1));
+        records = (ConsumerRecords<String, String>) consumer.poll(Duration.ofMillis(1));
 
         // verify that the fetch occurred as expected
         assertEquals(101, records.count());
@@ -1258,7 +1242,6 @@ public class KafkaConsumerTest {
         assertTrue(consumer.assignment().isEmpty());
 
         client.requests().clear();
-        consumer.close();
     }
 
     /**
@@ -1281,7 +1264,7 @@ public class KafkaConsumerTest {
 
         ConsumerPartitionAssignor assignor = new RangeAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
 
         initializeSubscriptionWithSingleTopic(consumer, getConsumerRebalanceListener(consumer));
 
@@ -1320,7 +1303,6 @@ public class KafkaConsumerTest {
             assertNotSame(ApiKeys.OFFSET_COMMIT, req.requestBuilder().apiKey());
 
         client.requests().clear();
-        consumer.close();
     }
 
     @Test
@@ -1332,7 +1314,7 @@ public class KafkaConsumerTest {
         Node node = metadata.fetch().nodes().get(0);
 
         CooperativeStickyAssignor assignor = new CooperativeStickyAssignor();
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
 
         initializeSubscriptionWithSingleTopic(consumer, getExceptionConsumerRebalanceListener());
 
@@ -1355,7 +1337,7 @@ public class KafkaConsumerTest {
         Node node = metadata.fetch().nodes().get(0);
 
         CooperativeStickyAssignor assignor = new CooperativeStickyAssignor();
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
 
         initializeSubscriptionWithSingleTopic(consumer, getExceptionConsumerRebalanceListener());
         Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null);
@@ -1373,7 +1355,7 @@ public class KafkaConsumerTest {
         assertEquals(partitionLost + singleTopicPartition, unsubscribeException.getCause().getMessage());
     }
 
-    private void initializeSubscriptionWithSingleTopic(KafkaConsumer<String, String> consumer,
+    private void initializeSubscriptionWithSingleTopic(KafkaConsumer<?, ?> consumer,
                                                        ConsumerRebalanceListener consumerRebalanceListener) {
         consumer.subscribe(singleton(topic), consumerRebalanceListener);
         // verify that subscription has changed but assignment is still unchanged
@@ -1382,6 +1364,7 @@ public class KafkaConsumerTest {
     }
 
     @Test
+    @SuppressWarnings("unchecked")
     public void testManualAssignmentChangeWithAutoCommitEnabled() {
         ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
@@ -1394,7 +1377,7 @@ public class KafkaConsumerTest {
 
         ConsumerPartitionAssignor assignor = new RangeAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
 
         // lookup coordinator
         client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node);
@@ -1416,7 +1399,7 @@ public class KafkaConsumerTest {
         client.prepareResponse(listOffsetsResponse(Collections.singletonMap(tp0, 10L)));
         client.prepareResponse(fetchResponse(tp0, 10L, 1));
 
-        ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(1));
+        ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ofMillis(100));
 
         assertEquals(1, records.count());
         assertEquals(11L, consumer.position(tp0));
@@ -1433,7 +1416,6 @@ public class KafkaConsumerTest {
         assertTrue(commitReceived.get());
 
         client.requests().clear();
-        consumer.close();
     }
 
     @Test
@@ -1449,7 +1431,7 @@ public class KafkaConsumerTest {
 
         ConsumerPartitionAssignor assignor = new RangeAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
 
         // lookup coordinator
         client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node);
@@ -1473,7 +1455,8 @@ public class KafkaConsumerTest {
         client.prepareResponse(listOffsetsResponse(Collections.singletonMap(tp0, 10L)));
         client.prepareResponse(fetchResponse(tp0, 10L, 1));
 
-        ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(1));
+        @SuppressWarnings("unchecked")
+        ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ofMillis(1));
         assertEquals(1, records.count());
         assertEquals(11L, consumer.position(tp0));
 
@@ -1488,7 +1471,6 @@ public class KafkaConsumerTest {
             assertNotSame(req.requestBuilder().apiKey(), ApiKeys.OFFSET_COMMIT);
 
         client.requests().clear();
-        consumer.close();
     }
 
     @Test
@@ -1501,7 +1483,7 @@ public class KafkaConsumerTest {
 
         ConsumerPartitionAssignor assignor = new RangeAssignor();
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
 
         // lookup coordinator
         client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node);
@@ -1539,30 +1521,26 @@ public class KafkaConsumerTest {
 
         client.requests().clear();
         consumer.unsubscribe();
-        consumer.close();
     }
 
     @Test
     public void testPollWithNoSubscription() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer((String) null)) {
-            assertThrows(IllegalStateException.class, () -> consumer.poll(Duration.ZERO));
-        }
+        consumer = newConsumer((String) null);
+        assertThrows(IllegalStateException.class, () -> consumer.poll(Duration.ZERO));
     }
 
     @Test
     public void testPollWithEmptySubscription() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer(groupId)) {
-            consumer.subscribe(Collections.emptyList());
-            assertThrows(IllegalStateException.class, () -> consumer.poll(Duration.ZERO));
-        }
+        consumer = newConsumer(groupId);
+        consumer.subscribe(Collections.emptyList());
+        assertThrows(IllegalStateException.class, () -> consumer.poll(Duration.ZERO));
     }
 
     @Test
     public void testPollWithEmptyUserAssignment() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer(groupId)) {
-            consumer.assign(Collections.emptySet());
-            assertThrows(IllegalStateException.class, () -> consumer.poll(Duration.ZERO));
-        }
+        consumer = newConsumer(groupId);
+        consumer.assign(Collections.emptySet());
+        assertThrows(IllegalStateException.class, () -> consumer.poll(Duration.ZERO));
     }
 
     @Test
@@ -1571,7 +1549,24 @@ public class KafkaConsumerTest {
         response.put(tp0, Errors.NONE);
         OffsetCommitResponse commitResponse = offsetCommitResponse(response);
         LeaveGroupResponse leaveGroupResponse = new LeaveGroupResponse(new LeaveGroupResponseData().setErrorCode(Errors.NONE.code()));
-        consumerCloseTest(5000, Arrays.asList(commitResponse, leaveGroupResponse), 0, false);
+        FetchResponse closeResponse = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>());
+        consumerCloseTest(5000, Arrays.asList(commitResponse, leaveGroupResponse, closeResponse), 0, false);
+    }
+
+    @Test
+    public void testCloseTimeoutDueToNoResponseForCloseFetchRequest() throws Exception {
+        Map<TopicPartition, Errors> response = new HashMap<>();
+        response.put(tp0, Errors.NONE);
+        OffsetCommitResponse commitResponse = offsetCommitResponse(response);
+        LeaveGroupResponse leaveGroupResponse = new LeaveGroupResponse(new LeaveGroupResponseData().setErrorCode(Errors.NONE.code()));
+        final List<AbstractResponse> serverResponsesWithoutCloseResponse = Arrays.asList(commitResponse, leaveGroupResponse);
+
+        // to ensure timeout due to no response for fetcher close request, we will ensure that we have successful
+        // response from server for first two requests and the test is configured to wait for duration which is greater
+        // than configured timeout.
+        final int closeTimeoutMs = 5000;
+        final int waitForCloseCompletionMs = closeTimeoutMs + 1000;
+        consumerCloseTest(closeTimeoutMs, serverResponsesWithoutCloseResponse, waitForCloseCompletionMs, false);
     }
 
     @Test
@@ -1599,10 +1594,17 @@ public class KafkaConsumerTest {
 
     @Test
     public void testCloseShouldBeIdempotent() {
-        KafkaConsumer<byte[], byte[]> consumer = newConsumer((String) null);
-        consumer.close();
-        consumer.close();
-        consumer.close();
+        ConsumerMetadata metadata = createMetadata(subscription);
+        MockClient client = spy(new MockClient(time, metadata));
+        initMetadata(client, singletonMap(topic, 1));
+
+        consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
+
+        consumer.close(Duration.ZERO);
+        consumer.close(Duration.ZERO);
+
+        // verify that the call is idempotent by checking that the network client is only closed once.
+        verify(client).close();
     }
 
     @Test
@@ -1671,20 +1673,21 @@ public class KafkaConsumerTest {
     }
 
     @Test
-    public void testMetricConfigRecordingLevel() {
+    public void testMetricConfigRecordingLevelInfo() {
         Properties props = new Properties();
         props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000");
-        try (KafkaConsumer<byte[], byte[]> consumer = new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer())) {
-            assertEquals(Sensor.RecordingLevel.INFO, consumer.metrics.config().recordLevel());
-        }
+        KafkaConsumer<byte[], byte[]> consumer = new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer());
+        assertEquals(Sensor.RecordingLevel.INFO, consumer.metrics.config().recordLevel());
+        consumer.close(Duration.ZERO);
 
         props.put(ConsumerConfig.METRICS_RECORDING_LEVEL_CONFIG, "DEBUG");
-        try (KafkaConsumer<byte[], byte[]> consumer = new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer())) {
-            assertEquals(Sensor.RecordingLevel.DEBUG, consumer.metrics.config().recordLevel());
-        }
+        KafkaConsumer<byte[], byte[]> consumer2 = new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer());
+        assertEquals(Sensor.RecordingLevel.DEBUG, consumer2.metrics.config().recordLevel());
+        consumer2.close(Duration.ZERO);
     }
 
     @Test
+    @SuppressWarnings("unchecked")
     public void testShouldAttemptToRejoinGroupAfterSyncGroupFailed() throws Exception {
         ConsumerMetadata metadata = createMetadata(subscription);
         MockClient client = new MockClient(time, metadata);
@@ -1692,7 +1695,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Collections.singletonMap(topic, 1));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId);
         consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer));
         client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node);
         Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port());
@@ -1746,9 +1749,8 @@ public class KafkaConsumerTest {
         time.sleep(heartbeatIntervalMs);
         Thread.sleep(heartbeatIntervalMs);
         consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
-        final ConsumerRecords<String, String> records = consumer.poll(Duration.ZERO);
+        final ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ZERO);
         assertFalse(records.isEmpty());
-        consumer.close(Duration.ofMillis(0));
     }
 
     private void consumerCloseTest(final long closeTimeoutMs,
@@ -1799,15 +1801,27 @@ public class KafkaConsumerTest {
                 // Expected exception
             }
 
-            // Ensure close has started and queued at least one more request after commitAsync
+            // Ensure close has started and queued at least one more request after commitAsync.
+            //
+            // Close enqueues two requests, but second is enqueued only after first has succeeded. First is
+            // LEAVE_GROUP as part of coordinator close and second is FETCH with epoch=FINAL_EPOCH. At this stage
+            // we expect only the first one to have been requested. Hence, waiting for total 2 requests, one for
+            // commit and another for LEAVE_GROUP.
             client.waitForRequests(2, 1000);
 
             // In graceful mode, commit response results in close() completing immediately without a timeout
             // In non-graceful mode, close() times out without an exception even though commit response is pending
+            int nonCloseRequests = 1;
             for (int i = 0; i < responses.size(); i++) {
                 client.waitForRequests(1, 1000);
-                client.respondFrom(responses.get(i), coordinator);
-                if (i != responses.size() - 1) {
+                if (i == responses.size() - 1 && responses.get(i) instanceof FetchResponse) {
+                    // last request is the close session request which is sent to the leader of the partition.
+                    client.respondFrom(responses.get(i), node);
+                } else {
+                    client.respondFrom(responses.get(i), coordinator);
+                }
+                if (i < nonCloseRequests) {
+                    // the close request should not complete until non-close requests (commit requests) have completed.
                     try {
                         future.get(100, TimeUnit.MILLISECONDS);
                         fail("Close completed without waiting for response");
@@ -1827,7 +1841,7 @@ public class KafkaConsumerTest {
 
                 assertTrue(closeException.get() instanceof InterruptException, "Expected exception not thrown " + closeException);
             } else {
-                future.get(500, TimeUnit.MILLISECONDS); // Should succeed without TimeoutException or ExecutionException
+                future.get(closeTimeoutMs, TimeUnit.MILLISECONDS); // Should succeed without TimeoutException or ExecutionException
                 assertNull(closeException.get(), "Unexpected exception during close");
             }
         } finally {
@@ -2150,7 +2164,7 @@ public class KafkaConsumerTest {
 
         client.requests().clear();
         consumer.unsubscribe();
-        consumer.close();
+        consumer.close(Duration.ZERO);
     }
 
     @Test
@@ -2211,14 +2225,14 @@ public class KafkaConsumerTest {
     }
 
     @Test
+    @SuppressWarnings("unchecked")
     public void testCurrentLag() {
         final ConsumerMetadata metadata = createMetadata(subscription);
         final MockClient client = new MockClient(time, metadata);
 
         initMetadata(client, singletonMap(topic, 1));
 
-        final KafkaConsumer<String, String> consumer =
-            newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
 
         // throws for unassigned partition
         assertThrows(IllegalStateException.class, () -> consumer.currentLag(tp0));
@@ -2255,14 +2269,12 @@ public class KafkaConsumerTest {
         final FetchInfo fetchInfo = new FetchInfo(1L, 99L, 50L, 5);
         client.respond(fetchResponse(singletonMap(tp0, fetchInfo)));
 
-        final ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(1));
+        final ConsumerRecords<String, String> records = (ConsumerRecords<String, String>) consumer.poll(Duration.ofMillis(1));
         assertEquals(5, records.count());
         assertEquals(55L, consumer.position(tp0));
 
         // correct lag result
         assertEquals(OptionalLong.of(45L), consumer.currentLag(tp0));
-
-        consumer.close(Duration.ZERO);
     }
 
     @Test
@@ -2272,8 +2284,7 @@ public class KafkaConsumerTest {
 
         initMetadata(client, singletonMap(topic, 1));
 
-        final KafkaConsumer<String, String> consumer =
-                newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
+        consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId);
 
         consumer.assign(singleton(tp0));
 
@@ -2287,8 +2298,6 @@ public class KafkaConsumerTest {
         assertEquals(singletonMap(tp0, 90L), consumer.endOffsets(Collections.singleton(tp0)));
         // correct lag result should be returned as well
         assertEquals(OptionalLong.of(40L), consumer.currentLag(tp0));
-
-        consumer.close(Duration.ZERO);
     }
 
     private KafkaConsumer<String, String> consumerWithPendingAuthenticationError(final Time time) {
@@ -2312,7 +2321,7 @@ public class KafkaConsumerTest {
         return consumerWithPendingAuthenticationError(time);
     }
 
-    private ConsumerRebalanceListener getConsumerRebalanceListener(final KafkaConsumer<String, String> consumer) {
+    private ConsumerRebalanceListener getConsumerRebalanceListener(final KafkaConsumer<?, ?> consumer) {
         return new ConsumerRebalanceListener() {
             @Override
             public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
@@ -2813,10 +2822,9 @@ public class KafkaConsumerTest {
 
     @Test
     public void testEnforceRebalanceWithManualAssignment() {
-        try (KafkaConsumer<byte[], byte[]> consumer = newConsumer((String) null)) {
-            consumer.assign(singleton(new TopicPartition("topic", 0)));
-            assertThrows(IllegalStateException.class, consumer::enforceRebalance);
-        }
+        consumer = newConsumer((String) null);
+        consumer.assign(singleton(new TopicPartition("topic", 0)));
+        assertThrows(IllegalStateException.class, consumer::enforceRebalance);
     }
 
     @Test
@@ -2857,7 +2865,7 @@ public class KafkaConsumerTest {
         initMetadata(client, Utils.mkMap(Utils.mkEntry(topic, 1)));
         Node node = metadata.fetch().nodes().get(0);
 
-        KafkaConsumer<String, String> consumer = newConsumer(
+        consumer = newConsumer(
             time,
             client,
             subscription,
@@ -2916,12 +2924,11 @@ public class KafkaConsumerTest {
         props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, DeserializerForClientId.class.getName());
         props.put(ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG, ConsumerInterceptorForClientId.class.getName());
 
-        KafkaConsumer<byte[], byte[]> consumer = new KafkaConsumer<>(props);
+        consumer = new KafkaConsumer<>(props);
         assertNotNull(consumer.getClientId());
         assertNotEquals(0, consumer.getClientId().length());
         assertEquals(3, CLIENT_IDS.size());
         CLIENT_IDS.forEach(id -> assertEquals(id, consumer.getClientId()));
-        consumer.close();
     }
 
     @Test
@@ -2933,9 +2940,8 @@ public class KafkaConsumerTest {
 
         assertTrue(config.unused().contains(SslConfigs.SSL_PROTOCOL_CONFIG));
 
-        try (KafkaConsumer<byte[], byte[]> consumer = new KafkaConsumer<>(config, null, null)) {
-            assertTrue(config.unused().contains(SslConfigs.SSL_PROTOCOL_CONFIG));
-        }
+        consumer = new KafkaConsumer<>(config, null, null);
+        assertTrue(config.unused().contains(SslConfigs.SSL_PROTOCOL_CONFIG));
     }
 
     @Test
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index 595f6404d63..eecb9fc190d 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -69,6 +69,7 @@ import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.network.NetworkReceive;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.FetchMetadata;
 import org.apache.kafka.common.requests.FetchRequest.PartitionData;
 import org.apache.kafka.common.utils.BufferSupplier;
 import org.apache.kafka.common.record.CompressionType;
@@ -101,6 +102,7 @@ import org.apache.kafka.common.serialization.StringDeserializer;
 import org.apache.kafka.common.utils.ByteBufferOutputStream;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.common.utils.Timer;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.test.DelayedReceive;
 import org.apache.kafka.test.MockSelector;
@@ -108,6 +110,7 @@ import org.apache.kafka.test.TestUtils;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 
 import java.io.DataOutputStream;
 import java.lang.reflect.Field;
@@ -155,6 +158,10 @@ import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.fail;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 
 public class FetcherTest {
     private static final double EPSILON = 0.0001;
@@ -288,6 +295,50 @@ public class FetcherTest {
         assertNull(fetchedRecords().get(tp0));
     }
 
+    @Test
+    public void testCloseShouldBeIdempotent() {
+        buildFetcher();
+
+        fetcher.close();
+        fetcher.close();
+        fetcher.close();
+
+        verify(fetcher, times(1)).maybeCloseFetchSessions(any(Timer.class));
+    }
+
+    @Test
+    public void testFetcherCloseClosesFetchSessionsInBroker() {
+        buildFetcher();
+
+        assignFromUser(singleton(tp0));
+        subscriptions.seek(tp0, 0);
+
+        // normal fetch
+        assertEquals(1, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+
+        final FetchResponse fetchResponse = fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0);
+        client.prepareResponse(fetchResponse);
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+        assertEquals(0, consumerClient.pendingRequestCount());
+
+        final ArgumentCaptor<FetchRequest.Builder> argument = ArgumentCaptor.forClass(FetchRequest.Builder.class);
+
+        // send request to close the fetcher
+        this.fetcher.close(time.timer(Duration.ofSeconds(10)));
+
+        // validate that Fetcher.close() has sent a request with final epoch. 2 requests are sent, one for the normal
+        // fetch earlier and another for the finish fetch here.
+        verify(consumerClient, times(2)).send(any(Node.class), argument.capture());
+        FetchRequest.Builder builder = argument.getValue();
+        // session Id is the same
+        assertEquals(fetchResponse.sessionId(), builder.metadata().sessionId());
+        // contains final epoch
+        assertEquals(FetchMetadata.FINAL_EPOCH, builder.metadata().epoch());  // final epoch indicates we want to close the session
+        assertTrue(builder.fetchData().isEmpty()); // partition data should be empty
+    }
+
     @Test
     public void testFetchingPendingPartitions() {
         buildFetcher();
@@ -5270,7 +5321,7 @@ public class FetcherTest {
                                      SubscriptionState subscriptionState,
                                      LogContext logContext) {
         buildDependencies(metricConfig, metadataExpireMs, subscriptionState, logContext);
-        fetcher = new Fetcher<>(
+        fetcher = spy(new Fetcher<>(
                 new LogContext(),
                 consumerClient,
                 minBytes,
@@ -5290,7 +5341,7 @@ public class FetcherTest {
                 retryBackoffMs,
                 requestTimeoutMs,
                 isolationLevel,
-                apiVersions);
+                apiVersions));
     }
 
     private void buildDependencies(MetricConfig metricConfig,
@@ -5303,8 +5354,8 @@ public class FetcherTest {
                 subscriptions, logContext, new ClusterResourceListeners());
         client = new MockClient(time, metadata);
         metrics = new Metrics(metricConfig, time);
-        consumerClient = new ConsumerNetworkClient(logContext, client, metadata, time,
-                100, 1000, Integer.MAX_VALUE);
+        consumerClient = spy(new ConsumerNetworkClient(logContext, client, metadata, time,
+                100, 1000, Integer.MAX_VALUE));
         metricsRegistry = new FetcherMetricsRegistry(metricConfig.tags().keySet(), "consumer" + groupId);
     }
 
diff --git a/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java b/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java
index ecc0e3f3430..38287e40053 100644
--- a/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java
+++ b/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java
@@ -45,6 +45,7 @@ import org.apache.kafka.server.fault.MockFaultHandler;
 import org.apache.kafka.test.TestUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import org.slf4j.event.Level;
 import scala.Option;
 import scala.collection.JavaConverters;
 
@@ -240,7 +241,7 @@ public class KafkaClusterTestKit implements AutoCloseable {
                                 bootstrapMetadata);
                     } catch (Throwable e) {
                         log.error("Error creating controller {}", node.id(), e);
-                        Utils.swallow(log, "sharedServer.stopForController", () -> sharedServer.stopForController());
+                        Utils.swallow(log, Level.WARN, "sharedServer.stopForController error", () -> sharedServer.stopForController());
                         if (controller != null) controller.shutdown();
                         throw e;
                     }
@@ -270,7 +271,7 @@ public class KafkaClusterTestKit implements AutoCloseable {
                                 JavaConverters.asScalaBuffer(Collections.<String>emptyList()).toSeq());
                     } catch (Throwable e) {
                         log.error("Error creating broker {}", node.id(), e);
-                        Utils.swallow(log, "sharedServer.stopForBroker", () -> sharedServer.stopForBroker());
+                        Utils.swallow(log, Level.WARN, "sharedServer.stopForBroker error", () -> sharedServer.stopForBroker());
                         if (broker != null) broker.shutdown();
                         throw e;
                     }