You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bo...@apache.org on 2021/01/21 18:24:51 UTC

[beam] branch master updated: [BEAM-11325] ReadFromKafkaDoFn should stop reading when topic/partition is removed or marked as stopped.

This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e8eed0b  [BEAM-11325] ReadFromKafkaDoFn should stop reading when topic/partition is removed or marked as stopped.
     new d2e1f69  Merge pull request #13710 from [BEAM-11325] ReadFromKafkaDoFn should stop reading when topic/partition is removed or marked as stopped
e8eed0b is described below

commit e8eed0bf70c334fe59327f0d70453302935410ee
Author: Boyuan Zhang <bo...@google.com>
AuthorDate: Fri Jan 8 13:43:06 2021 -0800

    [BEAM-11325] ReadFromKafkaDoFn should stop reading when topic/partition is removed or marked as stopped.
---
 .../java/org/apache/beam/sdk/io/kafka/KafkaIO.java |  36 ++-
 .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java       |  43 ++-
 .../beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java   | 344 +++++++++++++++++++++
 3 files changed, 421 insertions(+), 2 deletions(-)

diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index 60608e0..93759e6 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -141,6 +141,11 @@ import org.slf4j.LoggerFactory;
  *      // offset consumed by the pipeline can be committed back.
  *      .commitOffsetsInFinalize()
  *
+ *      // Specified a serializable function which can determine whether to stop reading from given
+ *      // TopicPartition during runtime. Note that only {@link ReadFromKafkaDoFn} respect the
+ *      // signal.
+ *      .withCheckStopReadingFn(new SerializedFunction<TopicPartition, Boolean>() {})
+ *
  *      // finally, if you don't need Kafka metadata, you can drop it.g
  *      .withoutMetadata() // PCollection<KV<Long, String>>
  *   )
@@ -514,6 +519,8 @@ public class KafkaIO {
 
     abstract @Nullable DeserializerProvider getValueDeserializerProvider();
 
+    abstract @Nullable SerializableFunction<TopicPartition, Boolean> getCheckStopReadingFn();
+
     abstract Builder<K, V> toBuilder();
 
     @Experimental(Kind.PORTABILITY)
@@ -553,6 +560,9 @@ public class KafkaIO {
       abstract Builder<K, V> setValueDeserializerProvider(
           DeserializerProvider deserializerProvider);
 
+      abstract Builder<K, V> setCheckStopReadingFn(
+          SerializableFunction<TopicPartition, Boolean> checkStopReadingFn);
+
       abstract Read<K, V> build();
 
       @Override
@@ -998,6 +1008,15 @@ public class KafkaIO {
       return toBuilder().setConsumerConfig(config).build();
     }
 
+    /**
+     * A custom {@link SerializableFunction} that determines whether the {@link ReadFromKafkaDoFn}
+     * should stop reading from the given {@link TopicPartition}.
+     */
+    public Read<K, V> withCheckStopReadingFn(
+        SerializableFunction<TopicPartition, Boolean> checkStopReadingFn) {
+      return toBuilder().setCheckStopReadingFn(checkStopReadingFn).build();
+    }
+
     /** Returns a {@link PTransform} for PCollection of {@link KV}, dropping Kafka metatdata. */
     public PTransform<PBegin, PCollection<KV<K, V>>> withoutMetadata() {
       return new TypedWithoutMetadata<>(this);
@@ -1080,7 +1099,8 @@ public class KafkaIO {
               .withKeyDeserializerProvider(getKeyDeserializerProvider())
               .withValueDeserializerProvider(getValueDeserializerProvider())
               .withManualWatermarkEstimator()
-              .withTimestampPolicyFactory(getTimestampPolicyFactory());
+              .withTimestampPolicyFactory(getTimestampPolicyFactory())
+              .withCheckStopReadingFn(getCheckStopReadingFn());
       if (isCommitOffsetsInFinalizeEnabled()) {
         readTransform = readTransform.commitOffsets();
       }
@@ -1267,6 +1287,8 @@ public class KafkaIO {
     abstract SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>>
         getConsumerFactoryFn();
 
+    abstract @Nullable SerializableFunction<TopicPartition, Boolean> getCheckStopReadingFn();
+
     abstract @Nullable SerializableFunction<KafkaRecord<K, V>, Instant>
         getExtractOutputTimestampFn();
 
@@ -1289,6 +1311,9 @@ public class KafkaIO {
       abstract ReadSourceDescriptors.Builder<K, V> setConsumerFactoryFn(
           SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFn);
 
+      abstract ReadSourceDescriptors.Builder<K, V> setCheckStopReadingFn(
+          SerializableFunction<TopicPartition, Boolean> checkStopReadingFn);
+
       abstract ReadSourceDescriptors.Builder<K, V> setKeyDeserializerProvider(
           DeserializerProvider deserializerProvider);
 
@@ -1403,6 +1428,15 @@ public class KafkaIO {
     }
 
     /**
+     * A custom {@link SerializableFunction} that determines whether the {@link ReadFromKafkaDoFn}
+     * should stop reading from the given {@link TopicPartition}.
+     */
+    public ReadSourceDescriptors<K, V> withCheckStopReadingFn(
+        SerializableFunction<TopicPartition, Boolean> checkStopReadingFn) {
+      return toBuilder().setCheckStopReadingFn(checkStopReadingFn).build();
+    }
+
+    /**
      * Updates configuration for the main consumer. This method merges updates from the provided map
      * with any prior updates using {@link KafkaIOUtils#DEFAULT_CONSUMER_PROPERTIES} as the starting
      * configuration.
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
index 08a590b..d12332a 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
@@ -20,8 +20,11 @@ package org.apache.beam.sdk.io.kafka;
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
 
 import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors;
@@ -52,6 +55,7 @@ import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.Deserializer;
 import org.joda.time.Duration;
@@ -116,6 +120,23 @@ import org.slf4j.LoggerFactory;
  * extractTimestampFn} and {@link
  * ReadSourceDescriptors#withMonotonicallyIncreasingWatermarkEstimator()} as the {@link
  * WatermarkEstimator}.
+ *
+ * <h4>Stop Reading from Removed {@link TopicPartition}</h4>
+ *
+ * {@link ReadFromKafkaDoFn} will stop reading from any removed {@link TopicPartition} automatically
+ * by querying Kafka {@link Consumer} APIs. Please note that stopping reading may not happen as soon
+ * as the {@link TopicPartition} is removed. For example, the removal could happen at the same time
+ * when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(java.time.Duration)}. In that
+ * case, the {@link ReadFromKafkaDoFn} will still output the fetched records.
+ *
+ * <h4>Stop Reading from Stopped {@link TopicPartition}</h4>
+ *
+ * {@link ReadFromKafkaDoFn} will also stop reading from certain {@link TopicPartition} if it's a
+ * good time to do so by querying {@link ReadFromKafkaDoFn#checkStopReadingFn}. {@link
+ * ReadFromKafkaDoFn#checkStopReadingFn} is a customer-provided callback which is used to determine
+ * whether to stop reading from the given {@link TopicPartition}. Similar to the mechanism of
+ * stopping reading from removed {@link TopicPartition}, the stopping reading may not happens
+ * immediately.
  */
 @UnboundedPerElement
 @SuppressWarnings({
@@ -134,12 +155,15 @@ class ReadFromKafkaDoFn<K, V>
     this.extractOutputTimestampFn = transform.getExtractOutputTimestampFn();
     this.createWatermarkEstimatorFn = transform.getCreateWatermarkEstimatorFn();
     this.timestampPolicyFactory = transform.getTimestampPolicyFactory();
+    this.checkStopReadingFn = transform.getCheckStopReadingFn();
   }
 
   private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class);
 
   private final Map<String, Object> offsetConsumerConfig;
 
+  private final SerializableFunction<TopicPartition, Boolean> checkStopReadingFn;
+
   private final SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>>
       consumerFactoryFn;
   private final SerializableFunction<KafkaRecord<K, V>, Instant> extractOutputTimestampFn;
@@ -275,7 +299,11 @@ class ReadFromKafkaDoFn<K, V>
       RestrictionTracker<OffsetRange, Long> tracker,
       WatermarkEstimator watermarkEstimator,
       OutputReceiver<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> receiver) {
-    // If there is no future work, resume with max timeout and move to the next element.
+    // Stop processing current TopicPartition when it's time to stop.
+    if (checkStopReadingFn != null
+        && checkStopReadingFn.apply(kafkaSourceDescriptor.getTopicPartition())) {
+      return ProcessContinuation.stop();
+    }
     Map<String, Object> updatedConsumerConfig =
         overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor);
     // If there is a timestampPolicyFactory, create the TimestampPolicy for current
@@ -288,6 +316,19 @@ class ReadFromKafkaDoFn<K, V>
               Optional.ofNullable(watermarkEstimator.currentWatermark()));
     }
     try (Consumer<byte[], byte[]> consumer = consumerFactoryFn.apply(updatedConsumerConfig)) {
+      // Check whether current TopicPartition is still available to read.
+      Set<TopicPartition> existingTopicPartitions = new HashSet<>();
+      for (List<PartitionInfo> topicPartitionList : consumer.listTopics().values()) {
+        topicPartitionList.forEach(
+            partitionInfo -> {
+              existingTopicPartitions.add(
+                  new TopicPartition(partitionInfo.topic(), partitionInfo.partition()));
+            });
+      }
+      if (!existingTopicPartitions.contains(kafkaSourceDescriptor.getTopicPartition())) {
+        return ProcessContinuation.stop();
+      }
+
       consumerSpEL.evaluateAssign(
           consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition()));
       long startOffset = tracker.currentRestriction().getFrom();
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java
new file mode 100644
index 0000000..62c28b2
--- /dev/null
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors;
+import org.apache.beam.sdk.io.range.OffsetRange;
+import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
+import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.MockConsumer;
+import org.apache.kafka.clients.consumer.OffsetAndTimestamp;
+import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.header.internals.RecordHeaders;
+import org.apache.kafka.common.serialization.StringDeserializer;
+import org.checkerframework.checker.initialization.qual.Initialized;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.testcontainers.shaded.com.google.common.collect.ImmutableMap;
+
+@SuppressWarnings({
+  "rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
+  "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+public class ReadFromKafkaDoFnTest {
+
+  private final TopicPartition topicPartition = new TopicPartition("topic", 0);
+
+  private final SimpleMockKafkaConsumer consumer =
+      new SimpleMockKafkaConsumer(OffsetResetStrategy.NONE, topicPartition);
+
+  private final ReadFromKafkaDoFn<String, String> dofnInstance =
+      new ReadFromKafkaDoFn(makeReadSourceDescriptor(consumer));
+
+  private ReadSourceDescriptors<String, String> makeReadSourceDescriptor(
+      Consumer kafkaMockConsumer) {
+    return ReadSourceDescriptors.<String, String>read()
+        .withKeyDeserializer(StringDeserializer.class)
+        .withValueDeserializer(StringDeserializer.class)
+        .withConsumerFactoryFn(
+            new SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>>() {
+              @Override
+              public Consumer<byte[], byte[]> apply(Map<String, Object> input) {
+                return kafkaMockConsumer;
+              }
+            })
+        .withBootstrapServers("bootstrap_server");
+  }
+
+  private static class SimpleMockKafkaConsumer extends MockConsumer<byte[], byte[]> {
+
+    private final TopicPartition topicPartition;
+    private boolean isRemoved = false;
+    private long currentPos = 0L;
+    private long startOffset = 0L;
+    private long startOffsetForTime = 0L;
+    private long numOfRecordsPerPoll;
+
+    public SimpleMockKafkaConsumer(
+        OffsetResetStrategy offsetResetStrategy, TopicPartition topicPartition) {
+      super(offsetResetStrategy);
+      this.topicPartition = topicPartition;
+    }
+
+    public void reset() {
+      this.isRemoved = false;
+      this.currentPos = 0L;
+      this.startOffset = 0L;
+      this.startOffsetForTime = 0L;
+      this.numOfRecordsPerPoll = 0L;
+    }
+
+    public void setRemoved() {
+      this.isRemoved = true;
+    }
+
+    public void setNumOfRecordsPerPoll(long num) {
+      this.numOfRecordsPerPoll = num;
+    }
+
+    public void setCurrentPos(long pos) {
+      this.currentPos = pos;
+    }
+
+    public void setStartOffsetForTime(long pos) {
+      this.startOffsetForTime = pos;
+    }
+
+    @Override
+    public synchronized Map<String, List<PartitionInfo>> listTopics() {
+      if (this.isRemoved) {
+        return ImmutableMap.of();
+      }
+      return ImmutableMap.of(
+          topicPartition.topic(),
+          ImmutableList.of(
+              new PartitionInfo(
+                  topicPartition.topic(), topicPartition.partition(), null, null, null)));
+    }
+
+    @Override
+    public synchronized void assign(Collection<TopicPartition> partitions) {
+      assertTrue(Iterables.getOnlyElement(partitions).equals(this.topicPartition));
+    }
+
+    @Override
+    public synchronized void seek(TopicPartition partition, long offset) {
+      assertTrue(partition.equals(this.topicPartition));
+      this.startOffset = offset;
+    }
+
+    @Override
+    public synchronized ConsumerRecords<byte[], byte[]> poll(long timeout) {
+      if (topicPartition == null) {
+        return ConsumerRecords.empty();
+      }
+      String key = "key";
+      String value = "value";
+      List<ConsumerRecord<byte[], byte[]>> records = new ArrayList<>();
+      for (long i = 0; i <= numOfRecordsPerPoll; i++) {
+        records.add(
+            new ConsumerRecord<byte[], byte[]>(
+                topicPartition.topic(),
+                topicPartition.partition(),
+                startOffset + i,
+                key.getBytes(Charsets.UTF_8),
+                value.getBytes(Charsets.UTF_8)));
+      }
+      if (records.isEmpty()) {
+        return ConsumerRecords.empty();
+      }
+      return new ConsumerRecords(ImmutableMap.of(topicPartition, records));
+    }
+
+    @Override
+    public synchronized Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(
+        Map<TopicPartition, Long> timestampsToSearch) {
+      assertTrue(
+          Iterables.getOnlyElement(
+                  timestampsToSearch.keySet().stream().collect(Collectors.toList()))
+              .equals(this.topicPartition));
+      return ImmutableMap.of(
+          topicPartition,
+          new OffsetAndTimestamp(
+              this.startOffsetForTime, Iterables.getOnlyElement(timestampsToSearch.values())));
+    }
+
+    @Override
+    public synchronized long position(TopicPartition partition) {
+      assertTrue(partition.equals(this.topicPartition));
+      return this.currentPos;
+    }
+  }
+
+  private static class MockOutputReceiver
+      implements OutputReceiver<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> {
+
+    private final List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> records =
+        new ArrayList<>();
+
+    @Override
+    public void output(KV<KafkaSourceDescriptor, KafkaRecord<String, String>> output) {}
+
+    @Override
+    public void outputWithTimestamp(
+        KV<KafkaSourceDescriptor, KafkaRecord<String, String>> output,
+        @UnknownKeyFor @NonNull @Initialized Instant timestamp) {
+      records.add(output);
+    }
+
+    public List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> getOutputs() {
+      return this.records;
+    }
+  }
+
+  private List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> createExpectedRecords(
+      KafkaSourceDescriptor descriptor,
+      long startOffset,
+      int numRecords,
+      String key,
+      String value) {
+    List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> records = new ArrayList<>();
+    for (int i = 0; i < numRecords; i++) {
+      records.add(
+          KV.of(
+              descriptor,
+              new KafkaRecord<String, String>(
+                  topicPartition.topic(),
+                  topicPartition.partition(),
+                  startOffset + i,
+                  -1L,
+                  KafkaTimestampType.NO_TIMESTAMP_TYPE,
+                  new RecordHeaders(),
+                  KV.of(key, value))));
+    }
+    return records;
+  }
+
+  @Before
+  public void setUp() throws Exception {
+    dofnInstance.setup();
+    consumer.reset();
+  }
+
+  @Test
+  public void testInitialRestrictionWhenHasStartOffset() throws Exception {
+    long expectedStartOffset = 10L;
+    consumer.setStartOffsetForTime(15L);
+    consumer.setCurrentPos(5L);
+    OffsetRange result =
+        dofnInstance.initialRestriction(
+            KafkaSourceDescriptor.of(
+                topicPartition, expectedStartOffset, Instant.now(), ImmutableList.of()));
+    assertEquals(new OffsetRange(expectedStartOffset, Long.MAX_VALUE), result);
+  }
+
+  @Test
+  public void testInitialRestrictionWhenHasStartTime() throws Exception {
+    long expectedStartOffset = 10L;
+    consumer.setStartOffsetForTime(expectedStartOffset);
+    consumer.setCurrentPos(5L);
+    OffsetRange result =
+        dofnInstance.initialRestriction(
+            KafkaSourceDescriptor.of(topicPartition, null, Instant.now(), ImmutableList.of()));
+    assertEquals(new OffsetRange(expectedStartOffset, Long.MAX_VALUE), result);
+  }
+
+  @Test
+  public void testInitialRestrictionWithConsumerPosition() throws Exception {
+    long expectedStartOffset = 5L;
+    consumer.setCurrentPos(5L);
+    OffsetRange result =
+        dofnInstance.initialRestriction(
+            KafkaSourceDescriptor.of(topicPartition, null, null, ImmutableList.of()));
+    assertEquals(new OffsetRange(expectedStartOffset, Long.MAX_VALUE), result);
+  }
+
+  @Test
+  public void testProcessElement() throws Exception {
+    MockOutputReceiver receiver = new MockOutputReceiver();
+    consumer.setNumOfRecordsPerPoll(3L);
+    long startOffset = 5L;
+    OffsetRangeTracker tracker =
+        new OffsetRangeTracker(new OffsetRange(startOffset, startOffset + 3));
+    KafkaSourceDescriptor descriptor = KafkaSourceDescriptor.of(topicPartition, null, null, null);
+    ProcessContinuation result =
+        dofnInstance.processElement(descriptor, tracker, null, (OutputReceiver) receiver);
+    assertEquals(ProcessContinuation.stop(), result);
+    assertEquals(
+        createExpectedRecords(descriptor, startOffset, 3, "key", "value"), receiver.getOutputs());
+  }
+
+  @Test
+  public void testProcessElementWithEmptyPoll() throws Exception {
+    MockOutputReceiver receiver = new MockOutputReceiver();
+    consumer.setNumOfRecordsPerPoll(-1);
+    OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE));
+    ProcessContinuation result =
+        dofnInstance.processElement(
+            KafkaSourceDescriptor.of(topicPartition, null, null, null),
+            tracker,
+            null,
+            (OutputReceiver) receiver);
+    assertEquals(ProcessContinuation.resume(), result);
+    assertTrue(receiver.getOutputs().isEmpty());
+  }
+
+  @Test
+  public void testProcessElementWhenTopicPartitionIsRemoved() throws Exception {
+    MockOutputReceiver receiver = new MockOutputReceiver();
+    consumer.setRemoved();
+    consumer.setNumOfRecordsPerPoll(10);
+    OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE));
+    ProcessContinuation result =
+        dofnInstance.processElement(
+            KafkaSourceDescriptor.of(topicPartition, null, null, null),
+            tracker,
+            null,
+            (OutputReceiver) receiver);
+    assertEquals(ProcessContinuation.stop(), result);
+  }
+
+  @Test
+  public void testProcessElementWhenTopicPartitionIsStopped() throws Exception {
+    MockOutputReceiver receiver = new MockOutputReceiver();
+    ReadFromKafkaDoFn<String, String> instance =
+        new ReadFromKafkaDoFn(
+            makeReadSourceDescriptor(consumer)
+                .toBuilder()
+                .setCheckStopReadingFn(
+                    new SerializableFunction<TopicPartition, Boolean>() {
+                      @Override
+                      public Boolean apply(TopicPartition input) {
+                        assertTrue(input.equals(topicPartition));
+                        return true;
+                      }
+                    })
+                .build());
+    consumer.setNumOfRecordsPerPoll(10);
+    OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE));
+    ProcessContinuation result =
+        instance.processElement(
+            KafkaSourceDescriptor.of(topicPartition, null, null, null),
+            tracker,
+            null,
+            (OutputReceiver) receiver);
+    assertEquals(ProcessContinuation.stop(), result);
+  }
+}