You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by yh...@apache.org on 2024/02/07 21:32:29 UTC

(beam) branch master updated: Fix Jms drop record (#30218)

This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c72a9f83fed Fix Jms drop record (#30218)
c72a9f83fed is described below

commit c72a9f83fedb5d558a138ff01199a5fc16f605db
Author: Yi Hu <ya...@google.com>
AuthorDate: Wed Feb 7 16:32:21 2024 -0500

    Fix Jms drop record (#30218)
    
    * Fix JmsIO read drop message
    
    * Close consumer before finalizing
    
    * Close last consumer and session on finalizeCheckpoint
    
    * Acknowlege single message in a session
    
    * Remove unnecessary write lock in JmsCheckpointMakr (it's immutable)
    
    * Add unit test
    
    * Add message when test fail
    
    * remove test leftover
    
    * address comments
    
    * enable integration test on amqp
---
 .../apache/beam/sdk/io/jms/JmsCheckpointMark.java  | 201 ++++++++++++++-------
 .../java/org/apache/beam/sdk/io/jms/JmsIO.java     |  84 ++++++---
 .../java/org/apache/beam/sdk/io/jms/JmsIOIT.java   |  45 ++++-
 .../java/org/apache/beam/sdk/io/jms/JmsIOTest.java | 109 +++++++----
 4 files changed, 301 insertions(+), 138 deletions(-)

diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsCheckpointMark.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsCheckpointMark.java
index f9f382dc3cd..e213561917d 100644
--- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsCheckpointMark.java
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsCheckpointMark.java
@@ -19,11 +19,12 @@ package org.apache.beam.sdk.io.jms;
 
 import java.io.IOException;
 import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
+import javax.jms.JMSException;
 import javax.jms.Message;
+import javax.jms.MessageConsumer;
+import javax.jms.Session;
 import org.apache.beam.sdk.io.UnboundedSource;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import org.checkerframework.checker.nullness.qual.Nullable;
@@ -39,78 +40,58 @@ class JmsCheckpointMark implements UnboundedSource.CheckpointMark, Serializable
 
   private static final Logger LOG = LoggerFactory.getLogger(JmsCheckpointMark.class);
 
-  private Instant oldestMessageTimestamp = Instant.now();
-  private transient List<Message> messages = new ArrayList<>();
+  private Instant oldestMessageTimestamp;
+  private transient @Nullable Message lastMessage;
+  private transient @Nullable MessageConsumer consumer;
+  private transient @Nullable Session session;
 
-  @VisibleForTesting transient boolean discarded = false;
-
-  @VisibleForTesting final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
-
-  JmsCheckpointMark() {}
+  private JmsCheckpointMark(
+      Instant oldestMessageTimestamp,
+      @Nullable Message lastMessage,
+      @Nullable MessageConsumer consumer,
+      @Nullable Session session) {
+    this.oldestMessageTimestamp = oldestMessageTimestamp;
+    this.lastMessage = lastMessage;
+    this.consumer = consumer;
+    this.session = session;
+  }
 
-  void add(Message message) throws Exception {
-    lock.writeLock().lock();
+  /** Acknowledge all outstanding message. */
+  @Override
+  public void finalizeCheckpoint() {
     try {
-      if (discarded) {
-        throw new IllegalStateException(
-            String.format(
-                "Attempting to add message %s to checkpoint that is discarded.", message));
-      }
-      Instant currentMessageTimestamp = new Instant(message.getJMSTimestamp());
-      if (currentMessageTimestamp.isBefore(oldestMessageTimestamp)) {
-        oldestMessageTimestamp = currentMessageTimestamp;
+      // Jms spec will implicitly acknowledge _all_ messaged already received by the same
+      // session if one message in this session is being acknowledged.
+      if (lastMessage != null) {
+        lastMessage.acknowledge();
       }
-      messages.add(message);
-    } finally {
-      lock.writeLock().unlock();
+    } catch (JMSException e) {
+      // The effect of this is message not get acknowledged and thus will be redelivered. It is
+      // not fatal, so we just raise error log. Similar below.
+      LOG.error(
+          "Failed to acknowledge the message. Will redeliver and might cause duplication.", e);
     }
-  }
 
-  Instant getOldestMessageTimestamp() {
-    lock.readLock().lock();
-    try {
-      return this.oldestMessageTimestamp;
-    } finally {
-      lock.readLock().unlock();
+    // session is closed after message acknowledged otherwise other consumer may receive duplicate
+    // messages.
+    if (consumer != null) {
+      try {
+        consumer.close();
+        consumer = null;
+      } catch (JMSException e) {
+        LOG.info("Error closing JMS consumer. It may have already been closed.");
+      }
     }
-  }
 
-  void discard() {
-    lock.writeLock().lock();
-    try {
-      this.discarded = true;
-    } finally {
-      lock.writeLock().unlock();
-    }
-  }
-
-  /**
-   * Acknowledge all outstanding message. Since we believe that messages will be delivered in
-   * timestamp order, and acknowledged messages will not be retried, the newest message in this
-   * batch is a good bound for future messages.
-   */
-  @Override
-  public void finalizeCheckpoint() {
-    lock.writeLock().lock();
-    try {
-      if (discarded) {
-        messages.clear();
-        return;
+    // session needs to be closed after message acknowledged because the latter needs session remain
+    // active.
+    if (session != null) {
+      try {
+        session.close();
+        session = null;
+      } catch (JMSException e) {
+        LOG.info("Error closing JMS session. It may have already been closed.");
       }
-      for (Message message : messages) {
-        try {
-          message.acknowledge();
-          Instant currentMessageTimestamp = new Instant(message.getJMSTimestamp());
-          if (currentMessageTimestamp.isAfter(oldestMessageTimestamp)) {
-            oldestMessageTimestamp = currentMessageTimestamp;
-          }
-        } catch (Exception e) {
-          LOG.error("Exception while finalizing message: ", e);
-        }
-      }
-      messages.clear();
-    } finally {
-      lock.writeLock().unlock();
     }
   }
 
@@ -118,8 +99,8 @@ class JmsCheckpointMark implements UnboundedSource.CheckpointMark, Serializable
   private void readObject(java.io.ObjectInputStream stream)
       throws IOException, ClassNotFoundException {
     stream.defaultReadObject();
-    messages = new ArrayList<>();
-    discarded = false;
+    lastMessage = null;
+    session = null;
   }
 
   @Override
@@ -138,4 +119,90 @@ class JmsCheckpointMark implements UnboundedSource.CheckpointMark, Serializable
   public int hashCode() {
     return Objects.hash(oldestMessageTimestamp);
   }
+
+  static Preparer newPreparer() {
+    return new Preparer();
+  }
+
+  /**
+   * A class preparing the immutable checkpoint. It is mutable so that new messages can be added.
+   */
+  static class Preparer {
+    private Instant oldestMessageTimestamp = Instant.now();
+    private transient @Nullable Message lastMessage = null;
+
+    @VisibleForTesting transient boolean discarded = false;
+
+    @VisibleForTesting final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
+
+    private Preparer() {}
+
+    void add(Message message) throws Exception {
+      lock.writeLock().lock();
+      try {
+        if (discarded) {
+          throw new IllegalStateException(
+              String.format(
+                  "Attempting to add message %s to checkpoint that is discarded.", message));
+        }
+        Instant currentMessageTimestamp = new Instant(message.getJMSTimestamp());
+        if (currentMessageTimestamp.isBefore(oldestMessageTimestamp)) {
+          oldestMessageTimestamp = currentMessageTimestamp;
+        }
+        lastMessage = message;
+      } finally {
+        lock.writeLock().unlock();
+      }
+    }
+
+    Instant getOldestMessageTimestamp() {
+      lock.readLock().lock();
+      try {
+        return this.oldestMessageTimestamp;
+      } finally {
+        lock.readLock().unlock();
+      }
+    }
+
+    void discard() {
+      lock.writeLock().lock();
+      try {
+        this.discarded = true;
+      } finally {
+        lock.writeLock().unlock();
+      }
+    }
+
+    /**
+     * Create a new checkpoint mark based on the current preparer. This will reset the messages held
+     * by the preparer, and the owner of the preparer is responsible to create a new Jms session
+     * after this call.
+     */
+    JmsCheckpointMark newCheckpoint(@Nullable MessageConsumer consumer, @Nullable Session session) {
+      JmsCheckpointMark checkpointMark;
+      lock.writeLock().lock();
+      try {
+        if (discarded) {
+          lastMessage = null;
+          checkpointMark = this.emptyCheckpoint();
+        } else {
+          checkpointMark =
+              new JmsCheckpointMark(oldestMessageTimestamp, lastMessage, consumer, session);
+          lastMessage = null;
+          oldestMessageTimestamp = Instant.now();
+        }
+      } finally {
+        lock.writeLock().unlock();
+      }
+      return checkpointMark;
+    }
+
+    JmsCheckpointMark emptyCheckpoint() {
+      return new JmsCheckpointMark(oldestMessageTimestamp, null, null, null);
+    }
+
+    boolean isEmpty() {
+      return lastMessage == null;
+    }
+  }
 }
diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
index 2344dec449a..7ffea24f913 100644
--- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
@@ -66,6 +66,7 @@ import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects;
 import org.checkerframework.checker.initialization.qual.Initialized;
 import org.checkerframework.checker.nullness.qual.Nullable;
@@ -494,7 +495,7 @@ public class JmsIO {
   static class UnboundedJmsReader<T> extends UnboundedReader<T> {
 
     private UnboundedJmsSource<T> source;
-    private JmsCheckpointMark checkpointMark;
+    @VisibleForTesting JmsCheckpointMark.Preparer checkpointMarkPreparer;
     private Connection connection;
     private Session session;
     private MessageConsumer consumer;
@@ -506,11 +507,32 @@ public class JmsIO {
 
     public UnboundedJmsReader(UnboundedJmsSource<T> source, PipelineOptions options) {
       this.source = source;
-      this.checkpointMark = new JmsCheckpointMark();
+      this.checkpointMarkPreparer = JmsCheckpointMark.newPreparer();
       this.currentMessage = null;
       this.options = options;
     }
 
+    /** recreate session and consumer. */
+    private synchronized void recreateSession() throws IOException {
+      try {
+        this.session = this.connection.createSession(false, Session.CLIENT_ACKNOWLEDGE);
+      } catch (Exception e) {
+        throw new IOException("Error creating JMS session", e);
+      }
+
+      Read<T> spec = source.spec;
+
+      try {
+        if (source.spec.getTopic() != null) {
+          consumer = session.createConsumer(session.createTopic(spec.getTopic()));
+        } else {
+          consumer = session.createConsumer(session.createQueue(spec.getQueue()));
+        }
+      } catch (Exception e) {
+        throw new IOException("Error creating JMS consumer", e);
+      }
+    }
+
     @Override
     public boolean start() throws IOException {
       Read<T> spec = source.spec;
@@ -534,21 +556,7 @@ public class JmsIO {
         throw new IOException("Error connecting to JMS", e);
       }
 
-      try {
-        this.session = this.connection.createSession(false, Session.CLIENT_ACKNOWLEDGE);
-      } catch (Exception e) {
-        throw new IOException("Error creating JMS session", e);
-      }
-
-      try {
-        if (spec.getTopic() != null) {
-          this.consumer = this.session.createConsumer(this.session.createTopic(spec.getTopic()));
-        } else {
-          this.consumer = this.session.createConsumer(this.session.createQueue(spec.getQueue()));
-        }
-      } catch (Exception e) {
-        throw new IOException("Error creating JMS consumer", e);
-      }
+      recreateSession();
 
       return advance();
     }
@@ -556,15 +564,19 @@ public class JmsIO {
     @Override
     public boolean advance() throws IOException {
       try {
-        Message message = this.consumer.receiveNoWait();
-
+        Message message;
+        synchronized (this) {
+          message = this.consumer.receiveNoWait();
+          // put add in synchronized to make sure all messages in preparer are in same session
+          if (message != null) {
+            checkpointMarkPreparer.add(message);
+          }
+        }
         if (message == null) {
           currentMessage = null;
           return false;
         }
 
-        checkpointMark.add(message);
-
         currentMessage = this.source.spec.getMessageMapper().mapMessage(message);
         currentTimestamp = new Instant(message.getJMSTimestamp());
 
@@ -584,7 +596,7 @@ public class JmsIO {
 
     @Override
     public Instant getWatermark() {
-      return checkpointMark.getOldestMessageTimestamp();
+      return checkpointMarkPreparer.getOldestMessageTimestamp();
     }
 
     @Override
@@ -597,7 +609,22 @@ public class JmsIO {
 
     @Override
     public CheckpointMark getCheckpointMark() {
-      return checkpointMark;
+      if (checkpointMarkPreparer.isEmpty()) {
+        return checkpointMarkPreparer.emptyCheckpoint();
+      }
+
+      MessageConsumer consumerToClose;
+      Session sessionTofinalize;
+      synchronized (this) {
+        consumerToClose = consumer;
+        sessionTofinalize = session;
+      }
+      try {
+        recreateSession();
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      return checkpointMarkPreparer.newCheckpoint(consumerToClose, sessionTofinalize);
     }
 
     @Override
@@ -617,7 +644,6 @@ public class JmsIO {
 
     @SuppressWarnings("FutureReturnValueIgnored")
     private void doClose() {
-
       try {
         closeAutoscaler();
         closeConsumer();
@@ -625,16 +651,14 @@ public class JmsIO {
             options.as(ExecutorOptions.class).getScheduledExecutorService();
         executorService.schedule(
             () -> {
-              LOG.debug(
-                  "Closing session and connection after delay {}", source.spec.getCloseTimeout());
+              LOG.debug("Closing connection after delay {}", source.spec.getCloseTimeout());
               // Discard the checkpoints and set the reader as inactive
-              checkpointMark.discard();
+              checkpointMarkPreparer.discard();
               closeSession();
               closeConnection();
             },
             source.spec.getCloseTimeout().getMillis(),
             TimeUnit.MILLISECONDS);
-
       } catch (Exception e) {
         LOG.error("Error closing reader", e);
       }
@@ -652,7 +676,7 @@ public class JmsIO {
       }
     }
 
-    private void closeSession() {
+    private synchronized void closeSession() {
       try {
         if (session != null) {
           session.close();
@@ -663,7 +687,7 @@ public class JmsIO {
       }
     }
 
-    private void closeConsumer() {
+    private synchronized void closeConsumer() {
       try {
         if (consumer != null) {
           consumer.close();
diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java
index fb6a08384f2..6f5a00760e3 100644
--- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java
+++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java
@@ -27,12 +27,17 @@ import java.io.IOException;
 import java.io.Serializable;
 import java.time.Instant;
 import java.util.Collection;
-import java.util.Collections;
+import java.util.Enumeration;
 import java.util.HashSet;
 import java.util.Set;
 import java.util.UUID;
 import java.util.function.Function;
+import javax.jms.Connection;
 import javax.jms.ConnectionFactory;
+import javax.jms.JMSException;
+import javax.jms.Message;
+import javax.jms.QueueBrowser;
+import javax.jms.Session;
 import javax.jms.TextMessage;
 import org.apache.activemq.ActiveMQConnectionFactory;
 import org.apache.activemq.command.ActiveMQTextMessage;
@@ -54,6 +59,7 @@ import org.apache.beam.sdk.testutils.metrics.TimeMonitor;
 import org.apache.beam.sdk.testutils.publishing.InfluxDBSettings;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import org.apache.qpid.jms.JmsConnectionFactory;
 import org.joda.time.Duration;
 import org.junit.After;
@@ -84,7 +90,6 @@ import org.junit.runners.Parameterized;
  */
 @RunWith(Parameterized.class)
 public class JmsIOIT implements Serializable {
-
   private static final String NAMESPACE = JmsIOIT.class.getName();
   private static final String READ_TIME_METRIC = "read_time";
   private static final String WRITE_TIME_METRIC = "write_time";
@@ -132,12 +137,12 @@ public class JmsIOIT implements Serializable {
 
   @Parameterized.Parameters(name = "with client class {3}")
   public static Collection<Object[]> connectionFactories() {
-    return Collections.singletonList(
+    return ImmutableList.of(
         new Object[] {
           "vm://localhost", 5672, "jms.sendAcksAsync=false", ActiveMQConnectionFactory.class
         });
     // TODO(https://github.com/apache/beam/issues/26175) Test failure on direct runner due to
-    //  JmsIO read on amqp slow on Jenkins.
+    //  JmsIO read on amqp slow on CI (passed locally)
     // new Object[] {
     //   "amqp://localhost", 5672, "jms.forceAsyncAcks=false", JmsConnectionFactory.class
     // });
@@ -181,7 +186,7 @@ public class JmsIOIT implements Serializable {
   }
 
   @Test
-  public void testPublishingThenReadingAll() throws IOException {
+  public void testPublishingThenReadingAll() throws IOException, JMSException {
     PipelineResult writeResult = publishingMessages();
     PipelineResult.State writeState = writeResult.waitUntilFinish();
     assertNotEquals(PipelineResult.State.FAILED, writeState);
@@ -196,11 +201,21 @@ public class JmsIOIT implements Serializable {
     MetricsReader metricsReader = new MetricsReader(readResult, NAMESPACE);
     long actualRecords = metricsReader.getCounterMetric(READ_ELEMENT_METRIC_NAME);
 
+    // TODO(yathu) resolve pending messages with direct runner then we can simply assert
+    //   actual-records == total-records.
+    //   Due to direct runner only finalize checkpoint at very end, there are open consumers (may
+    //   with buffer) and O(open_consumer) message won't get delivered to other session.
+    int unackRecords = countRemain(QUEUE);
+    assertTrue(
+        String.format("Too many unacknowledged messages: %d", unackRecords),
+        unackRecords < OPTIONS.getNumberOfRecords() * 0.002);
+
+    // acknowledged records
+    int ackRecords = OPTIONS.getNumberOfRecords() - unackRecords;
     assertTrue(
         String.format(
-            "actual number of records %d smaller than expected: %d.",
-            actualRecords, OPTIONS.getNumberOfRecords()),
-        OPTIONS.getNumberOfRecords() <= actualRecords);
+            "actual number of records %d smaller than expected: %d.", actualRecords, ackRecords),
+        ackRecords <= actualRecords);
     collectAndPublishMetrics(writeResult, readResult);
   }
 
@@ -279,6 +294,20 @@ public class JmsIOIT implements Serializable {
     };
   }
 
+  private int countRemain(String queue) throws JMSException {
+    Connection connection = connectionFactory.createConnection(USERNAME, PASSWORD);
+    connection.start();
+    Session session = connection.createSession(false, Session.AUTO_ACKNOWLEDGE);
+    QueueBrowser browser = session.createBrowser(session.createQueue(queue));
+    Enumeration<Message> messages = browser.getEnumeration();
+    int count = 0;
+    while (messages.hasMoreElements()) {
+      messages.nextElement();
+      count++;
+    }
+    return count;
+  }
+
   static class ToString extends DoFn<Long, String> {
     @ProcessElement
     public void processElement(@Element Long element, OutputReceiver<String> outputReceiver) {
diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
index 8eba8c494f0..d82873b37cf 100644
--- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
+++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
@@ -78,6 +78,8 @@ import org.apache.activemq.util.Callback;
 import org.apache.beam.sdk.PipelineResult;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.SerializableCoder;
+import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark;
+import org.apache.beam.sdk.io.jms.JmsIO.UnboundedJmsReader;
 import org.apache.beam.sdk.metrics.MetricNameFilter;
 import org.apache.beam.sdk.metrics.MetricQueryResults;
 import org.apache.beam.sdk.metrics.MetricsFilter;
@@ -420,33 +422,14 @@ public class JmsIOTest {
     // that the consumer will poll for message, which is exactly what we want for the test.
     // We are also sending message acknowledgements synchronously to ensure that they are
     // processed before any subsequent assertions.
-    Connection connection =
-        connectionFactoryWithSyncAcksAndWithoutPrefetch.createConnection(USERNAME, PASSWORD);
-    connection.start();
-    Session session = connection.createSession(false, Session.AUTO_ACKNOWLEDGE);
-    MessageProducer producer = session.createProducer(session.createQueue(QUEUE));
-    for (int i = 0; i < 10; i++) {
-      producer.send(session.createTextMessage("test " + i));
-    }
-    producer.close();
-    session.close();
-    connection.close();
-
-    JmsIO.Read spec =
-        JmsIO.read()
-            .withConnectionFactory(connectionFactoryWithSyncAcksAndWithoutPrefetch)
-            .withUsername(USERNAME)
-            .withPassword(PASSWORD)
-            .withQueue(QUEUE);
-    JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
-    JmsIO.UnboundedJmsReader reader = source.createReader(null, null);
+    UnboundedJmsReader reader = setupReaderForTest();
 
     // start the reader and move to the first record
     assertTrue(reader.start());
 
     // consume 3 messages (NB: start already consumed the first message)
     for (int i = 0; i < 3; i++) {
-      assertTrue(reader.advance());
+      assertTrue(String.format("Failed at %d-th message", i), reader.advance());
     }
 
     // the messages are still pending in the queue (no ACK yet)
@@ -460,7 +443,7 @@ public class JmsIOTest {
 
     // we read the 6 pending messages
     for (int i = 0; i < 6; i++) {
-      assertTrue(reader.advance());
+      assertTrue(String.format("Failed at %d-th message", i), reader.advance());
     }
 
     // still 6 pending messages as we didn't finalize the checkpoint
@@ -472,6 +455,66 @@ public class JmsIOTest {
     assertEquals(0, count(QUEUE));
   }
 
+  @Test
+  public void testCheckpointMarkAndFinalizeSeparately() throws Exception {
+    UnboundedJmsReader reader = setupReaderForTest();
+
+    // start the reader and move to the first record
+    assertTrue(reader.start());
+
+    // consume 2 message (NB: start already consumed the first message)
+    assertTrue(reader.advance());
+    assertTrue(reader.advance());
+
+    // get checkpoint mark after consumed 4 messages
+    CheckpointMark mark = reader.getCheckpointMark();
+
+    // consume two more messages after checkpoint made
+    reader.advance();
+    reader.advance();
+
+    // the messages are still pending in the queue (no ACK yet)
+    assertEquals(10, count(QUEUE));
+
+    // we finalize the checkpoint
+    mark.finalizeCheckpoint();
+
+    // the checkpoint finalize ack the messages, and so they are not pending in the queue anymore
+    assertEquals(7, count(QUEUE));
+  }
+
+  private JmsIO.UnboundedJmsReader setupReaderForTest() throws JMSException {
+    // we are using no prefetch here
+    // prefetch is an ActiveMQ feature: to make efficient use of network resources the broker
+    // utilizes a 'push' model to dispatch messages to consumers. However, in the case of our
+    // test, it means that we can have some latency between the receiveNoWait() method used by
+    // the consumer and the prefetch buffer populated by the broker. Using a prefetch to 0 means
+    // that the consumer will poll for message, which is exactly what we want for the test.
+    // We are also sending message acknowledgements synchronously to ensure that they are
+    // processed before any subsequent assertions.
+    Connection connection =
+        connectionFactoryWithSyncAcksAndWithoutPrefetch.createConnection(USERNAME, PASSWORD);
+    connection.start();
+    Session session = connection.createSession(false, Session.AUTO_ACKNOWLEDGE);
+    MessageProducer producer = session.createProducer(session.createQueue(QUEUE));
+    for (int i = 0; i < 10; i++) {
+      producer.send(session.createTextMessage("test " + i));
+    }
+    producer.close();
+    session.close();
+    connection.close();
+
+    JmsIO.Read spec =
+        JmsIO.read()
+            .withConnectionFactory(connectionFactoryWithSyncAcksAndWithoutPrefetch)
+            .withUsername(USERNAME)
+            .withPassword(PASSWORD)
+            .withQueue(QUEUE);
+    JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
+    JmsIO.UnboundedJmsReader reader = source.createReader(PipelineOptionsFactory.create(), null);
+    return reader;
+  }
+
   private Function<?, ?> getJmsMessageAck(Class connectorClass) {
     final int delay = 10;
     return connectorClass == JmsConnectionFactory.class
@@ -545,7 +588,7 @@ public class JmsIOTest {
             .withPassword(PASSWORD)
             .withQueue(QUEUE);
     JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
-    JmsIO.UnboundedJmsReader reader = source.createReader(null, null);
+    JmsIO.UnboundedJmsReader reader = source.createReader(PipelineOptionsFactory.create(), null);
 
     // start the reader and move to the first record
     assertTrue(reader.start());
@@ -583,7 +626,7 @@ public class JmsIOTest {
   /** Test the checkpoint mark default coder, which is actually AvroCoder. */
   @Test
   public void testCheckpointMarkDefaultCoder() throws Exception {
-    JmsCheckpointMark jmsCheckpointMark = new JmsCheckpointMark();
+    JmsCheckpointMark jmsCheckpointMark = JmsCheckpointMark.newPreparer().newCheckpoint(null, null);
     Coder coder = new JmsIO.UnboundedJmsSource(null).getCheckpointMarkCoder();
     CoderProperties.coderSerializable(coder);
     CoderProperties.coderDecodeEncodeEqual(coder, jmsCheckpointMark);
@@ -598,7 +641,7 @@ public class JmsIOTest {
             .withPassword(PASSWORD)
             .withQueue(QUEUE);
     JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
-    JmsIO.UnboundedJmsReader reader = source.createReader(null, null);
+    JmsIO.UnboundedJmsReader reader = source.createReader(PipelineOptionsFactory.create(), null);
 
     // start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values
     reader.start();
@@ -622,7 +665,7 @@ public class JmsIOTest {
             .withAutoScaler(autoScaler);
 
     JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
-    JmsIO.UnboundedJmsReader reader = source.createReader(null, null);
+    JmsIO.UnboundedJmsReader reader = source.createReader(PipelineOptionsFactory.create(), null);
 
     // start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values
     reader.start();
@@ -668,12 +711,12 @@ public class JmsIOTest {
   }
 
   private boolean getDiscardedValue(JmsIO.UnboundedJmsReader reader) {
-    JmsCheckpointMark checkpoint = (JmsCheckpointMark) reader.getCheckpointMark();
-    checkpoint.lock.readLock().lock();
+    JmsCheckpointMark.Preparer preparer = reader.checkpointMarkPreparer;
+    preparer.lock.readLock().lock();
     try {
-      return checkpoint.discarded;
+      return preparer.discarded;
     } finally {
-      checkpoint.lock.readLock().unlock();
+      preparer.lock.readLock().unlock();
     }
   }
 
@@ -698,7 +741,7 @@ public class JmsIOTest {
             .withPassword(PASSWORD)
             .withQueue(QUEUE);
     JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
-    JmsIO.UnboundedJmsReader reader = source.createReader(null, null);
+    JmsIO.UnboundedJmsReader reader = source.createReader(PipelineOptionsFactory.create(), null);
 
     // start the reader and move to the first record
     assertTrue(reader.start());
@@ -725,8 +768,8 @@ public class JmsIOTest {
     // still 6 pending messages as we didn't finalize the checkpoint
     assertEquals(6, count(QUEUE));
 
-    // But here we discard the checkpoint
-    ((JmsCheckpointMark) reader.getCheckpointMark()).discard();
+    // But here we discard the pending checkpoint
+    reader.checkpointMarkPreparer.discard();
     // we finalize the checkpoint: no messages should be acked
     reader.getCheckpointMark().finalizeCheckpoint();