You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2023/02/14 22:16:53 UTC

[beam] branch bigtable-cdc-feature-branch updated: Stream changes and handle Heartbeat responses (#25458)

This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch bigtable-cdc-feature-branch
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/bigtable-cdc-feature-branch by this push:
     new 37a18966f5f Stream changes and handle Heartbeat responses (#25458)
37a18966f5f is described below

commit 37a18966f5f5e052d9168f5f7d7e8075d86131ec
Author: Tony Tang <nf...@gmail.com>
AuthorDate: Tue Feb 14 17:16:44 2023 -0500

    Stream changes and handle Heartbeat responses (#25458)
---
 .../changestreams/ChangeStreamMetrics.java         | 20 +++++
 .../changestreams/action/ChangeStreamAction.java   | 37 ++++++++
 .../action/ReadChangeStreamPartitionAction.java    | 51 ++++++++++-
 .../changestreams/dao/ChangeStreamDao.java         | 60 +++++++++++++
 .../changestreams/dao/MetadataTableDao.java        | 54 ++++++++++++
 .../encoder/MetadataTableEncoder.java              | 62 ++++++++++++++
 .../package-info.java}                             | 22 ++---
 .../changestreams/model/PartitionRecord.java       | 53 ++++++++++--
 .../changestreams/restriction/StreamProgress.java  | 50 +++++++++++
 .../action/ChangeStreamActionTest.java             | 89 ++++++++++++++++++++
 .../changestreams/dao/MetadataTableDaoTest.java    | 96 +++++++++++++++++++++
 .../changestreams/dofn/MetadataTableDaoTest.java   | 98 ++++++++++++++++++++++
 ...adChangeStreamPartitionProgressTrackerTest.java | 80 ++++++++++++++++++
 13 files changed, 747 insertions(+), 25 deletions(-)

diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java
index f8177adf873..2c534020d39 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java
@@ -35,6 +35,17 @@ public class ChangeStreamMetrics implements Serializable {
           org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics.class,
           "list_partitions_count");
 
+  // -------------------
+  // Read change stream metrics
+
+  /**
+   * Counter for the total number of heartbeats identified during the execution of the Connector.
+   */
+  public static final Counter HEARTBEAT_COUNT =
+      Metrics.counter(
+          org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics.class,
+          "heartbeat_count");
+
   /**
    * Increments the {@link
    * org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#LIST_PARTITIONS_COUNT} by
@@ -44,6 +55,15 @@ public class ChangeStreamMetrics implements Serializable {
     inc(LIST_PARTITIONS_COUNT);
   }
 
+  /**
+   * Increments the {@link
+   * org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#HEARTBEAT_COUNT} by 1 if
+   * the metric is enabled.
+   */
+  public void incHeartbeatCount() {
+    inc(HEARTBEAT_COUNT);
+  }
+
   private void inc(Counter counter) {
     counter.inc();
   }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java
index 012173a89cd..8a2cbd6ed84 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java
@@ -17,11 +17,15 @@
  */
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.action;
 
+import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.formatByteStringRange;
+
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamRecord;
+import com.google.cloud.bigtable.data.v2.models.Heartbeat;
 import com.google.protobuf.ByteString;
 import java.util.Optional;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.TimestampConverter;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.model.PartitionRecord;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.StreamProgress;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -100,6 +104,39 @@ public class ChangeStreamAction {
       DoFn.OutputReceiver<KV<ByteString, ChangeStreamMutation>> receiver,
       ManualWatermarkEstimator<Instant> watermarkEstimator,
       boolean shouldDebug) {
+    if (record instanceof Heartbeat) {
+      Heartbeat heartbeat = (Heartbeat) record;
+      StreamProgress streamProgress =
+          new StreamProgress(
+              heartbeat.getChangeStreamContinuationToken(), heartbeat.getLowWatermark());
+      final Instant watermark = TimestampConverter.toInstant(heartbeat.getLowWatermark());
+      watermarkEstimator.setWatermark(watermark);
+
+      if (shouldDebug) {
+        LOG.info(
+            "RCSP {}: Heartbeat partition: {} token: {} watermark: {}",
+            formatByteStringRange(partitionRecord.getPartition()),
+            formatByteStringRange(heartbeat.getChangeStreamContinuationToken().getPartition()),
+            heartbeat.getChangeStreamContinuationToken().getToken(),
+            heartbeat.getLowWatermark());
+      }
+      // If the tracker fail to claim the streamProgress, it most likely means the runner initiated
+      // a checkpoint. See {@link
+      // org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.ReadChangeStreamPartitionProgressTracker}
+      // for more information regarding runner initiated checkpoints.
+      if (!tracker.tryClaim(streamProgress)) {
+        if (shouldDebug) {
+          LOG.info(
+              "RCSP {}: Failed to claim heart beat tracker",
+              formatByteStringRange(partitionRecord.getPartition()));
+        }
+        return Optional.of(DoFn.ProcessContinuation.stop());
+      }
+      metrics.incHeartbeatCount();
+    } else {
+      LOG.warn(
+          "RCSP {}: Invalid response type", formatByteStringRange(partitionRecord.getPartition()));
+    }
     return Optional.empty();
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java
index d55838e5b85..321a3c1a1ca 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java
@@ -17,9 +17,12 @@
  */
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.action;
 
+import com.google.api.gax.rpc.ServerStream;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamRecord;
 import com.google.protobuf.ByteString;
 import java.io.IOException;
+import java.util.Optional;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.ChangeStreamDao;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableDao;
@@ -117,7 +120,53 @@ public class ReadChangeStreamPartitionAction {
       DoFn.OutputReceiver<KV<ByteString, ChangeStreamMutation>> receiver,
       ManualWatermarkEstimator<Instant> watermarkEstimator)
       throws IOException {
+    // Watermark being delayed beyond 5 minutes signals a possible problem.
+    boolean shouldDebug =
+        watermarkEstimator.getState().plus(Duration.standardMinutes(5)).isBeforeNow();
 
-    return ProcessContinuation.stop();
+    if (shouldDebug) {
+      LOG.info(
+          "RCSP: Partition: "
+              + partitionRecord
+              + "\n Watermark: "
+              + watermarkEstimator.getState()
+              + "\n RestrictionTracker: "
+              + tracker.currentRestriction());
+    }
+
+    // Update the metadata table with the watermark
+    metadataTableDao.updateWatermark(
+        partitionRecord.getPartition(),
+        watermarkEstimator.getState(),
+        tracker.currentRestriction().getCurrentToken());
+
+    // Start to stream the partition.
+    ServerStream<ChangeStreamRecord> stream = null;
+    try {
+      stream =
+          changeStreamDao.readChangeStreamPartition(
+              partitionRecord,
+              tracker.currentRestriction(),
+              partitionRecord.getEndTime(),
+              heartbeatDurationSeconds,
+              shouldDebug);
+      for (ChangeStreamRecord record : stream) {
+        Optional<ProcessContinuation> result =
+            changeStreamAction.run(
+                partitionRecord, record, tracker, receiver, watermarkEstimator, shouldDebug);
+        // changeStreamAction will usually return Optional.empty() except for when a checkpoint
+        // (either runner or pipeline initiated) is required.
+        if (result.isPresent()) {
+          return result.get();
+        }
+      }
+    } catch (Exception e) {
+      throw e;
+    } finally {
+      if (stream != null) {
+        stream.cancel();
+      }
+    }
+    return ProcessContinuation.resume();
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/ChangeStreamDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/ChangeStreamDao.java
index c890193a168..98902866cb1 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/ChangeStreamDao.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/ChangeStreamDao.java
@@ -17,9 +17,22 @@
  */
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao;
 
+import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.formatByteStringRange;
+
+import com.google.api.gax.rpc.ServerStream;
+import com.google.cloud.Timestamp;
 import com.google.cloud.bigtable.data.v2.BigtableDataClient;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamRecord;
 import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
+import com.google.cloud.bigtable.data.v2.models.ReadChangeStreamQuery;
+import java.io.IOException;
+import java.util.Collections;
 import java.util.List;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.model.PartitionRecord;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.StreamProgress;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Duration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -44,4 +57,51 @@ public class ChangeStreamDao {
   public List<ByteStringRange> generateInitialChangeStreamPartitions() {
     return dataClient.generateInitialChangeStreamPartitionsCallable().all().call(tableId);
   }
+
+  /**
+   * Streams a partition.
+   *
+   * @param partition the partition to stream
+   * @param streamProgress may contain a continuation token for the stream request
+   * @param endTime time to end the stream, may be null
+   * @param heartbeatDurationSeconds period between heartbeat messages
+   * @return stream of ReadChangeStreamResponse
+   * @throws IOException if the stream could not be started
+   */
+  public ServerStream<ChangeStreamRecord> readChangeStreamPartition(
+      PartitionRecord partition,
+      StreamProgress streamProgress,
+      @Nullable Timestamp endTime,
+      Duration heartbeatDurationSeconds,
+      boolean shouldDebug)
+      throws IOException {
+    ReadChangeStreamQuery query =
+        ReadChangeStreamQuery.create(tableId).streamPartition(partition.getPartition());
+
+    ChangeStreamContinuationToken currentToken = streamProgress.getCurrentToken();
+    Timestamp startTime = partition.getStartTime();
+    List<ChangeStreamContinuationToken> changeStreamContinuationTokenList =
+        partition.getChangeStreamContinuationTokens();
+    if (currentToken != null) {
+      query.continuationTokens(Collections.singletonList(currentToken));
+    } else if (startTime != null) {
+      // Check if tracker has Continuation Token
+      query.startTime(startTime.toProto());
+    } else if (changeStreamContinuationTokenList != null) {
+      query.continuationTokens(changeStreamContinuationTokenList);
+    } else {
+      throw new IOException("Something went wrong");
+    }
+    if (endTime != null) {
+      query.endTime(endTime.toProto());
+    }
+    query.heartbeatDuration(heartbeatDurationSeconds.getStandardSeconds());
+    if (shouldDebug) {
+      LOG.info(
+          "RCSP {} ReadChangeStreamRequest: {}",
+          formatByteStringRange(partition.getPartition()),
+          query);
+    }
+    return dataClient.readChangeStream(query);
+  }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java
index c4889018c89..3c3e828c5bb 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java
@@ -22,8 +22,12 @@ import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTabl
 import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableAdminDao.STREAM_PARTITION_PREFIX;
 
 import com.google.cloud.bigtable.data.v2.BigtableDataClient;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.Range;
 import com.google.cloud.bigtable.data.v2.models.RowMutation;
 import com.google.protobuf.ByteString;
+import javax.annotation.Nullable;
+import org.joda.time.Instant;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -80,6 +84,56 @@ public class MetadataTableDao {
     return changeStreamNamePrefix.concat(DETECT_NEW_PARTITION_SUFFIX);
   }
 
+  /**
+   * Convert partition to a Stream Partition row key to query for metadata of partitions that are
+   * currently being streamed.
+   *
+   * @param partition convert to row key
+   * @return row key to insert to Cloud Bigtable.
+   */
+  public ByteString convertPartitionToStreamPartitionRowKey(Range.ByteStringRange partition) {
+    return getFullStreamPartitionPrefix().concat(Range.ByteStringRange.toByteString(partition));
+  }
+
+  /**
+   * Update the metadata for the rowKey. This helper adds necessary prefixes to the row key.
+   *
+   * @param rowKey row key of the row to update
+   * @param watermark watermark value to set for the cell
+   * @param currentToken continuation token to set for the cell
+   */
+  private void writeToMdTableWatermarkHelper(
+      ByteString rowKey, Instant watermark, @Nullable ChangeStreamContinuationToken currentToken) {
+    RowMutation rowMutation =
+        RowMutation.create(tableId, rowKey)
+            .setCell(
+                MetadataTableAdminDao.CF_WATERMARK,
+                MetadataTableAdminDao.QUALIFIER_DEFAULT,
+                watermark.getMillis());
+    if (currentToken != null) {
+      rowMutation.setCell(
+          MetadataTableAdminDao.CF_CONTINUATION_TOKEN,
+          MetadataTableAdminDao.QUALIFIER_DEFAULT,
+          currentToken.getToken());
+    }
+    dataClient.mutateRow(rowMutation);
+  }
+
+  /**
+   * Update the metadata for the row key represented by the partition.
+   *
+   * @param partition forms the row key of the row to update
+   * @param watermark watermark value to set for the cell
+   * @param currentToken continuation token to set for the cell
+   */
+  public void updateWatermark(
+      Range.ByteStringRange partition,
+      Instant watermark,
+      @Nullable ChangeStreamContinuationToken currentToken) {
+    writeToMdTableWatermarkHelper(
+        convertPartitionToStreamPartitionRowKey(partition), watermark, currentToken);
+  }
+
   /**
    * Set the version number for DetectNewPartition. This value can be checked later to verify that
    * the existing metadata table is compatible with current beam connector code.
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/encoder/MetadataTableEncoder.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/encoder/MetadataTableEncoder.java
new file mode 100644
index 00000000000..cbd2895efe1
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/encoder/MetadataTableEncoder.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigtable.changestreams.encoder;
+
+import com.google.cloud.bigtable.data.v2.models.Row;
+import com.google.cloud.bigtable.data.v2.models.RowCell;
+import java.util.List;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableAdminDao;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Longs;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+
+/** Helper methods that simplifies some conversion and extraction of metadata table content. */
+@Internal
+public class MetadataTableEncoder {
+  /**
+   * Read the watermark cell of a row from ReadRows.
+   *
+   * @param row row to extract the watermark from
+   * @return the watermark of the row
+   */
+  public static @Nullable Instant parseWatermarkFromRow(Row row) {
+    List<RowCell> cells =
+        row.getCells(MetadataTableAdminDao.CF_WATERMARK, MetadataTableAdminDao.QUALIFIER_DEFAULT);
+    if (cells.size() == 0) {
+      return null;
+    }
+    return Instant.ofEpochMilli(Longs.fromByteArray(cells.get(0).getValue().toByteArray()));
+  }
+
+  /**
+   * Read the continuation token cell of a row from ReadRows.
+   *
+   * @param row to extract the token from
+   * @return the token of the row
+   */
+  public static @Nullable String getTokenFromRow(Row row) {
+    List<RowCell> cells =
+        row.getCells(
+            MetadataTableAdminDao.CF_CONTINUATION_TOKEN, MetadataTableAdminDao.QUALIFIER_DEFAULT);
+    if (cells.size() == 0) {
+      return null;
+    }
+    return cells.get(0).getValue().toStringUtf8();
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/encoder/package-info.java
similarity index 52%
copy from sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
copy to sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/encoder/package-info.java
index bcf032b0aa4..29067faf29b 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/encoder/package-info.java
@@ -15,22 +15,12 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction;
-
-import java.io.Serializable;
-
 /**
- * Position for {@link ReadChangeStreamPartitionProgressTracker}. This represents contains
- * information that allows a stream, along with the {@link
- * org.apache.beam.sdk.io.gcp.bigtable.changestreams.model.PartitionRecord} to resume from a
- * checkpoint.
- *
- * <p>It should contain either a continuation token which represents a position in the stream, or it
- * can contain a close stream message which represents an end to the stream and the DoFn needs to
- * stop.
+ * Encoders for writing and reading from Metadata Table for Google Cloud Bigtable Change Streams.
  */
-public class StreamProgress implements Serializable {
-  private static final long serialVersionUID = -5384329262726188695L;
+@Internal
+@Experimental
+package org.apache.beam.sdk.io.gcp.bigtable.changestreams.encoder;
 
-  public StreamProgress() {}
-}
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.annotations.Internal;
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/model/PartitionRecord.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/model/PartitionRecord.java
index 7956a919bf2..15ebf4c8fd6 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/model/PartitionRecord.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/model/PartitionRecord.java
@@ -20,21 +20,26 @@ package org.apache.beam.sdk.io.gcp.bigtable.changestreams.model;
 import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.formatByteStringRange;
 
 import com.google.cloud.Timestamp;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
 import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Internal;
 
 /**
  * Output result of {@link
  * org.apache.beam.sdk.io.gcp.bigtable.changestreams.dofn.DetectNewPartitionsDoFn} containing
  * information required to stream a partition.
  */
+@Internal
 public class PartitionRecord implements Serializable {
   private static final long serialVersionUID = -7613861834142734474L;
 
   private ByteStringRange partition;
   @Nullable private Timestamp startTime;
+  @Nullable private List<ChangeStreamContinuationToken> changeStreamContinuationTokens;
   @Nullable private Timestamp endTime;
   private String uuid;
   private Timestamp parentLowWatermark;
@@ -52,13 +57,17 @@ public class PartitionRecord implements Serializable {
     this.endTime = endTime;
   }
 
-  @Nullable
-  public Timestamp getStartTime() {
-    return startTime;
-  }
-
-  public void setStartTime(@Nullable Timestamp startTime) {
-    this.startTime = startTime;
+  public PartitionRecord(
+      ByteStringRange partition,
+      List<ChangeStreamContinuationToken> changeStreamContinuationTokens,
+      String uuid,
+      Timestamp parentLowWatermark,
+      @Nullable Timestamp endTime) {
+    this.partition = partition;
+    this.changeStreamContinuationTokens = changeStreamContinuationTokens;
+    this.uuid = uuid;
+    this.parentLowWatermark = parentLowWatermark;
+    this.endTime = endTime;
   }
 
   @Nullable
@@ -70,6 +79,15 @@ public class PartitionRecord implements Serializable {
     this.endTime = endTime;
   }
 
+  @Nullable
+  public Timestamp getStartTime() {
+    return startTime;
+  }
+
+  public void setStartTime(@Nullable Timestamp startTime) {
+    this.startTime = startTime;
+  }
+
   public String getUuid() {
     return uuid;
   }
@@ -94,6 +112,16 @@ public class PartitionRecord implements Serializable {
     this.partition = partition;
   }
 
+  @Nullable
+  public List<ChangeStreamContinuationToken> getChangeStreamContinuationTokens() {
+    return changeStreamContinuationTokens;
+  }
+
+  public void setChangeStreamContinuationTokens(
+      @Nullable List<ChangeStreamContinuationToken> changeStreamContinuationTokens) {
+    this.changeStreamContinuationTokens = changeStreamContinuationTokens;
+  }
+
   @Override
   public boolean equals(@Nullable Object o) {
     if (this == o) {
@@ -105,6 +133,8 @@ public class PartitionRecord implements Serializable {
     PartitionRecord that = (PartitionRecord) o;
     return getPartition().equals(that.getPartition())
         && Objects.equals(getStartTime(), that.getStartTime())
+        && Objects.equals(
+            getChangeStreamContinuationTokens(), that.getChangeStreamContinuationTokens())
         && Objects.equals(getEndTime(), that.getEndTime())
         && getUuid().equals(that.getUuid())
         && Objects.equals(getParentLowWatermark(), that.getParentLowWatermark());
@@ -113,7 +143,12 @@ public class PartitionRecord implements Serializable {
   @Override
   public int hashCode() {
     return Objects.hash(
-        getPartition(), getStartTime(), getEndTime(), getUuid(), getParentLowWatermark());
+        getPartition(),
+        getStartTime(),
+        getChangeStreamContinuationTokens(),
+        getEndTime(),
+        getUuid(),
+        getParentLowWatermark());
   }
 
   @Override
@@ -123,6 +158,8 @@ public class PartitionRecord implements Serializable {
         + formatByteStringRange(partition)
         + ", startTime="
         + startTime
+        + ", changeStreamContinuationTokens="
+        + changeStreamContinuationTokens
         + ", endTime="
         + endTime
         + ", uuid='"
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
index bcf032b0aa4..ef35a040ee8 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
@@ -17,7 +17,12 @@
  */
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction;
 
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.protobuf.Timestamp;
 import java.io.Serializable;
+import java.util.Objects;
+import org.apache.beam.sdk.annotations.Internal;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 /**
  * Position for {@link ReadChangeStreamPartitionProgressTracker}. This represents contains
@@ -29,8 +34,53 @@ import java.io.Serializable;
  * can contain a close stream message which represents an end to the stream and the DoFn needs to
  * stop.
  */
+@Internal
 public class StreamProgress implements Serializable {
   private static final long serialVersionUID = -5384329262726188695L;
 
+  private @Nullable ChangeStreamContinuationToken currentToken;
+  private @Nullable Timestamp lowWatermark;
+
+  public @Nullable ChangeStreamContinuationToken getCurrentToken() {
+    return currentToken;
+  }
+
+  public @Nullable Timestamp getLowWatermark() {
+    return lowWatermark;
+  }
+
   public StreamProgress() {}
+
+  public StreamProgress(@Nullable ChangeStreamContinuationToken token, Timestamp lowWatermark) {
+    this.currentToken = token;
+    this.lowWatermark = lowWatermark;
+  }
+
+  @Override
+  public boolean equals(@Nullable Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (!(o instanceof StreamProgress)) {
+      return false;
+    }
+    StreamProgress that = (StreamProgress) o;
+    return Objects.equals(getCurrentToken(), that.getCurrentToken())
+        && Objects.equals(getLowWatermark(), that.getLowWatermark());
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(getCurrentToken());
+  }
+
+  @Override
+  public String toString() {
+    return "StreamProgress{"
+        + "currentToken="
+        + currentToken
+        + ", lowWatermark="
+        + lowWatermark
+        + '}';
+  }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamActionTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamActionTest.java
new file mode 100644
index 00000000000..a6f9f934b7c
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamActionTest.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigtable.changestreams.action;
+
+import static org.junit.Assert.assertFalse;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation;
+import com.google.cloud.bigtable.data.v2.models.Heartbeat;
+import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.Timestamp;
+import java.util.Optional;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.TimestampConverter;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.model.PartitionRecord;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.ReadChangeStreamPartitionProgressTracker;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.StreamProgress;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.values.KV;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+public class ChangeStreamActionTest {
+
+  private ChangeStreamMetrics metrics;
+  private ChangeStreamAction action;
+
+  private RestrictionTracker<StreamProgress, StreamProgress> tracker;
+  private PartitionRecord partitionRecord;
+  private DoFn.OutputReceiver<KV<ByteString, ChangeStreamMutation>> receiver;
+  private ManualWatermarkEstimator<Instant> watermarkEstimator;
+
+  @Before
+  public void setUp() {
+    metrics = mock(ChangeStreamMetrics.class);
+    tracker = mock(ReadChangeStreamPartitionProgressTracker.class);
+    partitionRecord = mock(PartitionRecord.class);
+    receiver = mock(DoFn.OutputReceiver.class);
+    watermarkEstimator = mock(ManualWatermarkEstimator.class);
+
+    action = new ChangeStreamAction(metrics);
+    when(tracker.tryClaim(any())).thenReturn(true);
+  }
+
+  @Test
+  public void testHeartBeat() {
+    final Timestamp lowWatermark = Timestamp.newBuilder().setSeconds(1000).build();
+    ChangeStreamContinuationToken changeStreamContinuationToken =
+        new ChangeStreamContinuationToken(ByteStringRange.create("a", "b"), "1234");
+    Heartbeat mockHeartBeat = Mockito.mock(Heartbeat.class);
+    Mockito.when(mockHeartBeat.getLowWatermark()).thenReturn(lowWatermark);
+    Mockito.when(mockHeartBeat.getChangeStreamContinuationToken())
+        .thenReturn(changeStreamContinuationToken);
+
+    final Optional<DoFn.ProcessContinuation> result =
+        action.run(partitionRecord, mockHeartBeat, tracker, receiver, watermarkEstimator, false);
+
+    assertFalse(result.isPresent());
+    verify(metrics).incHeartbeatCount();
+    verify(watermarkEstimator).setWatermark(eq(TimestampConverter.toInstant(lowWatermark)));
+    StreamProgress streamProgress = new StreamProgress(changeStreamContinuationToken, lowWatermark);
+    verify(tracker).tryClaim(eq(streamProgress));
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java
new file mode 100644
index 00000000000..8bb238162e7
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao;
+
+import static org.junit.Assert.assertEquals;
+
+import com.google.cloud.bigtable.admin.v2.BigtableTableAdminClient;
+import com.google.cloud.bigtable.admin.v2.BigtableTableAdminSettings;
+import com.google.cloud.bigtable.data.v2.BigtableDataClient;
+import com.google.cloud.bigtable.data.v2.BigtableDataSettings;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
+import com.google.cloud.bigtable.data.v2.models.Row;
+import com.google.cloud.bigtable.emulator.v2.BigtableEmulatorRule;
+import java.io.IOException;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.UniqueIdGenerator;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.encoder.MetadataTableEncoder;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class MetadataTableDaoTest {
+
+  @ClassRule
+  public static final BigtableEmulatorRule BIGTABLE_EMULATOR_RULE = BigtableEmulatorRule.create();
+
+  private static MetadataTableDao metadataTableDao;
+  private static MetadataTableAdminDao metadataTableAdminDao;
+  private static BigtableDataClient dataClient;
+  private static BigtableTableAdminClient adminClient;
+
+  @BeforeClass
+  public static void beforeClass() throws IOException {
+    BigtableTableAdminSettings adminSettings =
+        BigtableTableAdminSettings.newBuilderForEmulator(BIGTABLE_EMULATOR_RULE.getPort())
+            .setProjectId("fake-project")
+            .setInstanceId("fake-instance")
+            .build();
+    adminClient = BigtableTableAdminClient.create(adminSettings);
+    BigtableDataSettings dataSettingsBuilder =
+        BigtableDataSettings.newBuilderForEmulator(BIGTABLE_EMULATOR_RULE.getPort())
+            .setProjectId("fake-project")
+            .setInstanceId("fake-instance")
+            .build();
+    dataClient = BigtableDataClient.create(dataSettingsBuilder);
+  }
+
+  @Before
+  public void before() {
+    String changeStreamId = UniqueIdGenerator.generateRowKeyPrefix();
+    metadataTableAdminDao =
+        new MetadataTableAdminDao(
+            adminClient, null, changeStreamId, MetadataTableAdminDao.DEFAULT_METADATA_TABLE_NAME);
+    metadataTableAdminDao.createMetadataTable();
+
+    metadataTableDao =
+        new MetadataTableDao(
+            dataClient,
+            metadataTableAdminDao.getTableId(),
+            metadataTableAdminDao.getChangeStreamNamePrefix());
+  }
+
+  @Test
+  public void testUpdateWatermark() {
+    ByteStringRange partition = ByteStringRange.create("a", "b");
+    Instant watermark = Instant.now();
+    ChangeStreamContinuationToken token = new ChangeStreamContinuationToken(partition, "1234");
+    metadataTableDao.updateWatermark(partition, watermark, token);
+    Row row =
+        dataClient.readRow(
+            metadataTableAdminDao.getTableId(),
+            metadataTableDao.convertPartitionToStreamPartitionRowKey(partition));
+    assertEquals(token.getToken(), MetadataTableEncoder.getTokenFromRow(row));
+    assertEquals(watermark, MetadataTableEncoder.parseWatermarkFromRow(row));
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/MetadataTableDaoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/MetadataTableDaoTest.java
new file mode 100644
index 00000000000..05ff7bcec7e
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/MetadataTableDaoTest.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigtable.changestreams.dofn;
+
+import static org.junit.Assert.assertEquals;
+
+import com.google.cloud.bigtable.admin.v2.BigtableTableAdminClient;
+import com.google.cloud.bigtable.admin.v2.BigtableTableAdminSettings;
+import com.google.cloud.bigtable.data.v2.BigtableDataClient;
+import com.google.cloud.bigtable.data.v2.BigtableDataSettings;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
+import com.google.cloud.bigtable.data.v2.models.Row;
+import com.google.cloud.bigtable.emulator.v2.BigtableEmulatorRule;
+import java.io.IOException;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.UniqueIdGenerator;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableAdminDao;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableDao;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.encoder.MetadataTableEncoder;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class MetadataTableDaoTest {
+
+  @ClassRule
+  public static final BigtableEmulatorRule BIGTABLE_EMULATOR_RULE = BigtableEmulatorRule.create();
+
+  private static MetadataTableDao metadataTableDao;
+  private static MetadataTableAdminDao metadataTableAdminDao;
+  private static BigtableDataClient dataClient;
+  private static BigtableTableAdminClient adminClient;
+
+  @BeforeClass
+  public static void beforeClass() throws IOException {
+    BigtableTableAdminSettings adminSettings =
+        BigtableTableAdminSettings.newBuilderForEmulator(BIGTABLE_EMULATOR_RULE.getPort())
+            .setProjectId("fake-project")
+            .setInstanceId("fake-instance")
+            .build();
+    adminClient = BigtableTableAdminClient.create(adminSettings);
+    BigtableDataSettings dataSettingsBuilder =
+        BigtableDataSettings.newBuilderForEmulator(BIGTABLE_EMULATOR_RULE.getPort())
+            .setProjectId("fake-project")
+            .setInstanceId("fake-instance")
+            .build();
+    dataClient = BigtableDataClient.create(dataSettingsBuilder);
+  }
+
+  @Before
+  public void before() {
+    String changeStreamId = UniqueIdGenerator.generateRowKeyPrefix();
+    metadataTableAdminDao =
+        new MetadataTableAdminDao(
+            adminClient, null, changeStreamId, MetadataTableAdminDao.DEFAULT_METADATA_TABLE_NAME);
+    metadataTableAdminDao.createMetadataTable();
+
+    metadataTableDao =
+        new MetadataTableDao(
+            dataClient,
+            metadataTableAdminDao.getTableId(),
+            metadataTableAdminDao.getChangeStreamNamePrefix());
+  }
+
+  @Test
+  public void testUpdateWatermark() {
+    ByteStringRange partition = ByteStringRange.create("a", "b");
+    Instant watermark = Instant.now();
+    ChangeStreamContinuationToken token = new ChangeStreamContinuationToken(partition, "1234");
+    metadataTableDao.updateWatermark(partition, watermark, token);
+    Row row =
+        dataClient.readRow(
+            metadataTableAdminDao.getTableId(),
+            metadataTableDao.convertPartitionToStreamPartitionRowKey(partition));
+    assertEquals(token.getToken(), MetadataTableEncoder.getTokenFromRow(row));
+    assertEquals(watermark, MetadataTableEncoder.parseWatermarkFromRow(row));
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/ReadChangeStreamPartitionProgressTrackerTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/ReadChangeStreamPartitionProgressTrackerTest.java
new file mode 100644
index 00000000000..40e6777705c
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/ReadChangeStreamPartitionProgressTrackerTest.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import com.google.cloud.Timestamp;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.Range;
+import org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
+import org.junit.Test;
+
+public class ReadChangeStreamPartitionProgressTrackerTest {
+  @Test
+  public void testTryClaim() {
+    final StreamProgress streamProgress = new StreamProgress();
+    final ReadChangeStreamPartitionProgressTracker tracker =
+        new ReadChangeStreamPartitionProgressTracker(streamProgress);
+    assertEquals(streamProgress, tracker.currentRestriction());
+
+    ChangeStreamContinuationToken changeStreamContinuationToken =
+        new ChangeStreamContinuationToken(Range.ByteStringRange.create("a", "b"), "1234");
+    final StreamProgress streamProgress2 =
+        new StreamProgress(changeStreamContinuationToken, Timestamp.now().toProto());
+    assertTrue(tracker.tryClaim(streamProgress2));
+    assertEquals(streamProgress2, tracker.currentRestriction());
+    assertEquals(streamProgress2.getLowWatermark(), tracker.currentRestriction().getLowWatermark());
+
+    assertNull(tracker.trySplit(0.5));
+    assertEquals(streamProgress2, tracker.currentRestriction());
+    assertEquals(streamProgress2.getLowWatermark(), tracker.currentRestriction().getLowWatermark());
+    try {
+      tracker.checkDone();
+      assertFalse("Should not reach here because checkDone should have thrown an exception", false);
+    } catch (IllegalStateException e) {
+      assertTrue("There's more work to be done. CheckDone threw an exception", true);
+    }
+
+    final SplitResult<StreamProgress> splitResult = SplitResult.of(null, streamProgress2);
+    assertEquals(splitResult, tracker.trySplit(0));
+
+    assertFalse(tracker.tryClaim(streamProgress2));
+    // No exception thrown, it is done.
+    tracker.checkDone();
+  }
+
+  @Test
+  public void testTrySplitMultipleTimes() {
+    final StreamProgress streamProgress = new StreamProgress();
+    final ReadChangeStreamPartitionProgressTracker tracker =
+        new ReadChangeStreamPartitionProgressTracker(streamProgress);
+    assertEquals(streamProgress, tracker.currentRestriction());
+
+    final SplitResult<StreamProgress> splitResult = SplitResult.of(null, streamProgress);
+    assertEquals(splitResult, tracker.trySplit(0));
+
+    // Call trySplit again
+    assertNull(tracker.trySplit(0));
+    assertNull(tracker.trySplit(0));
+    tracker.checkDone();
+  }
+}