You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2018/09/14 18:56:52 UTC

[kafka] branch 2.0 updated: KAFKA-7280; Synchronize consumer fetch request/response handling (#5495)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 2.0
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.0 by this push:
     new 00f2ce2  KAFKA-7280; Synchronize consumer fetch request/response handling (#5495)
00f2ce2 is described below

commit 00f2ce27cd5cba03c021c4a4da2bf408f51d0ac9
Author: Rajini Sivaram <ra...@googlemail.com>
AuthorDate: Fri Sep 14 19:51:46 2018 +0100

    KAFKA-7280; Synchronize consumer fetch request/response handling (#5495)
    
    This patch fixes unsafe concurrent access in the consumer by the heartbeat thread and the thread calling `poll()` to the fetch session state in `FetchSessionHandler`.
    
    Reviewers: Viktor Somogyi <vi...@gmail.com>, Jason Gustafson <ja...@confluent.io>
---
 .../kafka/clients/consumer/internals/Fetcher.java  |  85 ++++++----
 .../clients/consumer/internals/FetcherTest.java    | 173 ++++++++++++++++++---
 2 files changed, 209 insertions(+), 49 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index d1ec117..af18d09 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -91,7 +91,21 @@ import static java.util.Collections.emptyList;
 import static org.apache.kafka.common.serialization.ExtendedDeserializer.Wrapper.ensureExtended;
 
 /**
- * This class manage the fetching process with the brokers.
+ * This class manages the fetching process with the brokers.
+ * <p>
+ * Thread-safety:
+ * Requests and responses of Fetcher may be processed by different threads since heartbeat
+ * thread may process responses. Other operations are single-threaded and invoked only from
+ * the thread polling the consumer.
+ * <ul>
+ *     <li>If a response handler accesses any shared state of the Fetcher (e.g. FetchSessionHandler),
+ *     all access to that state must be synchronized on the Fetcher instance.</li>
+ *     <li>If a response handler accesses any shared state of the coordinator (e.g. SubscriptionState),
+ *     it is assumed that all access to that state is synchronized on the coordinator instance by
+ *     the caller.</li>
+ *     <li>Responses that collate partial responses from multiple brokers (e.g. to list offsets) are
+ *     synchronized on the response future.</li>
+ * </ul>
  */
 public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
     private final Logger log;
@@ -187,7 +201,7 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
      * an in-flight fetch or pending fetch data.
      * @return number of fetches sent
      */
-    public int sendFetches() {
+    public synchronized int sendFetches() {
         Map<Node, FetchSessionHandler.FetchRequestData> fetchRequestMap = prepareFetchRequests();
         for (Map.Entry<Node, FetchSessionHandler.FetchRequestData> entry : fetchRequestMap.entrySet()) {
             final Node fetchTarget = entry.getKey();
@@ -205,39 +219,43 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
                     .addListener(new RequestFutureListener<ClientResponse>() {
                         @Override
                         public void onSuccess(ClientResponse resp) {
-                            FetchResponse<Records> response = (FetchResponse<Records>) resp.responseBody();
-                            FetchSessionHandler handler = sessionHandlers.get(fetchTarget.id());
-                            if (handler == null) {
-                                log.error("Unable to find FetchSessionHandler for node {}. Ignoring fetch response.",
-                                    fetchTarget.id());
-                                return;
+                            synchronized (Fetcher.this) {
+                                FetchResponse<Records> response = (FetchResponse<Records>) resp.responseBody();
+                                FetchSessionHandler handler = sessionHandler(fetchTarget.id());
+                                if (handler == null) {
+                                    log.error("Unable to find FetchSessionHandler for node {}. Ignoring fetch response.",
+                                            fetchTarget.id());
+                                    return;
+                                }
+                                if (!handler.handleResponse(response)) {
+                                    return;
+                                }
+
+                                Set<TopicPartition> partitions = new HashSet<>(response.responseData().keySet());
+                                FetchResponseMetricAggregator metricAggregator = new FetchResponseMetricAggregator(sensors, partitions);
+
+                                for (Map.Entry<TopicPartition, FetchResponse.PartitionData<Records>> entry : response.responseData().entrySet()) {
+                                    TopicPartition partition = entry.getKey();
+                                    long fetchOffset = data.sessionPartitions().get(partition).fetchOffset;
+                                    FetchResponse.PartitionData fetchData = entry.getValue();
+
+                                    log.debug("Fetch {} at offset {} for partition {} returned fetch data {}",
+                                            isolationLevel, fetchOffset, partition, fetchData);
+                                    completedFetches.add(new CompletedFetch(partition, fetchOffset, fetchData, metricAggregator,
+                                            resp.requestHeader().apiVersion()));
+                                }
+
+                                sensors.fetchLatency.record(resp.requestLatencyMs());
                             }
-                            if (!handler.handleResponse(response)) {
-                                return;
-                            }
-
-                            Set<TopicPartition> partitions = new HashSet<>(response.responseData().keySet());
-                            FetchResponseMetricAggregator metricAggregator = new FetchResponseMetricAggregator(sensors, partitions);
-
-                            for (Map.Entry<TopicPartition, FetchResponse.PartitionData<Records>> entry : response.responseData().entrySet()) {
-                                TopicPartition partition = entry.getKey();
-                                long fetchOffset = data.sessionPartitions().get(partition).fetchOffset;
-                                FetchResponse.PartitionData fetchData = entry.getValue();
-
-                                log.debug("Fetch {} at offset {} for partition {} returned fetch data {}",
-                                        isolationLevel, fetchOffset, partition, fetchData);
-                                completedFetches.add(new CompletedFetch(partition, fetchOffset, fetchData, metricAggregator,
-                                        resp.requestHeader().apiVersion()));
-                            }
-
-                            sensors.fetchLatency.record(resp.requestLatencyMs());
                         }
 
                         @Override
                         public void onFailure(RuntimeException e) {
-                            FetchSessionHandler handler = sessionHandlers.get(fetchTarget.id());
-                            if (handler != null) {
-                                handler.handleError(e);
+                            synchronized (Fetcher.this) {
+                                FetchSessionHandler handler = sessionHandler(fetchTarget.id());
+                                if (handler != null) {
+                                    handler.handleError(e);
+                                }
                             }
                         }
                     });
@@ -880,7 +898,7 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
                 // if there is a leader and no in-flight requests, issue a new fetch
                 FetchSessionHandler.Builder builder = fetchable.get(node);
                 if (builder == null) {
-                    FetchSessionHandler handler = sessionHandlers.get(node.id());
+                    FetchSessionHandler handler = sessionHandler(node.id());
                     if (handler == null) {
                         handler = new FetchSessionHandler(logContext, node.id());
                         sessionHandlers.put(node.id(), handler);
@@ -1037,6 +1055,11 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
         sensors.updatePartitionLagAndLeadSensors(assignment);
     }
 
+    // Visibilty for testing
+    protected FetchSessionHandler sessionHandler(int node) {
+        return sessionHandlers.get(node);
+    }
+
     public static Sensor throttleTimeSensor(Metrics metrics, FetcherMetricsRegistry metricsRegistry) {
         Sensor fetchThrottleTimeSensor = metrics.sensor("fetch-throttle-time");
         fetchThrottleTimeSensor.add(metrics.metricInstance(metricsRegistry.fetchThrottleTimeAvg), new Avg());
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index b67d48e..7f550d3 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -19,6 +19,7 @@ package org.apache.kafka.clients.consumer.internals;
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.ClientRequest;
 import org.apache.kafka.clients.ClientUtils;
+import org.apache.kafka.clients.FetchSessionHandler;
 import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.NetworkClient;
@@ -89,6 +90,7 @@ import org.junit.Before;
 import org.junit.Test;
 
 import java.io.DataOutputStream;
+import java.lang.reflect.Field;
 import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
@@ -101,6 +103,13 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
 import static java.util.Collections.singleton;
 import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
@@ -112,6 +121,7 @@ import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
+
 @SuppressWarnings("deprecation")
 public class FetcherTest {
     private ConsumerRebalanceListener listener = new NoOpConsumerRebalanceListener();
@@ -149,38 +159,30 @@ public class FetcherTest {
     private Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, metrics);
     private Metrics fetcherMetrics = new Metrics(time);
     private Fetcher<byte[], byte[]> fetcherNoAutoReset = createFetcher(subscriptionsNoAutoReset, fetcherMetrics);
+    private ExecutorService executorService;
 
     @Before
     public void setup() throws Exception {
         metadata.update(cluster, Collections.<String>emptySet(), time.milliseconds());
         client.setNode(node);
 
-        MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, TimestampType.CREATE_TIME, 1L);
-        builder.append(0L, "key".getBytes(), "value-1".getBytes());
-        builder.append(0L, "key".getBytes(), "value-2".getBytes());
-        builder.append(0L, "key".getBytes(), "value-3".getBytes());
-        records = builder.build();
-
-        builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, TimestampType.CREATE_TIME, 4L);
-        builder.append(0L, "key".getBytes(), "value-4".getBytes());
-        builder.append(0L, "key".getBytes(), "value-5".getBytes());
-        nextRecords = builder.build();
-
-        builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, TimestampType.CREATE_TIME, 0L);
-        emptyRecords = builder.build();
-
-        builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, TimestampType.CREATE_TIME, 4L);
-        builder.append(0L, "key".getBytes(), "value-0".getBytes());
-        partialRecords = builder.build();
+        records = buildRecords(1L, 3, 1);
+        nextRecords = buildRecords(4L, 2, 4);
+        emptyRecords = buildRecords(0L, 0, 0);
+        partialRecords = buildRecords(4L, 1, 0);
         partialRecords.buffer().putInt(Records.SIZE_OFFSET, 10000);
     }
 
     @After
-    public void teardown() {
+    public void teardown() throws Exception {
         this.metrics.close();
         this.fetcherMetrics.close();
         this.fetcher.close();
         this.fetcherMetrics.close();
+        if (executorService != null) {
+            executorService.shutdownNow();
+            assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS));
+        }
     }
 
     @Test
@@ -2408,6 +2410,141 @@ public class FetcherTest {
         assertEquals(5, records.get(1).offset());
     }
 
+    @Test
+    public void testFetcherConcurrency() throws Exception {
+        int numPartitions = 20;
+        Set<TopicPartition> topicPartitions = new HashSet<>();
+        for (int i = 0; i < numPartitions; i++)
+            topicPartitions.add(new TopicPartition(topicName, i));
+        cluster = TestUtils.singletonCluster(topicName, numPartitions);
+        metadata.update(cluster, Collections.emptySet(), time.milliseconds());
+        client.setNode(node);
+        fetchSize = 10000;
+
+        Fetcher<byte[], byte[]> fetcher = new Fetcher<byte[], byte[]>(
+                new LogContext(),
+                consumerClient,
+                minBytes,
+                maxBytes,
+                maxWaitMs,
+                fetchSize,
+                2 * numPartitions,
+                true,
+                new ByteArrayDeserializer(),
+                new ByteArrayDeserializer(),
+                metadata,
+                subscriptions,
+                metrics,
+                metricsRegistry,
+                time,
+                retryBackoffMs,
+                requestTimeoutMs,
+                IsolationLevel.READ_UNCOMMITTED) {
+            @Override
+            protected FetchSessionHandler sessionHandler(int id) {
+                final FetchSessionHandler handler = super.sessionHandler(id);
+                if (handler == null)
+                    return null;
+                else {
+                    return new FetchSessionHandler(new LogContext(), id) {
+                        @Override
+                        public Builder newBuilder() {
+                            verifySessionPartitions();
+                            return handler.newBuilder();
+                        }
+
+                        @Override
+                        public boolean handleResponse(FetchResponse response) {
+                            verifySessionPartitions();
+                            return handler.handleResponse(response);
+                        }
+
+                        @Override
+                        public void handleError(Throwable t) {
+                            verifySessionPartitions();
+                            handler.handleError(t);
+                        }
+
+                        // Verify that session partitions can be traversed safely.
+                        private void verifySessionPartitions() {
+                            try {
+                                Field field = FetchSessionHandler.class.getDeclaredField("sessionPartitions");
+                                field.setAccessible(true);
+                                LinkedHashMap<TopicPartition, FetchRequest.PartitionData> sessionPartitions =
+                                        (LinkedHashMap<TopicPartition, FetchRequest.PartitionData>) field.get(handler);
+                                for (Map.Entry<TopicPartition, FetchRequest.PartitionData> entry : sessionPartitions.entrySet()) {
+                                    // If `sessionPartitions` are modified on another thread, Thread.yield will increase the
+                                    // possibility of ConcurrentModificationException if appropriate synchronization is not used.
+                                    Thread.yield();
+                                }
+                            } catch (Exception e) {
+                                throw new RuntimeException(e);
+                            }
+                        }
+                    };
+                }
+            }
+        };
+
+        subscriptions.assignFromUser(topicPartitions);
+        topicPartitions.forEach(tp -> subscriptions.seek(tp, 0L));
+
+        AtomicInteger fetchesRemaining = new AtomicInteger(1000);
+        executorService = Executors.newSingleThreadExecutor();
+        Future<?> future = executorService.submit(() -> {
+            while (fetchesRemaining.get() > 0) {
+                synchronized (consumerClient) {
+                    if (!client.requests().isEmpty()) {
+                        ClientRequest request = client.requests().peek();
+                        FetchRequest fetchRequest = (FetchRequest) request.requestBuilder().build();
+                        LinkedHashMap<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> responseMap = new LinkedHashMap<>();
+                        for (Map.Entry<TopicPartition, FetchRequest.PartitionData> entry : fetchRequest.fetchData().entrySet()) {
+                            TopicPartition tp = entry.getKey();
+                            long offset = entry.getValue().fetchOffset;
+                            responseMap.put(tp, new FetchResponse.PartitionData<>(Errors.NONE, offset + 2L, offset + 2,
+                                    0L, null, buildRecords(offset, 2, offset)));
+                        }
+                        client.respondToRequest(request, new FetchResponse<>(Errors.NONE, responseMap, 0, 123));
+                        consumerClient.poll(0);
+                    }
+                }
+            }
+            return fetchesRemaining.get();
+        });
+        Map<TopicPartition, Long> nextFetchOffsets = topicPartitions.stream()
+                .collect(Collectors.toMap(Function.identity(), t -> 0L));
+        while (fetchesRemaining.get() > 0 && !future.isDone()) {
+            if (fetcher.sendFetches() == 1) {
+                synchronized (consumerClient) {
+                    consumerClient.poll(0);
+                }
+            }
+            if (fetcher.hasCompletedFetches()) {
+                Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+                if (!fetchedRecords.isEmpty()) {
+                    fetchesRemaining.decrementAndGet();
+                    fetchedRecords.entrySet().forEach(entry -> {
+                        TopicPartition tp = entry.getKey();
+                        List<ConsumerRecord<byte[], byte[]>> records = entry.getValue();
+                        assertEquals(2, records.size());
+                        long nextOffset = nextFetchOffsets.get(tp);
+                        assertEquals(nextOffset, records.get(0).offset());
+                        assertEquals(nextOffset + 1, records.get(1).offset());
+                        nextFetchOffsets.put(tp, nextOffset + 2);
+                    });
+                }
+            }
+        }
+        assertEquals(0, future.get());
+    }
+
+    private MemoryRecords buildRecords(long baseOffset, int count, long firstMessageId) {
+        MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, TimestampType.CREATE_TIME, baseOffset);
+        for (int i = 0; i < count; i++)
+            builder.append(0L, "key".getBytes(), ("value-" + (firstMessageId + i)).getBytes());
+        return builder.build();
+    }
+
     private int appendTransactionalRecords(ByteBuffer buffer, long pid, long baseOffset, int baseSequence, SimpleRecord... records) {
         MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
                 TimestampType.CREATE_TIME, baseOffset, time.milliseconds(), pid, (short) 0, baseSequence, true,