You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pulsar.apache.org by si...@apache.org on 2018/09/11 19:11:45 UTC
[incubator-pulsar] branch master updated: [ecosystem] Flink pulsar
source connector (#2555)
This is an automated email from the ASF dual-hosted git repository.
sijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pulsar.git
The following commit(s) were added to refs/heads/master by this push:
new fc61056 [ecosystem] Flink pulsar source connector (#2555)
fc61056 is described below
commit fc61056ccab29bb068b65d71120364d14503f69f
Author: Sijie Guo <gu...@gmail.com>
AuthorDate: Tue Sep 11 12:11:41 2018 -0700
[ecosystem] Flink pulsar source connector (#2555)
*Motivation*
This ports apache/flink#6200 to apache pulsar repo. It adds a pulsar source connector
which will enable flink jobs to process messages from pulsar topics.
*Changes*
Add a PulsarConsumerSource connector
---
.../connectors/pulsar/PulsarConsumerSource.java | 204 ++++++++
.../connectors/pulsar/PulsarSourceBase.java | 31 ++
.../connectors/pulsar/PulsarSourceBuilder.java | 118 +++++
.../pulsar/PulsarConsumerSourceTests.java | 524 +++++++++++++++++++++
4 files changed, 877 insertions(+)
diff --git a/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSource.java b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSource.java
new file mode 100644
index 0000000..f1b2595
--- /dev/null
+++ b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSource.java
@@ -0,0 +1,204 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.MessageAcknowledgingSourceBase;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.util.IOUtils;
+
+import org.apache.pulsar.client.api.Consumer;
+import org.apache.pulsar.client.api.Message;
+import org.apache.pulsar.client.api.MessageId;
+import org.apache.pulsar.client.api.PulsarClient;
+import org.apache.pulsar.client.api.PulsarClientException;
+import org.apache.pulsar.client.api.SubscriptionType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Pulsar source (consumer) which receives messages from a topic and acknowledges messages.
+ * When checkpointing is enabled, it guarantees at least once processing semantics.
+ *
+ * <p>When checkpointing is disabled, it auto acknowledges messages based on the number of messages it has
+ * received. In this mode messages may be dropped.
+ */
+class PulsarConsumerSource<T> extends MessageAcknowledgingSourceBase<T, MessageId> implements PulsarSourceBase<T> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(PulsarConsumerSource.class);
+
+ private final int messageReceiveTimeoutMs = 100;
+ private final String serviceUrl;
+ private final String topic;
+ private final String subscriptionName;
+ private final DeserializationSchema<T> deserializer;
+
+ private PulsarClient client;
+ private Consumer<byte[]> consumer;
+
+ private boolean isCheckpointingEnabled;
+
+ private final long acknowledgementBatchSize;
+ private long batchCount;
+ private long totalMessageCount;
+
+ private transient volatile boolean isRunning;
+
+ PulsarConsumerSource(PulsarSourceBuilder<T> builder) {
+ super(MessageId.class);
+ this.serviceUrl = builder.serviceUrl;
+ this.topic = builder.topic;
+ this.deserializer = builder.deserializationSchema;
+ this.subscriptionName = builder.subscriptionName;
+ this.acknowledgementBatchSize = builder.acknowledgementBatchSize;
+ }
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+
+ final RuntimeContext context = getRuntimeContext();
+ if (context instanceof StreamingRuntimeContext) {
+ isCheckpointingEnabled = ((StreamingRuntimeContext) context).isCheckpointingEnabled();
+ }
+
+ client = createClient();
+ consumer = createConsumer(client);
+
+ isRunning = true;
+ }
+
+ @Override
+ protected void acknowledgeIDs(long checkpointId, Set<MessageId> messageIds) {
+ if (consumer == null) {
+ LOG.error("null consumer unable to acknowledge messages");
+ throw new RuntimeException("null pulsar consumer unable to acknowledge messages");
+ }
+
+ if (messageIds.isEmpty()) {
+ LOG.info("no message ids to acknowledge");
+ return;
+ }
+
+ Map<String, CompletableFuture<Void>> futures = new HashMap<>(messageIds.size());
+ for (MessageId id : messageIds) {
+ futures.put(id.toString(), consumer.acknowledgeAsync(id));
+ }
+
+ futures.forEach((k, f) -> {
+ try {
+ f.get();
+ } catch (Exception e) {
+ LOG.error("failed to acknowledge messageId " + k, e);
+ throw new RuntimeException("Messages could not be acknowledged during checkpoint creation.", e);
+ }
+ });
+ }
+
+ @Override
+ public void run(SourceContext<T> context) throws Exception {
+ Message message;
+ while (isRunning) {
+ message = consumer.receive(messageReceiveTimeoutMs, TimeUnit.MILLISECONDS);
+ if (message == null) {
+ LOG.info("unexpected null message");
+ continue;
+ }
+
+ if (isCheckpointingEnabled) {
+ emitCheckpointing(context, message);
+ } else {
+ emitAutoAcking(context, message);
+ }
+ }
+ }
+
+ private void emitCheckpointing(SourceContext<T> context, Message message) throws IOException {
+ synchronized (context.getCheckpointLock()) {
+ if (!addId(message.getMessageId())) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("messageId=" + message.getMessageId().toString() + " already processed.");
+ }
+ return;
+ }
+ context.collect(deserialize(message));
+ totalMessageCount++;
+ }
+ }
+
+ private void emitAutoAcking(SourceContext<T> context, Message message) throws IOException {
+ context.collect(deserialize(message));
+ batchCount++;
+ totalMessageCount++;
+ if (batchCount >= acknowledgementBatchSize) {
+ LOG.info("processed {} messages acknowledging messageId {}", batchCount, message.getMessageId());
+ consumer.acknowledgeCumulative(message.getMessageId());
+ batchCount = 0;
+ }
+ }
+
+ private T deserialize(Message message) throws IOException {
+ return deserializer.deserialize(message.getData());
+ }
+
+ @Override
+ public void cancel() {
+ isRunning = false;
+ }
+
+ @Override
+ public void close() throws Exception {
+ super.close();
+ IOUtils.cleanup(LOG, consumer);
+ IOUtils.cleanup(LOG, client);
+ }
+
+ @Override
+ public TypeInformation<T> getProducedType() {
+ return deserializer.getProducedType();
+ }
+
+ boolean isCheckpointingEnabled() {
+ return isCheckpointingEnabled;
+ }
+
+ PulsarClient createClient() throws PulsarClientException {
+ return PulsarClient.builder()
+ .serviceUrl(serviceUrl)
+ .build();
+ }
+
+ Consumer<byte[]> createConsumer(PulsarClient client) throws PulsarClientException {
+ return client.newConsumer()
+ .topic(topic)
+ .subscriptionName(subscriptionName)
+ .subscriptionType(SubscriptionType.Failover)
+ .subscribe();
+ }
+}
diff --git a/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBase.java b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBase.java
new file mode 100644
index 0000000..9d44215
--- /dev/null
+++ b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBase.java
@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
+
+/**
+ * Base class for pulsar sources.
+ * @param <T>
+ */
+@PublicEvolving
+interface PulsarSourceBase<T> extends ParallelSourceFunction<T>, ResultTypeQueryable<T> {
+}
diff --git a/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBuilder.java b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBuilder.java
new file mode 100644
index 0000000..7f1ee9c
--- /dev/null
+++ b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBuilder.java
@@ -0,0 +1,118 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * A class for building a pulsar source.
+ */
+@PublicEvolving
+public class PulsarSourceBuilder<T> {
+
+ static final String SERVICE_URL = "pulsar://localhost:6650";
+ static final long ACKNOWLEDGEMENT_BATCH_SIZE = 100;
+ static final long MAX_ACKNOWLEDGEMENT_BATCH_SIZE = 1000;
+
+ final DeserializationSchema<T> deserializationSchema;
+ String serviceUrl = SERVICE_URL;
+ String topic;
+ String subscriptionName = "flink-sub";
+ long acknowledgementBatchSize = ACKNOWLEDGEMENT_BATCH_SIZE;
+
+ private PulsarSourceBuilder(DeserializationSchema<T> deserializationSchema) {
+ this.deserializationSchema = deserializationSchema;
+ }
+
+ /**
+ * Sets the pulsar service url to connect to. Defaults to pulsar://localhost:6650.
+ *
+ * @param serviceUrl service url to connect to
+ * @return this builder
+ */
+ public PulsarSourceBuilder<T> serviceUrl(String serviceUrl) {
+ Preconditions.checkNotNull(serviceUrl);
+ this.serviceUrl = serviceUrl;
+ return this;
+ }
+
+ /**
+ * Sets the topic to consumer from. This is required.
+ *
+ * <p>Topic names (https://pulsar.incubator.apache.org/docs/latest/getting-started/ConceptsAndArchitecture/#Topics)
+ * are in the following format:
+ * {persistent|non-persistent}://tenant/namespace/topic
+ *
+ * @param topic the topic to consumer from
+ * @return this builder
+ */
+ public PulsarSourceBuilder<T> topic(String topic) {
+ Preconditions.checkNotNull(topic);
+ this.topic = topic;
+ return this;
+ }
+
+ /**
+ * Sets the subscription name for the topic consumer. Defaults to flink-sub.
+ *
+ * @param subscriptionName the subscription name for the topic consumer
+ * @return this builder
+ */
+ public PulsarSourceBuilder<T> subscriptionName(String subscriptionName) {
+ Preconditions.checkNotNull(subscriptionName);
+ this.subscriptionName = subscriptionName;
+ return this;
+ }
+
+ /**
+ * Sets the number of messages to receive before acknowledging. This defaults to 100. This
+ * value is only used when checkpointing is disabled.
+ *
+ * @param size number of messages to receive before acknowledging
+ * @return this builder
+ */
+ public PulsarSourceBuilder<T> acknowledgementBatchSize(long size) {
+ if (size > 0 && size <= MAX_ACKNOWLEDGEMENT_BATCH_SIZE) {
+ acknowledgementBatchSize = size;
+ }
+ return this;
+ }
+
+ public SourceFunction<T> build() {
+ Preconditions.checkNotNull(serviceUrl, "a service url is required");
+ Preconditions.checkNotNull(topic, "a topic is required");
+ Preconditions.checkNotNull(subscriptionName, "a subscription name is required");
+
+ return new PulsarConsumerSource<>(this);
+ }
+
+ /**
+ * Creates a PulsarSourceBuilder.
+ *
+ * @param deserializationSchema the deserializer used to convert between Pulsar's byte messages and Flink's objects.
+ * @return a builder
+ */
+ public static <T> PulsarSourceBuilder<T> builder(DeserializationSchema<T> deserializationSchema) {
+ Preconditions.checkNotNull(deserializationSchema);
+ return new PulsarSourceBuilder<>(deserializationSchema);
+ }
+}
diff --git a/pulsar-flink/src/test/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSourceTests.java b/pulsar-flink/src/test/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSourceTests.java
new file mode 100644
index 0000000..97811da
--- /dev/null
+++ b/pulsar-flink/src/test/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSourceTests.java
@@ -0,0 +1,524 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.serialization.SimpleStringSchema;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+
+import org.apache.pulsar.client.api.Consumer;
+import org.apache.pulsar.client.api.ConsumerStats;
+import org.apache.pulsar.client.api.Message;
+import org.apache.pulsar.client.api.MessageId;
+import org.apache.pulsar.client.api.PulsarClient;
+import org.apache.pulsar.client.api.PulsarClientException;
+import org.apache.pulsar.client.api.Schema;
+import org.apache.pulsar.client.impl.MessageImpl;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.mockito.Matchers.any;
+
+/**
+ * Tests for the PulsarConsumerSource. The source supports two operation modes.
+ * 1) At-least-once (when checkpointed) with Pulsar message acknowledgements and the deduplication mechanism in
+ * {@link org.apache.flink.streaming.api.functions.source.MessageAcknowledgingSourceBase}..
+ * 3) No strong delivery guarantees (without checkpointing) with Pulsar acknowledging messages after
+ * after it receives x number of messages.
+ *
+ * <p>This tests assumes that the MessageIds are increasing monotonously. That doesn't have to be the
+ * case. The MessageId is used to uniquely identify messages.
+ */
+public class PulsarConsumerSourceTests {
+
+ private PulsarConsumerSource<String> source;
+
+ private TestConsumer consumer;
+
+ private TestSourceContext context;
+
+ private Thread sourceThread;
+
+ private Exception exception;
+
+ @Before
+ public void before() {
+ context = new TestSourceContext();
+
+ sourceThread = new Thread(() -> {
+ try {
+ source.run(context);
+ } catch (Exception e) {
+ exception = e;
+ }
+ });
+ }
+
+ @After
+ public void after() throws Exception {
+ if (source != null) {
+ source.cancel();
+ }
+ if (sourceThread != null) {
+ sourceThread.join();
+ }
+ }
+
+ @Test
+ public void testCheckpointing() throws Exception {
+ final int numMessages = 5;
+ consumer = new TestConsumer(numMessages);
+
+ source = createSource(consumer, 1, true);
+ source.open(new Configuration());
+
+ final StreamSource<String, PulsarConsumerSource<String>> src = new StreamSource<>(source);
+ final AbstractStreamOperatorTestHarness<String> testHarness =
+ new AbstractStreamOperatorTestHarness<>(src, 1, 1, 0);
+
+ testHarness.open();
+
+ sourceThread.start();
+
+ final Random random = new Random(System.currentTimeMillis());
+ for (int i = 0; i < 3; ++i) {
+
+ // wait and receive messages from the test consumer
+ receiveMessages();
+
+ final long snapshotId = random.nextLong();
+ OperatorSubtaskState data;
+ synchronized (context.getCheckpointLock()) {
+ data = testHarness.snapshot(snapshotId, System.currentTimeMillis());
+ }
+
+ final TestPulsarConsumerSource sourceCopy =
+ createSource(Mockito.mock(Consumer.class), 1, true);
+ final StreamSource<String, TestPulsarConsumerSource> srcCopy = new StreamSource<>(sourceCopy);
+ final AbstractStreamOperatorTestHarness<String> testHarnessCopy =
+ new AbstractStreamOperatorTestHarness<>(srcCopy, 1, 1, 0);
+
+ testHarnessCopy.setup();
+ testHarnessCopy.initializeState(data);
+ testHarnessCopy.open();
+
+ final ArrayDeque<Tuple2<Long, Set<MessageId>>> deque = sourceCopy.getRestoredState();
+ final Set<MessageId> messageIds = deque.getLast().f1;
+
+ final int start = consumer.currentMessage.get() - numMessages;
+ for (int mi = start; mi < (start + numMessages); ++mi) {
+ Assert.assertTrue(messageIds.contains(consumer.messages.get(mi).getMessageId()));
+ }
+
+ // check if the messages are being acknowledged
+ synchronized (context.getCheckpointLock()) {
+ source.notifyCheckpointComplete(snapshotId);
+
+ Assert.assertEquals(consumer.acknowledgedIds.keySet(), messageIds);
+ // clear acknowledgements for the next snapshot comparison
+ consumer.acknowledgedIds.clear();
+ }
+
+ final int lastMessageIndex = consumer.currentMessage.get();
+ consumer.addMessages(createMessages(lastMessageIndex, 5));
+ }
+ }
+
+ @Test
+ public void testCheckpointingDuplicatedIds() throws Exception {
+ consumer = new TestConsumer(5);
+
+ source = createSource(consumer, 1, true);
+ source.open(new Configuration());
+
+ sourceThread.start();
+
+ receiveMessages();
+
+ Assert.assertEquals(5, context.elements.size());
+
+ // try to reprocess the messages we should not collect any more elements
+ consumer.reset();
+
+ receiveMessages();
+
+ Assert.assertEquals(5, context.elements.size());
+ }
+
+ @Test
+ public void testCheckpointingDisabledMessagesEqualBatchSize() throws Exception {
+
+ consumer = new TestConsumer(5);
+
+ source = createSource(consumer, 5, false);
+ source.open(new Configuration());
+
+ sourceThread.start();
+
+ receiveMessages();
+
+ Assert.assertEquals(1, consumer.acknowledgedIds.size());
+ }
+
+ @Test
+ public void testCheckpointingDisabledMoreMessagesThanBatchSize() throws Exception {
+
+ consumer = new TestConsumer(6);
+
+ source = createSource(consumer, 5, false);
+ source.open(new Configuration());
+
+ sourceThread.start();
+
+ receiveMessages();
+
+ Assert.assertEquals(1, consumer.acknowledgedIds.size());
+ }
+
+ @Test
+ public void testCheckpointingDisabledLessMessagesThanBatchSize() throws Exception {
+
+ consumer = new TestConsumer(4);
+
+ source = createSource(consumer, 5, false);
+ source.open(new Configuration());
+
+ sourceThread.start();
+
+ receiveMessages();
+
+ Assert.assertEquals(0, consumer.acknowledgedIds.size());
+ }
+
+ @Test
+ public void testCheckpointingDisabledMessages2XBatchSize() throws Exception {
+
+ consumer = new TestConsumer(10);
+
+ source = createSource(consumer, 5, false);
+ source.open(new Configuration());
+
+ sourceThread.start();
+
+ receiveMessages();
+
+ Assert.assertEquals(2, consumer.acknowledgedIds.size());
+ }
+
+ private void receiveMessages() throws InterruptedException {
+ while (consumer.currentMessage.get() < consumer.messages.size()) {
+ Thread.sleep(5);
+ }
+ }
+
+ private TestPulsarConsumerSource createSource(Consumer<byte[]> testConsumer,
+ long batchSize, boolean isCheckpointingEnabled) throws Exception {
+ PulsarSourceBuilder<String> builder =
+ PulsarSourceBuilder.builder(new SimpleStringSchema())
+ .acknowledgementBatchSize(batchSize);
+ TestPulsarConsumerSource source = new TestPulsarConsumerSource(builder, testConsumer, isCheckpointingEnabled);
+
+ OperatorStateStore mockStore = Mockito.mock(OperatorStateStore.class);
+ FunctionInitializationContext mockContext = Mockito.mock(FunctionInitializationContext.class);
+ Mockito.when(mockContext.getOperatorStateStore()).thenReturn(mockStore);
+ Mockito.when(mockStore.getSerializableListState(any(String.class))).thenReturn(null);
+
+ source.initializeState(mockContext);
+
+ return source;
+ }
+
+ private static class TestPulsarConsumerSource extends PulsarConsumerSource<String> {
+
+ private ArrayDeque<Tuple2<Long, Set<MessageId>>> restoredState;
+
+ private Consumer<byte[]> testConsumer;
+ private boolean isCheckpointingEnabled;
+
+ TestPulsarConsumerSource(PulsarSourceBuilder<String> builder,
+ Consumer<byte[]> testConsumer, boolean isCheckpointingEnabled) {
+ super(builder);
+ this.testConsumer = testConsumer;
+ this.isCheckpointingEnabled = isCheckpointingEnabled;
+ }
+
+ @Override
+ protected boolean addId(MessageId messageId) {
+ Assert.assertEquals(true, isCheckpointingEnabled());
+ return super.addId(messageId);
+ }
+
+ @Override
+ public RuntimeContext getRuntimeContext() {
+ StreamingRuntimeContext context = Mockito.mock(StreamingRuntimeContext.class);
+ Mockito.when(context.isCheckpointingEnabled()).thenReturn(isCheckpointingEnabled);
+ return context;
+ }
+
+ @Override
+ public void initializeState(FunctionInitializationContext context) throws Exception {
+ super.initializeState(context);
+ this.restoredState = this.pendingCheckpoints;
+ }
+
+ public ArrayDeque<Tuple2<Long, Set<MessageId>>> getRestoredState() {
+ return this.restoredState;
+ }
+
+ @Override
+ PulsarClient createClient() {
+ return Mockito.mock(PulsarClient.class);
+ }
+
+ @Override
+ Consumer<byte[]> createConsumer(PulsarClient client) {
+ return testConsumer;
+ }
+ }
+
+ private static class TestSourceContext implements SourceFunction.SourceContext<String> {
+
+ private static final Object lock = new Object();
+
+ private final List<String> elements = new ArrayList<>();
+
+ @Override
+ public void collect(String element) {
+ elements.add(element);
+ }
+
+ @Override
+ public void collectWithTimestamp(String element, long timestamp) {
+
+ }
+
+ @Override
+ public void emitWatermark(Watermark mark) {
+
+ }
+
+ @Override
+ public void markAsTemporarilyIdle() {
+
+ }
+
+ @Override
+ public Object getCheckpointLock() {
+ return lock;
+ }
+
+ @Override
+ public void close() {
+
+ }
+ }
+
+ private static class TestConsumer implements Consumer<byte[]> {
+
+ private final List<Message> messages = new ArrayList<>();
+
+ private AtomicInteger currentMessage = new AtomicInteger();
+
+ private final Map<MessageId, MessageId> acknowledgedIds = new ConcurrentHashMap<>();
+
+ private TestConsumer(int numMessages) {
+ messages.addAll(createMessages(0, numMessages));
+ }
+
+ private void reset() {
+ currentMessage.set(0);
+ }
+
+ @Override
+ public String getTopic() {
+ return null;
+ }
+
+ @Override
+ public String getSubscription() {
+ return null;
+ }
+
+ @Override
+ public void unsubscribe() throws PulsarClientException {
+
+ }
+
+ @Override
+ public CompletableFuture<Void> unsubscribeAsync() {
+ return null;
+ }
+
+ @Override
+ public Message<byte[]> receive() throws PulsarClientException {
+ return null;
+ }
+
+ public synchronized void addMessages(List<Message> messages) {
+ this.messages.addAll(messages);
+ }
+
+ @Override
+ public CompletableFuture<Message<byte[]>> receiveAsync() {
+ return null;
+ }
+
+ @Override
+ public Message<byte[]> receive(int i, TimeUnit timeUnit) throws PulsarClientException {
+ synchronized (this) {
+ if (currentMessage.get() == messages.size()) {
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException e) {
+ System.out.println("no more messages sleeping index: " + currentMessage.get());
+ }
+ return null;
+ }
+ return messages.get(currentMessage.getAndIncrement());
+ }
+ }
+
+ @Override
+ public void acknowledge(Message<?> message) throws PulsarClientException {
+
+ }
+
+ @Override
+ public void acknowledge(MessageId messageId) throws PulsarClientException {
+
+ }
+
+ @Override
+ public void acknowledgeCumulative(Message<?> message) throws PulsarClientException {
+
+ }
+
+ @Override
+ public void acknowledgeCumulative(MessageId messageId) throws PulsarClientException {
+ acknowledgedIds.put(messageId, messageId);
+ }
+
+ @Override
+ public CompletableFuture<Void> acknowledgeAsync(Message<?> message) {
+ return null;
+ }
+
+ @Override
+ public CompletableFuture<Void> acknowledgeAsync(MessageId messageId) {
+ acknowledgedIds.put(messageId, messageId);
+ return CompletableFuture.completedFuture(null);
+ }
+
+ @Override
+ public CompletableFuture<Void> acknowledgeCumulativeAsync(Message<?> message) {
+ return null;
+ }
+
+ @Override
+ public CompletableFuture<Void> acknowledgeCumulativeAsync(MessageId messageId) {
+ return null;
+ }
+
+ @Override
+ public ConsumerStats getStats() {
+ return null;
+ }
+
+ @Override
+ public void close() throws PulsarClientException {
+
+ }
+
+ @Override
+ public CompletableFuture<Void> closeAsync() {
+ return null;
+ }
+
+ @Override
+ public boolean hasReachedEndOfTopic() {
+ return false;
+ }
+
+ @Override
+ public void redeliverUnacknowledgedMessages() {
+
+ }
+
+ @Override
+ public void seek(MessageId messageId) throws PulsarClientException {
+
+ }
+
+ @Override
+ public CompletableFuture<Void> seekAsync(MessageId messageId) {
+ return null;
+ }
+
+ @Override
+ public boolean isConnected() {
+ return true;
+ }
+
+ @Override
+ public String getConsumerName() {
+ return "test-consumer-0";
+ }
+ }
+
+ private static List<Message> createMessages(int startIndex, int numMessages) {
+ final List<Message> messages = new ArrayList<>();
+ for (int i = startIndex; i < (startIndex + numMessages); ++i) {
+ String content = "message-" + i;
+ messages.add(createMessage(content, createMessageId(1, i + 1, 1)));
+ }
+ return messages;
+ }
+
+ private static Message<byte[]> createMessage(String content, String messageId) {
+ return new MessageImpl<>(messageId, Collections.emptyMap(), content.getBytes(), Schema.BYTES);
+ }
+
+ private static String createMessageId(long ledgerId, long entryId, long partitionIndex) {
+ return String.format("%d:%d:%d", ledgerId, entryId, partitionIndex);
+ }
+}