You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2022/06/14 14:42:18 UTC

[iceberg] branch master updated: Core: Update ExpireSnapshots impl for branching and tagging (#4578)

This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 98dc9fe872 Core: Update ExpireSnapshots impl for branching and tagging (#4578)
98dc9fe872 is described below

commit 98dc9fe8724375e8b932e7e8afdb925c897b4695
Author: Amogh Jahagirdar <ja...@amazon.com>
AuthorDate: Tue Jun 14 07:42:12 2022 -0700

    Core: Update ExpireSnapshots impl for branching and tagging (#4578)
---
 .../java/org/apache/iceberg/RemoveSnapshots.java   | 157 +++++++++++--
 .../org/apache/iceberg/TableMetadataParser.java    |   7 +-
 .../java/org/apache/iceberg/TableProperties.java   |   3 +
 .../org/apache/iceberg/TestRemoveSnapshots.java    | 249 +++++++++++++++++++++
 4 files changed, 393 insertions(+), 23 deletions(-)

diff --git a/core/src/main/java/org/apache/iceberg/RemoveSnapshots.java b/core/src/main/java/org/apache/iceberg/RemoveSnapshots.java
index d920c5348f..2f9fa023d6 100644
--- a/core/src/main/java/org/apache/iceberg/RemoveSnapshots.java
+++ b/core/src/main/java/org/apache/iceberg/RemoveSnapshots.java
@@ -20,7 +20,9 @@
 package org.apache.iceberg;
 
 import java.io.IOException;
+import java.util.Collection;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
@@ -34,6 +36,7 @@ import org.apache.iceberg.io.CloseableIterable;
 import org.apache.iceberg.relocated.com.google.common.base.Joiner;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
 import org.apache.iceberg.relocated.com.google.common.collect.Sets;
 import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
 import org.apache.iceberg.util.DateTimeUtil;
@@ -54,6 +57,8 @@ import static org.apache.iceberg.TableProperties.COMMIT_TOTAL_RETRY_TIME_MS;
 import static org.apache.iceberg.TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT;
 import static org.apache.iceberg.TableProperties.GC_ENABLED;
 import static org.apache.iceberg.TableProperties.GC_ENABLED_DEFAULT;
+import static org.apache.iceberg.TableProperties.MAX_REF_AGE_MS;
+import static org.apache.iceberg.TableProperties.MAX_REF_AGE_MS_DEFAULT;
 import static org.apache.iceberg.TableProperties.MAX_SNAPSHOT_AGE_MS;
 import static org.apache.iceberg.TableProperties.MAX_SNAPSHOT_AGE_MS_DEFAULT;
 import static org.apache.iceberg.TableProperties.MIN_SNAPSHOTS_TO_KEEP;
@@ -75,10 +80,12 @@ class RemoveSnapshots implements ExpireSnapshots {
 
   private final TableOperations ops;
   private final Set<Long> idsToRemove = Sets.newHashSet();
+  private final long now;
+  private final long defaultMaxRefAgeMs;
   private boolean cleanExpiredFiles = true;
   private TableMetadata base;
-  private long expireOlderThan;
-  private int minNumSnapshots;
+  private long defaultExpireOlderThan;
+  private int defaultMinNumSnapshots;
   private Consumer<String> deleteFunc = defaultDelete;
   private ExecutorService deleteExecutorService = DEFAULT_DELETE_EXECUTOR_SERVICE;
   private ExecutorService planExecutorService = ThreadPools.getWorkerPool();
@@ -86,21 +93,27 @@ class RemoveSnapshots implements ExpireSnapshots {
   RemoveSnapshots(TableOperations ops) {
     this.ops = ops;
     this.base = ops.current();
+    ValidationException.check(
+        PropertyUtil.propertyAsBoolean(base.properties(), GC_ENABLED, GC_ENABLED_DEFAULT),
+        "Cannot expire snapshots: GC is disabled (deleting files may corrupt other tables)");
 
-    long maxSnapshotAgeMs = PropertyUtil.propertyAsLong(
+    long defaultMaxSnapshotAgeMs = PropertyUtil.propertyAsLong(
         base.properties(),
         MAX_SNAPSHOT_AGE_MS,
         MAX_SNAPSHOT_AGE_MS_DEFAULT);
-    this.expireOlderThan = System.currentTimeMillis() - maxSnapshotAgeMs;
 
-    this.minNumSnapshots = PropertyUtil.propertyAsInt(
+    this.now = System.currentTimeMillis();
+    this.defaultExpireOlderThan = now - defaultMaxSnapshotAgeMs;
+    this.defaultMinNumSnapshots = PropertyUtil.propertyAsInt(
         base.properties(),
         MIN_SNAPSHOTS_TO_KEEP,
         MIN_SNAPSHOTS_TO_KEEP_DEFAULT);
 
-    ValidationException.check(
-        PropertyUtil.propertyAsBoolean(base.properties(), GC_ENABLED, GC_ENABLED_DEFAULT),
-        "Cannot expire snapshots: GC is disabled (deleting files may corrupt other tables)");
+    this.defaultMaxRefAgeMs = PropertyUtil.propertyAsLong(
+        base.properties(),
+        MAX_REF_AGE_MS,
+        MAX_REF_AGE_MS_DEFAULT
+    );
   }
 
   @Override
@@ -120,7 +133,7 @@ class RemoveSnapshots implements ExpireSnapshots {
   public ExpireSnapshots expireOlderThan(long timestampMillis) {
     LOG.info("Expiring snapshots older than: {} ({})",
         DateTimeUtil.formatTimestampMillis(timestampMillis), timestampMillis);
-    this.expireOlderThan = timestampMillis;
+    this.defaultExpireOlderThan = timestampMillis;
     return this;
   }
 
@@ -128,7 +141,7 @@ class RemoveSnapshots implements ExpireSnapshots {
   public ExpireSnapshots retainLast(int numSnapshots) {
     Preconditions.checkArgument(1 <= numSnapshots,
             "Number of snapshots to retain must be at least 1, cannot be: %s", numSnapshots);
-    this.minNumSnapshots = numSnapshots;
+    this.defaultMinNumSnapshots = numSnapshots;
     return this;
   }
 
@@ -161,21 +174,119 @@ class RemoveSnapshots implements ExpireSnapshots {
 
   private TableMetadata internalApply() {
     this.base = ops.refresh();
+    if (base.snapshots().isEmpty()) {
+      return base;
+    }
 
     Set<Long> idsToRetain = Sets.newHashSet();
-    List<Long> ancestorIds = SnapshotUtil.ancestorIds(base.currentSnapshot(), base::snapshot);
-    if (minNumSnapshots >= ancestorIds.size()) {
-      idsToRetain.addAll(ancestorIds);
-    } else {
-      idsToRetain.addAll(ancestorIds.subList(0, minNumSnapshots));
+    // Identify refs that should be removed
+    Map<String, SnapshotRef> retainedRefs = computeRetainedRefs(base.refs());
+    Map<Long, List<String>> retainedIdToRefs = Maps.newHashMap();
+    for (Map.Entry<String, SnapshotRef> retainedRefEntry : retainedRefs.entrySet()) {
+      long snapshotId = retainedRefEntry.getValue().snapshotId();
+      retainedIdToRefs.putIfAbsent(snapshotId, Lists.newArrayList());
+      retainedIdToRefs.get(snapshotId).add(retainedRefEntry.getKey());
+      idsToRetain.add(snapshotId);
+    }
+
+    for (long idToRemove : idsToRemove) {
+      List<String> refsForId = retainedIdToRefs.get(idToRemove);
+      Preconditions.checkArgument(refsForId == null,
+          "Cannot expire %s. Still referenced by refs: %s", idToRemove, refsForId);
+    }
+
+    idsToRetain.addAll(computeAllBranchSnapshotsToRetain(retainedRefs.values()));
+    idsToRetain.addAll(unreferencedSnapshotsToRetain(retainedRefs.values()));
+
+    TableMetadata.Builder updatedMetaBuilder = TableMetadata.buildFrom(base);
+
+    base.refs().keySet().stream()
+        .filter(ref -> !retainedRefs.containsKey(ref))
+        .forEach(updatedMetaBuilder::removeRef);
+
+    base.snapshots().stream()
+        .map(Snapshot::snapshotId)
+        .filter(snapshot -> !idsToRetain.contains(snapshot))
+        .forEach(idsToRemove::add);
+    updatedMetaBuilder.removeSnapshots(idsToRemove);
+
+    return updatedMetaBuilder.build();
+  }
+
+  private Map<String, SnapshotRef> computeRetainedRefs(Map<String, SnapshotRef> refs) {
+    Map<String, SnapshotRef> retainedRefs = Maps.newHashMap();
+    for (Map.Entry<String, SnapshotRef> refEntry : refs.entrySet()) {
+      String name = refEntry.getKey();
+      SnapshotRef ref = refEntry.getValue();
+      if (name.equals(SnapshotRef.MAIN_BRANCH)) {
+        retainedRefs.put(name, ref);
+        continue;
+      }
+
+      Snapshot snapshot = base.snapshot(ref.snapshotId());
+      long maxRefAgeMs = ref.maxRefAgeMs() != null ? ref.maxRefAgeMs() : defaultMaxRefAgeMs;
+      if (snapshot != null) {
+        long refAgeMs = now - snapshot.timestampMillis();
+        if (refAgeMs <= maxRefAgeMs) {
+          retainedRefs.put(name, ref);
+        }
+      } else {
+        LOG.warn("Removing invalid ref {}: snapshot {} does not exist", name, ref.snapshotId());
+      }
     }
 
-    TableMetadata updateMeta = base.removeSnapshotsIf(snapshot ->
-        idsToRemove.contains(snapshot.snapshotId()) ||
-        (snapshot.timestampMillis() < expireOlderThan && !idsToRetain.contains(snapshot.snapshotId())));
-    List<Snapshot> updateSnapshots = updateMeta.snapshots();
-    List<Snapshot> baseSnapshots = base.snapshots();
-    return updateSnapshots.size() != baseSnapshots.size() ? updateMeta : base;
+    return retainedRefs;
+  }
+
+  private Set<Long> computeAllBranchSnapshotsToRetain(Collection<SnapshotRef> refs) {
+    Set<Long> branchSnapshotsToRetain = Sets.newHashSet();
+    for (SnapshotRef ref : refs) {
+      if (ref.isBranch()) {
+        long expireSnapshotsOlderThan = ref.maxSnapshotAgeMs() != null ? now - ref.maxSnapshotAgeMs() :
+            defaultExpireOlderThan;
+        int minSnapshotsToKeep = ref.minSnapshotsToKeep() != null ? ref.minSnapshotsToKeep() :
+            defaultMinNumSnapshots;
+        branchSnapshotsToRetain.addAll(
+            computeBranchSnapshotsToRetain(ref.snapshotId(), expireSnapshotsOlderThan, minSnapshotsToKeep));
+      }
+    }
+
+    return branchSnapshotsToRetain;
+  }
+
+  private Set<Long> computeBranchSnapshotsToRetain(
+      long snapshot,
+      long expireSnapshotsOlderThan,
+      int minSnapshotsToKeep) {
+    Set<Long> idsToRetain = Sets.newHashSet();
+    for (Snapshot ancestor : SnapshotUtil.ancestorsOf(snapshot, base::snapshot)) {
+      if (idsToRetain.size() < minSnapshotsToKeep || ancestor.timestampMillis() >= expireSnapshotsOlderThan) {
+        idsToRetain.add(ancestor.snapshotId());
+      } else {
+        return idsToRetain;
+      }
+    }
+
+    return idsToRetain;
+  }
+
+  private Set<Long> unreferencedSnapshotsToRetain(Collection<SnapshotRef> refs) {
+    Set<Long> referencedSnapshots = Sets.newHashSet();
+    for (SnapshotRef ref : refs) {
+      for (Snapshot snapshot : SnapshotUtil.ancestorsOf(ref.snapshotId(), base::snapshot)) {
+        referencedSnapshots.add(snapshot.snapshotId());
+      }
+    }
+
+    Set<Long> snapshotsToRetain = Sets.newHashSet();
+    for (Snapshot snapshot : base.snapshots()) {
+      if (!referencedSnapshots.contains(snapshot.snapshotId()) && // unreferenced
+          snapshot.timestampMillis() >= defaultExpireOlderThan) { // not old enough to expire
+        snapshotsToRetain.add(snapshot.snapshotId());
+      }
+    }
+
+    return snapshotsToRetain;
   }
 
   @Override
@@ -190,6 +301,10 @@ class RemoveSnapshots implements ExpireSnapshots {
         .onlyRetryOn(CommitFailedException.class)
         .run(item -> {
           TableMetadata updated = internalApply();
+          if (cleanExpiredFiles && updated.refs().size() > 1) {
+            throw new UnsupportedOperationException("Cannot incrementally clean files for tables with more than 1 ref");
+          }
+
           ops.commit(base, updated);
         });
     LOG.info("Committed snapshot changes");
diff --git a/core/src/main/java/org/apache/iceberg/TableMetadataParser.java b/core/src/main/java/org/apache/iceberg/TableMetadataParser.java
index a709fe5ee3..c167b08706 100644
--- a/core/src/main/java/org/apache/iceberg/TableMetadataParser.java
+++ b/core/src/main/java/org/apache/iceberg/TableMetadataParser.java
@@ -415,12 +415,15 @@ public class TableMetadataParser {
 
     // parse properties map
     Map<String, String> properties = JsonUtil.getStringMap(PROPERTIES, node);
-    long currentVersionId = JsonUtil.getLong(CURRENT_SNAPSHOT_ID, node);
+    long currentSnapshotId = JsonUtil.getLong(CURRENT_SNAPSHOT_ID, node);
     long lastUpdatedMillis = JsonUtil.getLong(LAST_UPDATED_MILLIS, node);
 
     Map<String, SnapshotRef> refs;
     if (node.has(REFS)) {
       refs = refsFromJson(node.get(REFS));
+    } else if (currentSnapshotId != -1) {
+      // initialize the main branch if there are no refs
+      refs = ImmutableMap.of(SnapshotRef.MAIN_BRANCH, SnapshotRef.branchBuilder(currentSnapshotId).build());
     } else {
       refs = ImmutableMap.of();
     }
@@ -457,7 +460,7 @@ public class TableMetadataParser {
 
     return new TableMetadata(metadataLocation, formatVersion, uuid, location,
         lastSequenceNumber, lastUpdatedMillis, lastAssignedColumnId, currentSchemaId, schemas, defaultSpecId, specs,
-        lastAssignedPartitionId, defaultSortOrderId, sortOrders, properties, currentVersionId,
+        lastAssignedPartitionId, defaultSortOrderId, sortOrders, properties, currentSnapshotId,
         snapshots, entries.build(), metadataEntries.build(), refs,
         ImmutableList.of() /* no changes from the file */);
   }
diff --git a/core/src/main/java/org/apache/iceberg/TableProperties.java b/core/src/main/java/org/apache/iceberg/TableProperties.java
index 86fafe99fc..e8fd8ce838 100644
--- a/core/src/main/java/org/apache/iceberg/TableProperties.java
+++ b/core/src/main/java/org/apache/iceberg/TableProperties.java
@@ -305,6 +305,9 @@ public class TableProperties {
   public static final String MIN_SNAPSHOTS_TO_KEEP = "history.expire.min-snapshots-to-keep";
   public static final int MIN_SNAPSHOTS_TO_KEEP_DEFAULT = 1;
 
+  public static final String MAX_REF_AGE_MS = "history.expire.max-ref-age-ms";
+  public static final long MAX_REF_AGE_MS_DEFAULT = Long.MAX_VALUE;
+
   public static final String DELETE_ISOLATION_LEVEL = "write.delete.isolation-level";
   public static final String DELETE_ISOLATION_LEVEL_DEFAULT = "serializable";
 
diff --git a/core/src/test/java/org/apache/iceberg/TestRemoveSnapshots.java b/core/src/test/java/org/apache/iceberg/TestRemoveSnapshots.java
index 7ed319acb5..99cba44d38 100644
--- a/core/src/test/java/org/apache/iceberg/TestRemoveSnapshots.java
+++ b/core/src/test/java/org/apache/iceberg/TestRemoveSnapshots.java
@@ -1204,4 +1204,253 @@ public class TestRemoveSnapshots extends TableTestBase {
             .build(),
         deletedFiles);
   }
+
+  @Test
+  public void testTagExpiration() {
+    table.newAppend()
+        .appendFile(FILE_A)
+        .commit();
+
+    long now = System.currentTimeMillis();
+    long maxAgeMs = 100;
+    long expirationTime = now + maxAgeMs;
+
+    table.manageSnapshots()
+        .createTag("tag", table.currentSnapshot().snapshotId())
+        .setMaxRefAgeMs("tag", maxAgeMs)
+        .commit();
+
+    table.newAppend()
+        .appendFile(FILE_B)
+        .commit();
+
+    table.manageSnapshots()
+        .createBranch("branch", table.currentSnapshot().snapshotId())
+        .commit();
+
+    waitUntilAfter(expirationTime);
+
+    table.expireSnapshots().cleanExpiredFiles(false).commit();
+
+    Assert.assertNull(table.ops().current().ref("tag"));
+    Assert.assertNotNull(table.ops().current().ref("branch"));
+    Assert.assertNotNull(table.ops().current().ref(SnapshotRef.MAIN_BRANCH));
+  }
+
+  @Test
+  public void testBranchExpiration() {
+    table.newAppend()
+        .appendFile(FILE_A)
+        .commit();
+
+    long now = System.currentTimeMillis();
+    long maxAgeMs = 100;
+    long expirationTime = now + maxAgeMs;
+
+    table.manageSnapshots()
+        .createBranch("branch", table.currentSnapshot().snapshotId())
+        .setMaxRefAgeMs("branch", maxAgeMs)
+        .commit();
+
+    table.newAppend()
+        .appendFile(FILE_B)
+        .commit();
+
+    table.manageSnapshots()
+        .createTag("tag", table.currentSnapshot().snapshotId())
+        .commit();
+
+    waitUntilAfter(expirationTime);
+
+    table.expireSnapshots().cleanExpiredFiles(false).commit();
+
+    Assert.assertNull(table.ops().current().ref("branch"));
+    Assert.assertNotNull(table.ops().current().ref("tag"));
+    Assert.assertNotNull(table.ops().current().ref(SnapshotRef.MAIN_BRANCH));
+  }
+
+  @Test
+  public void testMultipleRefsAndCleanExpiredFilesFails() {
+    table.newAppend()
+        .appendFile(FILE_A)
+        .commit();
+
+    table.manageSnapshots()
+        .createTag("TagA", table.currentSnapshot().snapshotId())
+        .commit();
+
+    AssertHelpers.assertThrows(
+        "Should fail removing snapshots and files when there is more than 1 ref",
+        UnsupportedOperationException.class,
+        "Cannot incrementally clean files for tables with more than 1 ref",
+        () -> table.expireSnapshots().cleanExpiredFiles(true).commit());
+  }
+
+  @Test
+  public void testFailRemovingSnapshotWhenStillReferencedByBranch() {
+    table.newAppend()
+        .appendFile(FILE_A)
+        .commit();
+
+    AppendFiles append = table.newAppend()
+        .appendFile(FILE_B)
+        .stageOnly();
+
+    long snapshotId = append.apply().snapshotId();
+
+    append.commit();
+
+    table.manageSnapshots()
+        .createBranch("branch", snapshotId)
+        .commit();
+
+    AssertHelpers.assertThrows(
+        "Should fail removing snapshot when it is still referenced",
+        IllegalArgumentException.class,
+        "Cannot expire 2. Still referenced by refs: [branch]",
+        () -> table.expireSnapshots().expireSnapshotId(snapshotId).commit());
+  }
+
+  @Test
+  public void testFailRemovingSnapshotWhenStillReferencedByTag() {
+    table.newAppend()
+        .appendFile(FILE_A)
+        .commit();
+
+    long snapshotId = table.currentSnapshot().snapshotId();
+
+    table.manageSnapshots()
+        .createTag("tag", snapshotId)
+        .commit();
+
+    // commit another snapshot so the first one isn't referenced by main
+    table.newAppend()
+        .appendFile(FILE_B)
+        .commit();
+
+    AssertHelpers.assertThrows(
+        "Should fail removing snapshot when it is still referenced",
+          IllegalArgumentException.class,
+        "Cannot expire 1. Still referenced by refs: [tag]",
+        () -> table.expireSnapshots().expireSnapshotId(snapshotId).commit());
+  }
+
+  @Test
+  public void testRetainUnreferencedSnapshotsWithinExpirationAge() {
+    table.newAppend()
+        .appendFile(FILE_A)
+        .commit();
+
+    long expireTimestampSnapshotA = waitUntilAfter(table.currentSnapshot().timestampMillis());
+    waitUntilAfter(expireTimestampSnapshotA);
+
+    table.newAppend()
+        .appendFile(FILE_B)
+        .stageOnly()
+        .commit();
+
+    table.newAppend()
+        .appendFile(FILE_C)
+        .commit();
+
+    table.expireSnapshots()
+        .expireOlderThan(expireTimestampSnapshotA)
+        .commit();
+
+    Assert.assertEquals(2, table.ops().current().snapshots().size());
+  }
+
+  // ToDo: Add tests which commit to branches once committing snapshots to a branch is supported
+
+  @Test
+  public void testMinSnapshotsToKeepMultipleBranches() {
+    table.newAppend().appendFile(FILE_A).commit();
+    long initialSnapshotId = table.currentSnapshot().snapshotId();
+    table.newAppend().appendFile(FILE_B).commit();
+
+    // stage a snapshot and get its id
+    AppendFiles append = table.newAppend().appendFile(FILE_C).stageOnly();
+    long branchSnapshotId = append.apply().snapshotId();
+    append.commit();
+
+    Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));
+
+    long maxSnapshotAgeMs = 1;
+    long expirationTime = System.currentTimeMillis() + maxSnapshotAgeMs;
+
+    // configure main so that the initial snapshot will expire
+    table.manageSnapshots()
+        .setMinSnapshotsToKeep(SnapshotRef.MAIN_BRANCH, 1)
+        .setMaxSnapshotAgeMs(SnapshotRef.MAIN_BRANCH, 1)
+        .commit();
+
+    // retain 3 snapshots on branch (including the initial snapshot)
+    table.manageSnapshots()
+        .createBranch("branch", branchSnapshotId)
+        .setMinSnapshotsToKeep("branch", 3)
+        .setMaxSnapshotAgeMs("branch", maxSnapshotAgeMs)
+        .commit();
+
+    waitUntilAfter(expirationTime);
+    table.expireSnapshots().cleanExpiredFiles(false).commit();
+
+    Assert.assertEquals("Should have 3 snapshots (none removed)", 3, Iterables.size(table.snapshots()));
+
+    // stop retaining snapshots from the branch
+    table.manageSnapshots()
+        .setMinSnapshotsToKeep("branch", 1)
+        .commit();
+
+    table.expireSnapshots().cleanExpiredFiles(false).commit();
+
+    Assert.assertEquals("Should have 2 snapshots (initial removed)", 2, Iterables.size(table.snapshots()));
+    Assert.assertNull(table.ops().current().snapshot(initialSnapshotId));
+  }
+
+  @Test
+  public void testMaxSnapshotAgeMultipleBranches() {
+    table.newAppend().appendFile(FILE_A).commit();
+    long initialSnapshotId = table.currentSnapshot().snapshotId();
+
+    long ageMs = 10;
+    long expirationTime = System.currentTimeMillis() + ageMs;
+
+    waitUntilAfter(expirationTime);
+
+    table.newAppend().appendFile(FILE_B).commit();
+
+    // configure main so that the initial snapshot will expire
+    table.manageSnapshots()
+        .setMaxSnapshotAgeMs(SnapshotRef.MAIN_BRANCH, ageMs)
+        .setMinSnapshotsToKeep(SnapshotRef.MAIN_BRANCH, 1)
+        .commit();
+
+    // stage a snapshot and get its id
+    AppendFiles append = table.newAppend().appendFile(FILE_C).stageOnly();
+    long branchSnapshotId = append.apply().snapshotId();
+    append.commit();
+
+    Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));
+
+    // retain all snapshots on branch (including the initial snapshot)
+    table.manageSnapshots()
+        .createBranch("branch", branchSnapshotId)
+        .setMinSnapshotsToKeep("branch", 1)
+        .setMaxSnapshotAgeMs("branch", Long.MAX_VALUE)
+        .commit();
+
+    table.expireSnapshots().cleanExpiredFiles(false).commit();
+
+    Assert.assertEquals("Should have 3 snapshots (none removed)", 3, Iterables.size(table.snapshots()));
+
+    // allow the initial snapshot to age off from branch
+    table.manageSnapshots()
+        .setMaxSnapshotAgeMs("branch", ageMs)
+        .commit();
+
+    table.expireSnapshots().cleanExpiredFiles(false).commit();
+
+    Assert.assertEquals("Should have 2 snapshots (initial removed)", 2, Iterables.size(table.snapshots()));
+    Assert.assertNull(table.ops().current().snapshot(initialSnapshotId));
+  }
 }