You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pulsar.apache.org by si...@apache.org on 2018/09/11 19:11:45 UTC

[incubator-pulsar] branch master updated: [ecosystem] Flink pulsar source connector (#2555)

This is an automated email from the ASF dual-hosted git repository.

sijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pulsar.git


The following commit(s) were added to refs/heads/master by this push:
     new fc61056  [ecosystem] Flink pulsar source connector (#2555)
fc61056 is described below

commit fc61056ccab29bb068b65d71120364d14503f69f
Author: Sijie Guo <gu...@gmail.com>
AuthorDate: Tue Sep 11 12:11:41 2018 -0700

    [ecosystem] Flink pulsar source connector (#2555)
    
    *Motivation*
    
    This ports apache/flink#6200 to apache pulsar repo. It adds a pulsar source connector
    which will enable flink jobs to process messages from pulsar topics.
    
    *Changes*
    
    Add a PulsarConsumerSource connector
---
 .../connectors/pulsar/PulsarConsumerSource.java    | 204 ++++++++
 .../connectors/pulsar/PulsarSourceBase.java        |  31 ++
 .../connectors/pulsar/PulsarSourceBuilder.java     | 118 +++++
 .../pulsar/PulsarConsumerSourceTests.java          | 524 +++++++++++++++++++++
 4 files changed, 877 insertions(+)

diff --git a/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSource.java b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSource.java
new file mode 100644
index 0000000..f1b2595
--- /dev/null
+++ b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSource.java
@@ -0,0 +1,204 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.MessageAcknowledgingSourceBase;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.util.IOUtils;
+
+import org.apache.pulsar.client.api.Consumer;
+import org.apache.pulsar.client.api.Message;
+import org.apache.pulsar.client.api.MessageId;
+import org.apache.pulsar.client.api.PulsarClient;
+import org.apache.pulsar.client.api.PulsarClientException;
+import org.apache.pulsar.client.api.SubscriptionType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Pulsar source (consumer) which receives messages from a topic and acknowledges messages.
+ * When checkpointing is enabled, it guarantees at least once processing semantics.
+ *
+ * <p>When checkpointing is disabled, it auto acknowledges messages based on the number of messages it has
+ * received. In this mode messages may be dropped.
+ */
+class PulsarConsumerSource<T> extends MessageAcknowledgingSourceBase<T, MessageId> implements PulsarSourceBase<T> {
+
+    private static final Logger LOG = LoggerFactory.getLogger(PulsarConsumerSource.class);
+
+    private final int messageReceiveTimeoutMs = 100;
+    private final String serviceUrl;
+    private final String topic;
+    private final String subscriptionName;
+    private final DeserializationSchema<T> deserializer;
+
+    private PulsarClient client;
+    private Consumer<byte[]> consumer;
+
+    private boolean isCheckpointingEnabled;
+
+    private final long acknowledgementBatchSize;
+    private long batchCount;
+    private long totalMessageCount;
+
+    private transient volatile boolean isRunning;
+
+    PulsarConsumerSource(PulsarSourceBuilder<T> builder) {
+        super(MessageId.class);
+        this.serviceUrl = builder.serviceUrl;
+        this.topic = builder.topic;
+        this.deserializer = builder.deserializationSchema;
+        this.subscriptionName = builder.subscriptionName;
+        this.acknowledgementBatchSize = builder.acknowledgementBatchSize;
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+
+        final RuntimeContext context = getRuntimeContext();
+        if (context instanceof StreamingRuntimeContext) {
+            isCheckpointingEnabled = ((StreamingRuntimeContext) context).isCheckpointingEnabled();
+        }
+
+        client = createClient();
+        consumer = createConsumer(client);
+
+        isRunning = true;
+    }
+
+    @Override
+    protected void acknowledgeIDs(long checkpointId, Set<MessageId> messageIds) {
+        if (consumer == null) {
+            LOG.error("null consumer unable to acknowledge messages");
+            throw new RuntimeException("null pulsar consumer unable to acknowledge messages");
+        }
+
+        if (messageIds.isEmpty()) {
+            LOG.info("no message ids to acknowledge");
+            return;
+        }
+
+        Map<String, CompletableFuture<Void>> futures = new HashMap<>(messageIds.size());
+        for (MessageId id : messageIds) {
+            futures.put(id.toString(), consumer.acknowledgeAsync(id));
+        }
+
+        futures.forEach((k, f) -> {
+            try {
+                f.get();
+            } catch (Exception e) {
+                LOG.error("failed to acknowledge messageId " + k, e);
+                throw new RuntimeException("Messages could not be acknowledged during checkpoint creation.", e);
+            }
+        });
+    }
+
+    @Override
+    public void run(SourceContext<T> context) throws Exception {
+        Message message;
+        while (isRunning) {
+            message = consumer.receive(messageReceiveTimeoutMs, TimeUnit.MILLISECONDS);
+            if (message == null) {
+                LOG.info("unexpected null message");
+                continue;
+            }
+
+            if (isCheckpointingEnabled) {
+                emitCheckpointing(context, message);
+            } else {
+                emitAutoAcking(context, message);
+            }
+        }
+    }
+
+    private void emitCheckpointing(SourceContext<T> context, Message message) throws IOException {
+        synchronized (context.getCheckpointLock()) {
+            if (!addId(message.getMessageId())) {
+                if (LOG.isDebugEnabled()) {
+                    LOG.debug("messageId=" + message.getMessageId().toString() + " already processed.");
+                }
+                return;
+            }
+            context.collect(deserialize(message));
+            totalMessageCount++;
+        }
+    }
+
+    private void emitAutoAcking(SourceContext<T> context, Message message) throws IOException {
+        context.collect(deserialize(message));
+        batchCount++;
+        totalMessageCount++;
+        if (batchCount >= acknowledgementBatchSize) {
+            LOG.info("processed {} messages acknowledging messageId {}", batchCount, message.getMessageId());
+            consumer.acknowledgeCumulative(message.getMessageId());
+            batchCount = 0;
+        }
+    }
+
+    private T deserialize(Message message) throws IOException {
+        return deserializer.deserialize(message.getData());
+    }
+
+    @Override
+    public void cancel() {
+        isRunning = false;
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        IOUtils.cleanup(LOG, consumer);
+        IOUtils.cleanup(LOG, client);
+    }
+
+    @Override
+    public TypeInformation<T> getProducedType() {
+        return deserializer.getProducedType();
+    }
+
+    boolean isCheckpointingEnabled() {
+        return isCheckpointingEnabled;
+    }
+
+    PulsarClient createClient() throws PulsarClientException {
+        return PulsarClient.builder()
+            .serviceUrl(serviceUrl)
+            .build();
+    }
+
+    Consumer<byte[]> createConsumer(PulsarClient client) throws PulsarClientException {
+        return client.newConsumer()
+            .topic(topic)
+            .subscriptionName(subscriptionName)
+            .subscriptionType(SubscriptionType.Failover)
+            .subscribe();
+    }
+}
diff --git a/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBase.java b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBase.java
new file mode 100644
index 0000000..9d44215
--- /dev/null
+++ b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBase.java
@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
+
+/**
+ * Base class for pulsar sources.
+ * @param <T>
+ */
+@PublicEvolving
+interface PulsarSourceBase<T> extends ParallelSourceFunction<T>, ResultTypeQueryable<T> {
+}
diff --git a/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBuilder.java b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBuilder.java
new file mode 100644
index 0000000..7f1ee9c
--- /dev/null
+++ b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBuilder.java
@@ -0,0 +1,118 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * A class for building a pulsar source.
+ */
+@PublicEvolving
+public class PulsarSourceBuilder<T> {
+
+    static final String SERVICE_URL = "pulsar://localhost:6650";
+    static final long ACKNOWLEDGEMENT_BATCH_SIZE = 100;
+    static final long MAX_ACKNOWLEDGEMENT_BATCH_SIZE = 1000;
+
+    final DeserializationSchema<T> deserializationSchema;
+    String serviceUrl = SERVICE_URL;
+    String topic;
+    String subscriptionName = "flink-sub";
+    long acknowledgementBatchSize = ACKNOWLEDGEMENT_BATCH_SIZE;
+
+    private PulsarSourceBuilder(DeserializationSchema<T> deserializationSchema) {
+        this.deserializationSchema = deserializationSchema;
+    }
+
+    /**
+     * Sets the pulsar service url to connect to. Defaults to pulsar://localhost:6650.
+     *
+     * @param serviceUrl service url to connect to
+     * @return this builder
+     */
+    public PulsarSourceBuilder<T> serviceUrl(String serviceUrl) {
+        Preconditions.checkNotNull(serviceUrl);
+        this.serviceUrl = serviceUrl;
+        return this;
+    }
+
+    /**
+     * Sets the topic to consumer from. This is required.
+     *
+     * <p>Topic names (https://pulsar.incubator.apache.org/docs/latest/getting-started/ConceptsAndArchitecture/#Topics)
+     * are in the following format:
+     * {persistent|non-persistent}://tenant/namespace/topic
+     *
+     * @param topic the topic to consumer from
+     * @return this builder
+     */
+    public PulsarSourceBuilder<T> topic(String topic) {
+        Preconditions.checkNotNull(topic);
+        this.topic = topic;
+        return this;
+    }
+
+    /**
+     * Sets the subscription name for the topic consumer. Defaults to flink-sub.
+     *
+     * @param subscriptionName the subscription name for the topic consumer
+     * @return this builder
+     */
+    public PulsarSourceBuilder<T> subscriptionName(String subscriptionName) {
+        Preconditions.checkNotNull(subscriptionName);
+        this.subscriptionName = subscriptionName;
+        return this;
+    }
+
+    /**
+     * Sets the number of messages to receive before acknowledging. This defaults to 100. This
+     * value is only used when checkpointing is disabled.
+     *
+     * @param size number of messages to receive before acknowledging
+     * @return this builder
+     */
+    public PulsarSourceBuilder<T> acknowledgementBatchSize(long size) {
+        if (size > 0 && size <= MAX_ACKNOWLEDGEMENT_BATCH_SIZE) {
+            acknowledgementBatchSize = size;
+        }
+        return this;
+    }
+
+    public SourceFunction<T> build() {
+        Preconditions.checkNotNull(serviceUrl, "a service url is required");
+        Preconditions.checkNotNull(topic, "a topic is required");
+        Preconditions.checkNotNull(subscriptionName, "a subscription name is required");
+
+        return new PulsarConsumerSource<>(this);
+    }
+
+    /**
+     * Creates a PulsarSourceBuilder.
+     *
+     * @param deserializationSchema the deserializer used to convert between Pulsar's byte messages and Flink's objects.
+     * @return a builder
+     */
+    public static <T> PulsarSourceBuilder<T> builder(DeserializationSchema<T> deserializationSchema) {
+        Preconditions.checkNotNull(deserializationSchema);
+        return new PulsarSourceBuilder<>(deserializationSchema);
+    }
+}
diff --git a/pulsar-flink/src/test/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSourceTests.java b/pulsar-flink/src/test/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSourceTests.java
new file mode 100644
index 0000000..97811da
--- /dev/null
+++ b/pulsar-flink/src/test/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSourceTests.java
@@ -0,0 +1,524 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.serialization.SimpleStringSchema;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+
+import org.apache.pulsar.client.api.Consumer;
+import org.apache.pulsar.client.api.ConsumerStats;
+import org.apache.pulsar.client.api.Message;
+import org.apache.pulsar.client.api.MessageId;
+import org.apache.pulsar.client.api.PulsarClient;
+import org.apache.pulsar.client.api.PulsarClientException;
+import org.apache.pulsar.client.api.Schema;
+import org.apache.pulsar.client.impl.MessageImpl;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.mockito.Matchers.any;
+
+/**
+ * Tests for the PulsarConsumerSource. The source supports two operation modes.
+ * 1) At-least-once (when checkpointed) with Pulsar message acknowledgements and the deduplication mechanism in
+ *    {@link org.apache.flink.streaming.api.functions.source.MessageAcknowledgingSourceBase}..
+ * 3) No strong delivery guarantees (without checkpointing) with Pulsar acknowledging messages after
+ *	  after it receives x number of messages.
+ *
+ * <p>This tests assumes that the MessageIds are increasing monotonously. That doesn't have to be the
+ * case. The MessageId is used to uniquely identify messages.
+ */
+public class PulsarConsumerSourceTests {
+
+    private PulsarConsumerSource<String> source;
+
+    private TestConsumer consumer;
+
+    private TestSourceContext context;
+
+    private Thread sourceThread;
+
+    private Exception exception;
+
+    @Before
+    public void before() {
+        context = new TestSourceContext();
+
+        sourceThread = new Thread(() -> {
+            try {
+                source.run(context);
+            } catch (Exception e) {
+                exception = e;
+            }
+        });
+    }
+
+    @After
+    public void after() throws Exception {
+        if (source != null) {
+            source.cancel();
+        }
+        if (sourceThread != null) {
+            sourceThread.join();
+        }
+    }
+
+    @Test
+    public void testCheckpointing() throws Exception {
+        final int numMessages = 5;
+        consumer = new TestConsumer(numMessages);
+
+        source = createSource(consumer, 1, true);
+        source.open(new Configuration());
+
+        final StreamSource<String, PulsarConsumerSource<String>> src = new StreamSource<>(source);
+        final AbstractStreamOperatorTestHarness<String> testHarness =
+            new AbstractStreamOperatorTestHarness<>(src, 1, 1, 0);
+
+        testHarness.open();
+
+        sourceThread.start();
+
+        final Random random = new Random(System.currentTimeMillis());
+        for (int i = 0; i < 3; ++i) {
+
+            // wait and receive messages from the test consumer
+            receiveMessages();
+
+            final long snapshotId = random.nextLong();
+            OperatorSubtaskState data;
+            synchronized (context.getCheckpointLock()) {
+                data = testHarness.snapshot(snapshotId, System.currentTimeMillis());
+            }
+
+            final TestPulsarConsumerSource sourceCopy =
+                createSource(Mockito.mock(Consumer.class), 1, true);
+            final StreamSource<String, TestPulsarConsumerSource> srcCopy = new StreamSource<>(sourceCopy);
+            final AbstractStreamOperatorTestHarness<String> testHarnessCopy =
+                new AbstractStreamOperatorTestHarness<>(srcCopy, 1, 1, 0);
+
+            testHarnessCopy.setup();
+            testHarnessCopy.initializeState(data);
+            testHarnessCopy.open();
+
+            final ArrayDeque<Tuple2<Long, Set<MessageId>>> deque = sourceCopy.getRestoredState();
+            final Set<MessageId> messageIds = deque.getLast().f1;
+
+            final int start = consumer.currentMessage.get() - numMessages;
+            for (int mi = start; mi < (start + numMessages); ++mi) {
+                Assert.assertTrue(messageIds.contains(consumer.messages.get(mi).getMessageId()));
+            }
+
+            // check if the messages are being acknowledged
+            synchronized (context.getCheckpointLock()) {
+                source.notifyCheckpointComplete(snapshotId);
+
+                Assert.assertEquals(consumer.acknowledgedIds.keySet(), messageIds);
+                // clear acknowledgements for the next snapshot comparison
+                consumer.acknowledgedIds.clear();
+            }
+
+            final int lastMessageIndex = consumer.currentMessage.get();
+            consumer.addMessages(createMessages(lastMessageIndex, 5));
+        }
+    }
+
+    @Test
+    public void testCheckpointingDuplicatedIds() throws Exception {
+        consumer = new TestConsumer(5);
+
+        source = createSource(consumer, 1, true);
+        source.open(new Configuration());
+
+        sourceThread.start();
+
+        receiveMessages();
+
+        Assert.assertEquals(5, context.elements.size());
+
+        // try to reprocess the messages we should not collect any more elements
+        consumer.reset();
+
+        receiveMessages();
+
+        Assert.assertEquals(5, context.elements.size());
+    }
+
+    @Test
+    public void testCheckpointingDisabledMessagesEqualBatchSize() throws Exception {
+
+        consumer = new TestConsumer(5);
+
+        source = createSource(consumer, 5, false);
+        source.open(new Configuration());
+
+        sourceThread.start();
+
+        receiveMessages();
+
+        Assert.assertEquals(1, consumer.acknowledgedIds.size());
+    }
+
+    @Test
+    public void testCheckpointingDisabledMoreMessagesThanBatchSize() throws Exception {
+
+        consumer = new TestConsumer(6);
+
+        source = createSource(consumer, 5, false);
+        source.open(new Configuration());
+
+        sourceThread.start();
+
+        receiveMessages();
+
+        Assert.assertEquals(1, consumer.acknowledgedIds.size());
+    }
+
+    @Test
+    public void testCheckpointingDisabledLessMessagesThanBatchSize() throws Exception {
+
+        consumer = new TestConsumer(4);
+
+        source = createSource(consumer, 5, false);
+        source.open(new Configuration());
+
+        sourceThread.start();
+
+        receiveMessages();
+
+        Assert.assertEquals(0, consumer.acknowledgedIds.size());
+    }
+
+    @Test
+    public void testCheckpointingDisabledMessages2XBatchSize() throws Exception {
+
+        consumer = new TestConsumer(10);
+
+        source = createSource(consumer, 5, false);
+        source.open(new Configuration());
+
+        sourceThread.start();
+
+        receiveMessages();
+
+        Assert.assertEquals(2, consumer.acknowledgedIds.size());
+    }
+
+    private void receiveMessages() throws InterruptedException {
+        while (consumer.currentMessage.get() < consumer.messages.size()) {
+            Thread.sleep(5);
+        }
+    }
+
+    private TestPulsarConsumerSource createSource(Consumer<byte[]> testConsumer,
+                                                  long batchSize, boolean isCheckpointingEnabled) throws Exception {
+        PulsarSourceBuilder<String> builder =
+            PulsarSourceBuilder.builder(new SimpleStringSchema())
+                .acknowledgementBatchSize(batchSize);
+        TestPulsarConsumerSource source = new TestPulsarConsumerSource(builder, testConsumer, isCheckpointingEnabled);
+
+        OperatorStateStore mockStore = Mockito.mock(OperatorStateStore.class);
+        FunctionInitializationContext mockContext = Mockito.mock(FunctionInitializationContext.class);
+        Mockito.when(mockContext.getOperatorStateStore()).thenReturn(mockStore);
+        Mockito.when(mockStore.getSerializableListState(any(String.class))).thenReturn(null);
+
+        source.initializeState(mockContext);
+
+        return source;
+    }
+
+    private static class TestPulsarConsumerSource extends PulsarConsumerSource<String> {
+
+        private ArrayDeque<Tuple2<Long, Set<MessageId>>> restoredState;
+
+        private Consumer<byte[]> testConsumer;
+        private boolean isCheckpointingEnabled;
+
+        TestPulsarConsumerSource(PulsarSourceBuilder<String> builder,
+                                 Consumer<byte[]> testConsumer, boolean isCheckpointingEnabled) {
+            super(builder);
+            this.testConsumer = testConsumer;
+            this.isCheckpointingEnabled = isCheckpointingEnabled;
+        }
+
+        @Override
+        protected boolean addId(MessageId messageId) {
+            Assert.assertEquals(true, isCheckpointingEnabled());
+            return super.addId(messageId);
+        }
+
+        @Override
+        public RuntimeContext getRuntimeContext() {
+            StreamingRuntimeContext context = Mockito.mock(StreamingRuntimeContext.class);
+            Mockito.when(context.isCheckpointingEnabled()).thenReturn(isCheckpointingEnabled);
+            return context;
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) throws Exception {
+            super.initializeState(context);
+            this.restoredState = this.pendingCheckpoints;
+        }
+
+        public ArrayDeque<Tuple2<Long, Set<MessageId>>> getRestoredState() {
+            return this.restoredState;
+        }
+
+        @Override
+        PulsarClient createClient() {
+            return Mockito.mock(PulsarClient.class);
+        }
+
+        @Override
+        Consumer<byte[]> createConsumer(PulsarClient client) {
+            return testConsumer;
+        }
+    }
+
+    private static class TestSourceContext implements SourceFunction.SourceContext<String> {
+
+        private static final Object lock = new Object();
+
+        private final List<String> elements = new ArrayList<>();
+
+        @Override
+        public void collect(String element) {
+            elements.add(element);
+        }
+
+        @Override
+        public void collectWithTimestamp(String element, long timestamp) {
+
+        }
+
+        @Override
+        public void emitWatermark(Watermark mark) {
+
+        }
+
+        @Override
+        public void markAsTemporarilyIdle() {
+
+        }
+
+        @Override
+        public Object getCheckpointLock() {
+            return lock;
+        }
+
+        @Override
+        public void close() {
+
+        }
+    }
+
+    private static class TestConsumer implements Consumer<byte[]> {
+
+        private final List<Message> messages = new ArrayList<>();
+
+        private AtomicInteger currentMessage = new AtomicInteger();
+
+        private final Map<MessageId, MessageId> acknowledgedIds = new ConcurrentHashMap<>();
+
+        private TestConsumer(int numMessages) {
+            messages.addAll(createMessages(0, numMessages));
+        }
+
+        private void reset() {
+            currentMessage.set(0);
+        }
+
+        @Override
+        public String getTopic() {
+            return null;
+        }
+
+        @Override
+        public String getSubscription() {
+            return null;
+        }
+
+        @Override
+        public void unsubscribe() throws PulsarClientException {
+
+        }
+
+        @Override
+        public CompletableFuture<Void> unsubscribeAsync() {
+            return null;
+        }
+
+        @Override
+        public Message<byte[]> receive() throws PulsarClientException {
+            return null;
+        }
+
+        public synchronized void addMessages(List<Message> messages) {
+            this.messages.addAll(messages);
+        }
+
+        @Override
+        public CompletableFuture<Message<byte[]>> receiveAsync() {
+            return null;
+        }
+
+        @Override
+        public Message<byte[]> receive(int i, TimeUnit timeUnit) throws PulsarClientException {
+            synchronized (this) {
+                if (currentMessage.get() == messages.size()) {
+                    try {
+                        Thread.sleep(10);
+                    } catch (InterruptedException e) {
+                        System.out.println("no more messages sleeping index: " + currentMessage.get());
+                    }
+                    return null;
+                }
+                return messages.get(currentMessage.getAndIncrement());
+            }
+        }
+
+        @Override
+        public void acknowledge(Message<?> message) throws PulsarClientException {
+
+        }
+
+        @Override
+        public void acknowledge(MessageId messageId) throws PulsarClientException {
+
+        }
+
+        @Override
+        public void acknowledgeCumulative(Message<?> message) throws PulsarClientException {
+
+        }
+
+        @Override
+        public void acknowledgeCumulative(MessageId messageId) throws PulsarClientException {
+            acknowledgedIds.put(messageId, messageId);
+        }
+
+        @Override
+        public CompletableFuture<Void> acknowledgeAsync(Message<?> message) {
+            return null;
+        }
+
+        @Override
+        public CompletableFuture<Void> acknowledgeAsync(MessageId messageId) {
+            acknowledgedIds.put(messageId, messageId);
+            return CompletableFuture.completedFuture(null);
+        }
+
+        @Override
+        public CompletableFuture<Void> acknowledgeCumulativeAsync(Message<?> message) {
+            return null;
+        }
+
+        @Override
+        public CompletableFuture<Void> acknowledgeCumulativeAsync(MessageId messageId) {
+            return null;
+        }
+
+        @Override
+        public ConsumerStats getStats() {
+            return null;
+        }
+
+        @Override
+        public void close() throws PulsarClientException {
+
+        }
+
+        @Override
+        public CompletableFuture<Void> closeAsync() {
+            return null;
+        }
+
+        @Override
+        public boolean hasReachedEndOfTopic() {
+            return false;
+        }
+
+        @Override
+        public void redeliverUnacknowledgedMessages() {
+
+        }
+
+        @Override
+        public void seek(MessageId messageId) throws PulsarClientException {
+
+        }
+
+        @Override
+        public CompletableFuture<Void> seekAsync(MessageId messageId) {
+            return null;
+        }
+
+        @Override
+        public boolean isConnected() {
+            return true;
+        }
+
+        @Override
+        public String getConsumerName() {
+            return "test-consumer-0";
+        }
+    }
+
+    private static List<Message> createMessages(int startIndex, int numMessages) {
+        final List<Message> messages = new ArrayList<>();
+        for (int i = startIndex; i < (startIndex + numMessages); ++i) {
+            String content = "message-" + i;
+            messages.add(createMessage(content, createMessageId(1, i + 1, 1)));
+        }
+        return messages;
+    }
+
+    private static Message<byte[]> createMessage(String content, String messageId) {
+        return new MessageImpl<>(messageId, Collections.emptyMap(), content.getBytes(), Schema.BYTES);
+    }
+
+    private static String createMessageId(long ledgerId, long entryId, long partitionIndex) {
+        return String.format("%d:%d:%d", ledgerId, entryId, partitionIndex);
+    }
+}