You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2019/04/15 22:56:52 UTC

[kafka] branch trunk updated: KAFKA-6635; Producer close awaits pending transactions (#5971)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 7bd8162  KAFKA-6635; Producer close awaits pending transactions (#5971)
7bd8162 is described below

commit 7bd81628d9f6c21649e615e73ed507520fd74fd9
Author: Viktor Somogyi <vi...@gmail.com>
AuthorDate: Tue Apr 16 00:56:36 2019 +0200

    KAFKA-6635; Producer close awaits pending transactions (#5971)
    
    Currently close() only awaits completion of pending produce requests. If there is a transaction ongoing, it may be dropped. For example, if one thread is calling commitTransaction() and another calls close(), then the commit may never happen even if the caller is willing to wait for it (by using a long timeout). What's more, the thread blocking in commitTransaction() will be stuck since the result will not be completed once the producer has shutdown.
    
    This patch ensures that 1) completing transactions are awaited, 2) ongoing transactions are aborted, and 3) pending callbacks are completed before close() returns.
    
    Reviewers: Jason Gustafson <ja...@confluent.io>
---
 checkstyle/suppressions.xml                        |   2 +-
 .../kafka/clients/producer/KafkaProducer.java      |  10 +-
 .../kafka/clients/producer/internals/Sender.java   |  32 ++++-
 .../producer/internals/TransactionManager.java     |  18 ++-
 .../kafka/clients/producer/KafkaProducerTest.java  | 149 +++++++++++++++++----
 .../clients/producer/internals/SenderTest.java     | 126 ++++++++++++++++-
 .../producer/internals/TransactionManagerTest.java |   8 +-
 7 files changed, 306 insertions(+), 39 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index f306aaa..ce2706d 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -54,7 +54,7 @@
               files="(ConsumerCoordinator|Fetcher|Sender|KafkaProducer|BufferPool|ConfigDef|RecordAccumulator|KerberosLogin|AbstractRequest|AbstractResponse|Selector|SslFactory|SslTransportLayer|SaslClientAuthenticator|SaslClientCallbackHandler|SaslServerAuthenticator|SchemaGenerator).java"/>
 
     <suppress checks="JavaNCSS"
-              files="AbstractRequest.java|KerberosLogin.java|WorkerSinkTaskTest.java|TransactionManagerTest.java"/>
+              files="AbstractRequest.java|KerberosLogin.java|WorkerSinkTaskTest.java|TransactionManagerTest.java|SenderTest.java"/>
 
     <suppress checks="NPathComplexity"
               files="(BufferPool|Fetcher|MetricName|Node|ConfigDef|RecordBatch|SslFactory|SslTransportLayer|MetadataResponse|KerberosLogin|Selector|Sender|Serdes|TokenInformation|Agent|Values|PluginUtils|MiniTrogdorCluster|TasksRequest).java"/>
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
index 3a0130f..e9d2626 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
@@ -611,6 +611,7 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
      */
     public void initTransactions() {
         throwIfNoTransactionManager();
+        throwIfProducerClosed();
         TransactionalRequestResult result = transactionManager.initializeTransactions();
         sender.wakeup();
         result.await(maxBlockTimeMs, TimeUnit.MILLISECONDS);
@@ -631,6 +632,7 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
      */
     public void beginTransaction() throws ProducerFencedException {
         throwIfNoTransactionManager();
+        throwIfProducerClosed();
         transactionManager.beginTransaction();
     }
 
@@ -661,6 +663,7 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
     public void sendOffsetsToTransaction(Map<TopicPartition, OffsetAndMetadata> offsets,
                                          String consumerGroupId) throws ProducerFencedException {
         throwIfNoTransactionManager();
+        throwIfProducerClosed();
         TransactionalRequestResult result = transactionManager.sendOffsetsToTransaction(offsets, consumerGroupId);
         sender.wakeup();
         result.await();
@@ -691,6 +694,7 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
      */
     public void commitTransaction() throws ProducerFencedException {
         throwIfNoTransactionManager();
+        throwIfProducerClosed();
         TransactionalRequestResult result = transactionManager.beginCommit();
         sender.wakeup();
         result.await(maxBlockTimeMs, TimeUnit.MILLISECONDS);
@@ -718,6 +722,7 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
      */
     public void abortTransaction() throws ProducerFencedException {
         throwIfNoTransactionManager();
+        throwIfProducerClosed();
         TransactionalRequestResult result = transactionManager.beginAbort();
         sender.wakeup();
         result.await(maxBlockTimeMs, TimeUnit.MILLISECONDS);
@@ -848,7 +853,7 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
     // Verify that this producer instance has not been closed. This method throws IllegalStateException if the producer
     // has already been closed.
     private void throwIfProducerClosed() {
-        if (ioThread == null || !ioThread.isAlive())
+        if (sender == null || !sender.isRunning())
             throw new IllegalStateException("Cannot perform operation after producer has been closed");
     }
 
@@ -1117,7 +1122,8 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
      * This method waits up to <code>timeout</code> for the producer to complete the sending of all incomplete requests.
      * <p>
      * If the producer is unable to complete all requests before the timeout expires, this method will fail
-     * any unsent and unacknowledged records immediately.
+     * any unsent and unacknowledged records immediately. It will also abort the ongoing transaction if it's not
+     * already completing.
      * <p>
      * If invoked from within a {@link Callback} this method will not block and will be equivalent to
      * <code>close(Duration.ofMillis(0))</code>. This is done since no further sending will happen while
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
index 33bc496..326b938 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
@@ -224,6 +224,10 @@ public class Sender implements Runnable {
         }
     }
 
+    private boolean hasPendingTransactionalRequests() {
+        return transactionManager != null && transactionManager.hasPendingRequests() && transactionManager.hasOngoingTransaction();
+    }
+
     /**
      * The main run loop for the sender thread
      */
@@ -242,18 +246,36 @@ public class Sender implements Runnable {
         log.debug("Beginning shutdown of Kafka producer I/O thread, sending remaining records.");
 
         // okay we stopped accepting requests but there may still be
-        // requests in the accumulator or waiting for acknowledgment,
+        // requests in the transaction manager, accumulator or waiting for acknowledgment,
         // wait until these are completed.
-        while (!forceClose && (this.accumulator.hasUndrained() || this.client.inFlightRequestCount() > 0)) {
+        while (!forceClose && ((this.accumulator.hasUndrained() || this.client.inFlightRequestCount() > 0) || hasPendingTransactionalRequests())) {
+            try {
+                run(time.milliseconds());
+            } catch (Exception e) {
+                log.error("Uncaught error in kafka producer I/O thread: ", e);
+            }
+        }
+
+        // Abort the transaction if any commit or abort didn't go through the transaction manager's queue
+        while (!forceClose && transactionManager != null && transactionManager.hasOngoingTransaction()) {
+            if (!transactionManager.isCompleting()) {
+                log.info("Aborting incomplete transaction due to shutdown");
+                transactionManager.beginAbort();
+            }
             try {
                 run(time.milliseconds());
             } catch (Exception e) {
                 log.error("Uncaught error in kafka producer I/O thread: ", e);
             }
         }
+
         if (forceClose) {
-            // We need to fail all the incomplete batches and wake up the threads waiting on
+            // We need to fail all the incomplete transactional requests and batches and wake up the threads waiting on
             // the futures.
+            if (transactionManager != null) {
+                log.debug("Aborting incomplete transactional requests due to forced shutdown");
+                transactionManager.close();
+            }
             log.debug("Aborting incomplete batches due to forced shutdown");
             this.accumulator.abortIncompleteBatches();
         }
@@ -479,6 +501,10 @@ public class Sender implements Runnable {
         initiateClose();
     }
 
+    public boolean isRunning() {
+        return running;
+    }
+
     private ClientResponse sendAndAwaitInitProducerIdRequest(Node node) throws IOException {
         String nodeId = node.idString();
         InitProducerIdRequestData requestData = new InitProducerIdRequestData()
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
index b34cc98..024d882 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
@@ -310,7 +310,7 @@ public class TransactionManager {
 
         log.debug("Begin adding offsets {} for consumer group {} to transaction", offsets, consumerGroupId);
         AddOffsetsToTxnRequest.Builder builder = new AddOffsetsToTxnRequest.Builder(transactionalId,
-                producerIdAndEpoch.producerId, producerIdAndEpoch.epoch, consumerGroupId);
+            producerIdAndEpoch.producerId, producerIdAndEpoch.epoch, consumerGroupId);
         AddOffsetsToTxnHandler handler = new AddOffsetsToTxnHandler(builder, offsets);
         enqueueRequest(handler);
         return handler.result;
@@ -684,6 +684,16 @@ public class TransactionManager {
             request.fatalError(e);
     }
 
+    synchronized void close() {
+        KafkaException shutdownException = new KafkaException("The producer closed forcefully");
+        pendingRequests.forEach(handler ->
+                handler.fatalError(shutdownException));
+        if (pendingResult != null) {
+            pendingResult.setError(shutdownException);
+            pendingResult.done();
+        }
+    }
+
     Node coordinator(FindCoordinatorRequest.CoordinatorType type) {
         switch (type) {
             case GROUP:
@@ -731,6 +741,10 @@ public class TransactionManager {
         return !pendingTxnOffsetCommits.isEmpty();
     }
 
+    synchronized boolean hasPendingRequests() {
+        return !pendingRequests.isEmpty();
+    }
+
     // visible for testing
     synchronized boolean hasOngoingTransaction() {
         // transactions are considered ongoing once started until completion or a fatal error
@@ -799,7 +813,7 @@ public class TransactionManager {
 
         if (target == State.FATAL_ERROR || target == State.ABORTABLE_ERROR) {
             if (error == null)
-                throw new IllegalArgumentException("Cannot transition to " + target + " with an null exception");
+                throw new IllegalArgumentException("Cannot transition to " + target + " with a null exception");
             lastError = error;
         } else {
             lastError = null;
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
index 8d74c6b..0465414 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
@@ -35,6 +35,7 @@ import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.network.Selectable;
 import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.FindCoordinatorResponse;
 import org.apache.kafka.common.requests.MetadataResponse;
 import org.apache.kafka.common.serialization.ByteArraySerializer;
 import org.apache.kafka.common.serialization.Serializer;
@@ -59,6 +60,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -72,6 +74,7 @@ import static java.util.Collections.emptyMap;
 import static java.util.Collections.singletonMap;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
@@ -86,20 +89,21 @@ import static org.mockito.Mockito.when;
 
 public class KafkaProducerTest {
     private String topic = "topic";
-    private Collection<Node> nodes = Collections.singletonList(new Node(0, "host1", 1000));
+    private Node host1 = new Node(0, "host1", 1000);
+    private Collection<Node> nodes = Collections.singletonList(host1);
     private final Cluster emptyCluster = new Cluster(null, nodes,
             Collections.emptySet(),
             Collections.emptySet(),
             Collections.emptySet());
     private final Cluster onePartitionCluster = new Cluster(
             "dummy",
-            Collections.singletonList(new Node(0, "host1", 1000)),
+            Collections.singletonList(host1),
             Collections.singletonList(new PartitionInfo(topic, 0, null, null, null)),
             Collections.emptySet(),
             Collections.emptySet());
     private final Cluster threePartitionCluster = new Cluster(
             "dummy",
-            Collections.singletonList(new Node(0, "host1", 1000)),
+            Collections.singletonList(host1),
             Arrays.asList(
                     new PartitionInfo(topic, 0, null, null, null),
                     new PartitionInfo(topic, 1, null, null, null),
@@ -497,12 +501,7 @@ public class KafkaProducerTest {
                 }
             });
             t.start();
-            try {
-                producer.partitionsFor(topic);
-                fail("Expect TimeoutException");
-            } catch (TimeoutException e) {
-                // skip
-            }
+            assertThrows(TimeoutException.class, () -> producer.partitionsFor(topic));
             running.set(false);
             t.join();
         }
@@ -553,12 +552,7 @@ public class KafkaProducerTest {
         producer.send(record, null);
 
         //ensure headers are closed and cannot be mutated post send
-        try {
-            record.headers().add(new RecordHeader("test", "test".getBytes()));
-            fail("Expected IllegalStateException to be raised");
-        } catch (IllegalStateException ise) {
-            //expected
-        }
+        assertThrows(IllegalStateException.class, () -> record.headers().add(new RecordHeader("test", "test".getBytes())));
 
         //ensure existing headers are not changed, and last header for key is still original value
         assertArrayEquals(record.headers().lastHeader("test").value(), "header2".getBytes());
@@ -625,14 +619,11 @@ public class KafkaProducerTest {
         Properties props = new Properties();
         props.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000");
         try (KafkaProducer<byte[], byte[]> producer = new KafkaProducer<>(props, new ByteArraySerializer(), new ByteArraySerializer())) {
-            producer.partitionsFor(null);
-            fail("Expected NullPointerException to be raised");
-        } catch (NullPointerException e) {
-            // expected
+            assertThrows(NullPointerException.class, () -> producer.partitionsFor(null));
         }
     }
 
-    @Test(expected = TimeoutException.class)
+    @Test
     public void testInitTransactionTimeout() {
         Map<String, Object> configs = new HashMap<>();
         configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "bad-transaction");
@@ -648,12 +639,11 @@ public class KafkaProducerTest {
 
         try (Producer<String, String> producer = new KafkaProducer<>(configs, new StringSerializer(),
                 new StringSerializer(), metadata, client, null, time)) {
-            producer.initTransactions();
-            fail("initTransactions() should have raised TimeoutException");
+            assertThrows(TimeoutException.class, producer::initTransactions);
         }
     }
 
-    @Test(expected = KafkaException.class)
+    @Test
     public void testOnlyCanExecuteCloseAfterInitTransactionsTimeout() {
         Map<String, Object> configs = new HashMap<>();
         configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "bad-transaction");
@@ -669,14 +659,10 @@ public class KafkaProducerTest {
 
         Producer<String, String> producer = new KafkaProducer<>(configs, new StringSerializer(), new StringSerializer(),
                 metadata, client, null, time);
-        try {
-            producer.initTransactions();
-        } catch (TimeoutException e) {
-            // expected
-        }
+        assertThrows(TimeoutException.class, producer::initTransactions);
         // other transactional operations should not be allowed if we catch the error after initTransactions failed
         try {
-            producer.beginTransaction();
+            assertThrows(KafkaException.class, producer::beginTransaction);
         } finally {
             producer.close(Duration.ofMillis(0));
         }
@@ -766,6 +752,111 @@ public class KafkaProducerTest {
         }
     }
 
+    @Test
+    public void testTransactionalMethodThrowsWhenSenderClosed() {
+        Map<String, Object> configs = new HashMap<>();
+        configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000");
+        configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "this-is-a-transactional-id");
+
+        Time time = new MockTime();
+        MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, emptyMap());
+        ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE);
+        metadata.update(initialUpdateResponse, time.milliseconds());
+
+        MockClient client = new MockClient(time, metadata);
+
+        Producer<String, String> producer = new KafkaProducer<>(configs, new StringSerializer(), new StringSerializer(),
+                metadata, client, null, time);
+        producer.close();
+        assertThrows(IllegalStateException.class, producer::initTransactions);
+    }
+
+    @Test(timeout = 5000)
+    public void testCloseIsForcedOnPendingFindCoordinator() throws InterruptedException {
+        Map<String, Object> configs = new HashMap<>();
+        configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000");
+        configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "this-is-a-transactional-id");
+
+        Time time = new MockTime();
+        MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, singletonMap("testTopic", 1));
+        ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE);
+        metadata.update(initialUpdateResponse, time.milliseconds());
+
+        MockClient client = new MockClient(time, metadata);
+
+        Producer<String, String> producer = new KafkaProducer<>(configs, new StringSerializer(), new StringSerializer(),
+                metadata, client, null, time);
+
+        ExecutorService executorService = Executors.newSingleThreadExecutor();
+        CountDownLatch assertionDoneLatch = new CountDownLatch(1);
+        executorService.submit(() -> {
+            assertThrows(KafkaException.class, producer::initTransactions);
+            assertionDoneLatch.countDown();
+        });
+
+        client.waitForRequests(1, 2000);
+        producer.close(Duration.ofMillis(1000));
+        assertionDoneLatch.await();
+    }
+
+    @Test(timeout = 5000)
+    public void testCloseIsForcedOnPendingInitProducerId() throws InterruptedException {
+        Map<String, Object> configs = new HashMap<>();
+        configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000");
+        configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "this-is-a-transactional-id");
+
+        Time time = new MockTime();
+        MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, singletonMap("testTopic", 1));
+        ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE);
+        metadata.update(initialUpdateResponse, time.milliseconds());
+
+        MockClient client = new MockClient(time, metadata);
+
+        Producer<String, String> producer = new KafkaProducer<>(configs, new StringSerializer(), new StringSerializer(),
+                metadata, client, null, time);
+
+        ExecutorService executorService = Executors.newSingleThreadExecutor();
+        CountDownLatch assertionDoneLatch = new CountDownLatch(1);
+        client.prepareResponse(new FindCoordinatorResponse(Errors.NONE, host1));
+        executorService.submit(() -> {
+            assertThrows(KafkaException.class, producer::initTransactions);
+            assertionDoneLatch.countDown();
+        });
+
+        client.waitForRequests(1, 2000);
+        producer.close(Duration.ofMillis(1000));
+        assertionDoneLatch.await();
+    }
+
+    @Test(timeout = 5000)
+    public void testCloseIsForcedOnPendingAddOffsetRequest() throws InterruptedException {
+        Map<String, Object> configs = new HashMap<>();
+        configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000");
+        configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "this-is-a-transactional-id");
+
+        Time time = new MockTime();
+        MetadataResponse initialUpdateResponse = TestUtils.metadataUpdateWith(1, singletonMap("testTopic", 1));
+        ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE);
+        metadata.update(initialUpdateResponse, time.milliseconds());
+
+        MockClient client = new MockClient(time, metadata);
+
+        Producer<String, String> producer = new KafkaProducer<>(configs, new StringSerializer(), new StringSerializer(),
+                metadata, client, null, time);
+
+        ExecutorService executorService = Executors.newSingleThreadExecutor();
+        CountDownLatch assertionDoneLatch = new CountDownLatch(1);
+        client.prepareResponse(new FindCoordinatorResponse(Errors.NONE, host1));
+        executorService.submit(() -> {
+            assertThrows(KafkaException.class, producer::initTransactions);
+            assertionDoneLatch.countDown();
+        });
+
+        client.waitForRequests(1, 2000);
+        producer.close(Duration.ofMillis(1000));
+        assertionDoneLatch.await();
+    }
+
     private ProducerMetadata newMetadata(long refreshBackoffMs, long expirationMs) {
         return new ProducerMetadata(refreshBackoffMs, expirationMs,
                 new LogContext(), new ClusterResourceListeners(), Time.SYSTEM);
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
index d397fd4..6436bcc 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
@@ -25,6 +25,7 @@ import org.apache.kafka.clients.NodeApiVersions;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.RecordMetadata;
 import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.MetricNameTemplate;
 import org.apache.kafka.common.Node;
@@ -55,6 +56,8 @@ import org.apache.kafka.common.record.RecordBatch;
 import org.apache.kafka.common.requests.AbstractRequest;
 import org.apache.kafka.common.requests.AddPartitionsToTxnResponse;
 import org.apache.kafka.common.requests.ApiVersionsResponse;
+import org.apache.kafka.common.requests.EndTxnRequest;
+import org.apache.kafka.common.requests.EndTxnResponse;
 import org.apache.kafka.common.requests.FindCoordinatorResponse;
 import org.apache.kafka.common.requests.InitProducerIdRequest;
 import org.apache.kafka.common.requests.InitProducerIdResponse;
@@ -62,6 +65,7 @@ import org.apache.kafka.common.requests.MetadataResponse;
 import org.apache.kafka.common.requests.ProduceRequest;
 import org.apache.kafka.common.requests.ProduceResponse;
 import org.apache.kafka.common.requests.ResponseHeader;
+import org.apache.kafka.common.requests.TransactionResult;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
@@ -93,9 +97,11 @@ import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.AdditionalMatchers.geq;
 import static org.mockito.ArgumentMatchers.any;
@@ -2193,6 +2199,124 @@ public class SenderTest {
         }
     }
 
+    @Test
+    public void testTransactionalRequestsSentOnShutdown() {
+        // create a sender with retries = 1
+        int maxRetries = 1;
+        Metrics m = new Metrics();
+        SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
+        try {
+            TransactionManager txnManager = new TransactionManager(logContext, "testTransactionalRequestsSentOnShutdown", 6000, 100);
+            Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL,
+                    maxRetries, senderMetrics, time, REQUEST_TIMEOUT, 50, txnManager, apiVersions);
+
+            ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
+            TopicPartition tp = new TopicPartition("testTransactionalRequestsSentOnShutdown", 1);
+
+            setupWithTransactionState(txnManager);
+            doInitTransactions(txnManager, producerIdAndEpoch);
+
+            txnManager.beginTransaction();
+            txnManager.maybeAddPartitionToTransaction(tp);
+            client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE)));
+            sender.run(time.milliseconds());
+            sender.initiateClose();
+            txnManager.beginCommit();
+            AssertEndTxnRequestMatcher endTxnMatcher = new AssertEndTxnRequestMatcher(TransactionResult.COMMIT);
+            client.prepareResponse(endTxnMatcher, new EndTxnResponse(0, Errors.NONE));
+            sender.run();
+            assertTrue("Response didn't match in test", endTxnMatcher.matched);
+        } finally {
+            m.close();
+        }
+    }
+
+    @Test
+    public void testIncompleteTransactionAbortOnShutdown() {
+        // create a sender with retries = 1
+        int maxRetries = 1;
+        Metrics m = new Metrics();
+        SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
+        try {
+            TransactionManager txnManager = new TransactionManager(logContext, "testIncompleteTransactionAbortOnShutdown", 6000, 100);
+            Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL,
+                    maxRetries, senderMetrics, time, REQUEST_TIMEOUT, 50, txnManager, apiVersions);
+
+            ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
+            TopicPartition tp = new TopicPartition("testIncompleteTransactionAbortOnShutdown", 1);
+
+            setupWithTransactionState(txnManager);
+            doInitTransactions(txnManager, producerIdAndEpoch);
+
+            txnManager.beginTransaction();
+            txnManager.maybeAddPartitionToTransaction(tp);
+            client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE)));
+            sender.run(time.milliseconds());
+            sender.initiateClose();
+            AssertEndTxnRequestMatcher endTxnMatcher = new AssertEndTxnRequestMatcher(TransactionResult.ABORT);
+            client.prepareResponse(endTxnMatcher, new EndTxnResponse(0, Errors.NONE));
+            sender.run();
+            assertTrue("Response didn't match in test", endTxnMatcher.matched);
+        } finally {
+            m.close();
+        }
+    }
+
+    @Test(timeout = 10000L)
+    public void testForceShutdownWithIncompleteTransaction() {
+        // create a sender with retries = 1
+        int maxRetries = 1;
+        Metrics m = new Metrics();
+        SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
+        try {
+            TransactionManager txnManager = new TransactionManager(logContext, "testForceShutdownWithIncompleteTransaction", 6000, 100);
+            Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL,
+                    maxRetries, senderMetrics, time, REQUEST_TIMEOUT, 50, txnManager, apiVersions);
+
+            ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
+            TopicPartition tp = new TopicPartition("testForceShutdownWithIncompleteTransaction", 1);
+
+            setupWithTransactionState(txnManager);
+            doInitTransactions(txnManager, producerIdAndEpoch);
+
+            txnManager.beginTransaction();
+            txnManager.maybeAddPartitionToTransaction(tp);
+            client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE)));
+            sender.run(time.milliseconds());
+
+            // Try to commit the transaction but it won't happen as we'll forcefully close the sender
+            TransactionalRequestResult commitResult = txnManager.beginCommit();
+
+            sender.forceClose();
+            sender.run();
+            assertThrows("The test expected to throw a KafkaException for forcefully closing the sender",
+                    KafkaException.class, commitResult::await);
+        } finally {
+            m.close();
+        }
+    }
+
+    class AssertEndTxnRequestMatcher implements MockClient.RequestMatcher {
+
+        private TransactionResult requiredResult;
+        private boolean matched = false;
+
+        AssertEndTxnRequestMatcher(TransactionResult requiredResult) {
+            this.requiredResult = requiredResult;
+        }
+
+        @Override
+        public boolean matches(AbstractRequest body) {
+            if (body instanceof EndTxnRequest) {
+                assertSame(requiredResult, ((EndTxnRequest) body).command());
+                matched = true;
+                return true;
+            } else {
+                return false;
+            }
+        }
+    }
+
     private class MatchingBufferPool extends BufferPool {
         IdentityHashMap<ByteBuffer, Boolean> allocatedBuffers;
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
index 1c47b9d..e830c3b 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
@@ -138,7 +138,7 @@ public class TransactionManagerTest {
     }
 
     @Test
-    public void testSenderShutdownWithPendingAddPartitions() throws Exception {
+    public void testSenderShutdownWithPendingTransactions() throws Exception {
         long pid = 13131L;
         short epoch = 1;
         doInitTransactions(pid, epoch);
@@ -152,6 +152,12 @@ public class TransactionManagerTest {
         prepareProduceResponse(Errors.NONE, pid, epoch);
 
         sender.initiateClose();
+        sender.run(time.milliseconds());
+        TransactionalRequestResult result = transactionManager.beginCommit();
+        sender.run(time.milliseconds());
+        prepareEndTxnResponse(Errors.NONE, TransactionResult.COMMIT, pid, epoch);
+        sender.run(time.milliseconds());
+        assertTrue(result.isCompleted());
         sender.run();
 
         assertTrue(sendFuture.isDone());