You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by dc...@apache.org on 2022/02/04 22:37:43 UTC

[samza] branch master updated: Fix - filter only configured stores in backup/restore init (#1582)

This is an automated email from the ASF dual-hosted git repository.

dchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e7c2fe  Fix - filter only configured stores in backup/restore init  (#1582)
2e7c2fe is described below

commit 2e7c2fe6c095d05b1b75fc0c2768ad0e9a81a085
Author: shekhars-li <72...@users.noreply.github.com>
AuthorDate: Fri Feb 4 14:37:34 2022 -0800

    Fix - filter only configured stores in backup/restore init  (#1582)
    
    * Fix - filter only configured stores in init for downloading SnapshotIndex from the blob store
    
    * Fix style issues
    
    * fix condition to check inmemory stores in ContainerStorageManager
---
 .../samza/serializers/model/SamzaObjectMapper.java |  3 +-
 .../storage/blobstore/BlobStoreBackupManager.java  |  3 +-
 .../storage/blobstore/BlobStoreRestoreManager.java |  2 +-
 .../blobstore/index/serde/SnapshotIndexSerde.java  | 13 +++--
 .../storage/blobstore/util/BlobStoreUtil.java      | 45 +++++++++------
 .../samza/storage/ContainerStorageManager.java     |  2 +-
 .../blobstore/TestBlobStoreBackupManager.java      |  4 +-
 .../storage/blobstore/util/TestBlobStoreUtil.java  | 65 +++++++++++++++++-----
 8 files changed, 95 insertions(+), 42 deletions(-)

diff --git a/samza-core/src/main/java/org/apache/samza/serializers/model/SamzaObjectMapper.java b/samza-core/src/main/java/org/apache/samza/serializers/model/SamzaObjectMapper.java
index b2f16d0..bcc3ae4 100644
--- a/samza-core/src/main/java/org/apache/samza/serializers/model/SamzaObjectMapper.java
+++ b/samza-core/src/main/java/org/apache/samza/serializers/model/SamzaObjectMapper.java
@@ -39,6 +39,7 @@ import com.fasterxml.jackson.databind.cfg.MapperConfig;
 import com.fasterxml.jackson.databind.introspect.AnnotatedField;
 import com.fasterxml.jackson.databind.introspect.AnnotatedMethod;
 import com.fasterxml.jackson.databind.module.SimpleModule;
+import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
 import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
 import org.apache.samza.checkpoint.CheckpointId;
@@ -149,7 +150,7 @@ public class SamzaObjectMapper {
 
     // Convert camel case to hyphenated field names, and register the module.
     mapper.setPropertyNamingStrategy(new CamelCaseToDashesStrategy());
-    mapper.registerModule(module);
+    mapper.registerModules(module, new Jdk8Module());
 
     return mapper;
   }
diff --git a/samza-core/src/main/java/org/apache/samza/storage/blobstore/BlobStoreBackupManager.java b/samza-core/src/main/java/org/apache/samza/storage/blobstore/BlobStoreBackupManager.java
index fb04ce5..c4eb56b 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/blobstore/BlobStoreBackupManager.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/blobstore/BlobStoreBackupManager.java
@@ -25,6 +25,7 @@ import java.io.File;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -133,7 +134,7 @@ public class BlobStoreBackupManager implements TaskBackupManager {
     // Note: blocks the caller thread.
     // TODO LOW shesharma exclude stores that are no longer configured during init
     Map<String, Pair<String, SnapshotIndex>> prevStoreSnapshotIndexes =
-        blobStoreUtil.getStoreSnapshotIndexes(jobName, jobId, taskName, checkpoint);
+        blobStoreUtil.getStoreSnapshotIndexes(jobName, jobId, taskName, checkpoint, new HashSet<>(storesToBackup));
     this.prevStoreSnapshotIndexesFuture =
         CompletableFuture.completedFuture(ImmutableMap.copyOf(prevStoreSnapshotIndexes));
     metrics.initNs.set(System.nanoTime() - startTime);
diff --git a/samza-core/src/main/java/org/apache/samza/storage/blobstore/BlobStoreRestoreManager.java b/samza-core/src/main/java/org/apache/samza/storage/blobstore/BlobStoreRestoreManager.java
index 4db891c..22aaecb 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/blobstore/BlobStoreRestoreManager.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/blobstore/BlobStoreRestoreManager.java
@@ -125,7 +125,7 @@ public class BlobStoreRestoreManager implements TaskRestoreManager {
     blobStoreManager.init();
 
     // get previous SCMs from checkpoint
-    prevStoreSnapshotIndexes = blobStoreUtil.getStoreSnapshotIndexes(jobName, jobId, taskName, checkpoint);
+    prevStoreSnapshotIndexes = blobStoreUtil.getStoreSnapshotIndexes(jobName, jobId, taskName, checkpoint, storesToRestore);
     metrics.getSnapshotIndexNs.set(System.nanoTime() - startTime);
     LOG.trace("Found previous snapshot index during blob store restore manager init for task: {} to be: {}",
         taskName, prevStoreSnapshotIndexes);
diff --git a/samza-core/src/main/java/org/apache/samza/storage/blobstore/index/serde/SnapshotIndexSerde.java b/samza-core/src/main/java/org/apache/samza/storage/blobstore/index/serde/SnapshotIndexSerde.java
index 9e316ad..f445b8c 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/blobstore/index/serde/SnapshotIndexSerde.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/blobstore/index/serde/SnapshotIndexSerde.java
@@ -21,8 +21,8 @@ package org.apache.samza.storage.blobstore.index.serde;
 
 import com.fasterxml.jackson.core.type.TypeReference;
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.ObjectReader;
 import com.fasterxml.jackson.databind.ObjectWriter;
-import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
 import org.apache.samza.SamzaException;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.serializers.model.SamzaObjectMapper;
@@ -32,16 +32,18 @@ import org.apache.samza.storage.blobstore.index.FileIndex;
 import org.apache.samza.storage.blobstore.index.FileMetadata;
 import org.apache.samza.storage.blobstore.index.SnapshotIndex;
 import org.apache.samza.storage.blobstore.index.SnapshotMetadata;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 
 public class SnapshotIndexSerde implements Serde<SnapshotIndex> {
-
+  private static final Logger LOG = LoggerFactory.getLogger(SnapshotIndexSerde.class);
   private final static ObjectMapper MAPPER = SamzaObjectMapper.getObjectMapper();
   private TypeReference<SnapshotIndex> typeReference;
   private final ObjectWriter objectWriter;
+  private final ObjectReader objectReader;
 
   public SnapshotIndexSerde() {
-    MAPPER.registerModule(new Jdk8Module());
     MAPPER.addMixIn(SnapshotIndex.class, JsonSnapshotIndexMixin.class)
         .addMixIn(SnapshotMetadata.class, JsonSnapshotMetadataMixin.class)
         .addMixIn(DirIndex.class, JsonDirIndexMixin.class)
@@ -51,12 +53,14 @@ public class SnapshotIndexSerde implements Serde<SnapshotIndex> {
 
     this.typeReference = new TypeReference<SnapshotIndex>() { };
     this.objectWriter = MAPPER.writerFor(typeReference);
+    this.objectReader = MAPPER.readerFor(typeReference);
   }
 
   @Override
   public SnapshotIndex fromBytes(byte[] bytes) {
     try {
-      return MAPPER.readerFor(typeReference).readValue(bytes);
+      LOG.debug("Modules loaded: {}", MAPPER.getRegisteredModuleIds());
+      return objectReader.readValue(bytes);
     } catch (Exception exception) {
       throw new SamzaException(String.format("Exception in deserializing SnapshotIndex bytes %s",
           new String(bytes)), exception);
@@ -66,6 +70,7 @@ public class SnapshotIndexSerde implements Serde<SnapshotIndex> {
   @Override
   public byte[] toBytes(SnapshotIndex snapshotIndex) {
     try {
+      LOG.debug("Modules loaded: {}", MAPPER.getRegisteredModuleIds());
       return objectWriter.writeValueAsBytes(snapshotIndex);
     } catch (Exception exception) {
       throw new SamzaException(String.format("Exception in serializing SnapshotIndex bytes %s", snapshotIndex), exception);
diff --git a/samza-core/src/main/java/org/apache/samza/storage/blobstore/util/BlobStoreUtil.java b/samza-core/src/main/java/org/apache/samza/storage/blobstore/util/BlobStoreUtil.java
index dc7e709..7e9553c 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/blobstore/util/BlobStoreUtil.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/blobstore/util/BlobStoreUtil.java
@@ -38,6 +38,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionException;
 import java.util.concurrent.CompletionStage;
@@ -83,6 +84,7 @@ public class BlobStoreUtil {
   private final ExecutorService executor;
   private final BlobStoreBackupManagerMetrics backupMetrics;
   private final BlobStoreRestoreManagerMetrics restoreMetrics;
+  private final SnapshotIndexSerde snapshotIndexSerde;
 
   public BlobStoreUtil(BlobStoreManager blobStoreManager, ExecutorService executor,
       BlobStoreBackupManagerMetrics backupMetrics, BlobStoreRestoreManagerMetrics restoreMetrics) {
@@ -90,6 +92,7 @@ public class BlobStoreUtil {
     this.executor = executor;
     this.backupMetrics = backupMetrics;
     this.restoreMetrics = restoreMetrics;
+    this.snapshotIndexSerde = new SnapshotIndexSerde();
   }
 
   /**
@@ -100,10 +103,11 @@ public class BlobStoreUtil {
    * @param taskName task name to get the store state checkpoint markers and snapshot indexes for
    * @param checkpoint {@link Checkpoint} instance to get the store state checkpoint markers from. Only
    *                   {@link CheckpointV2} and newer are supported for blob stores.
+   * @param storesToBackupOrRestore set of store names to be backed up or restored
    * @return Map of store name to its blob id of snapshot indices and their corresponding snapshot indices for the task.
    */
   public Map<String, Pair<String, SnapshotIndex>> getStoreSnapshotIndexes(
-      String jobName, String jobId, String taskName, Checkpoint checkpoint) {
+      String jobName, String jobId, String taskName, Checkpoint checkpoint, Set<String> storesToBackupOrRestore) {
     //TODO MED shesharma document error handling (checkpoint ver, blob not found, getBlob)
     if (checkpoint == null) {
       LOG.debug("No previous checkpoint found for taskName: {}", taskName);
@@ -124,21 +128,26 @@ public class BlobStoreUtil {
 
     if (storeSnapshotIndexBlobIds != null) {
       storeSnapshotIndexBlobIds.forEach((storeName, snapshotIndexBlobId) -> {
-        try {
-          LOG.debug("Getting snapshot index for taskName: {} store: {} blobId: {}", taskName, storeName, snapshotIndexBlobId);
-          Metadata requestMetadata =
-              new Metadata(Metadata.SNAPSHOT_INDEX_PAYLOAD_PATH, Optional.empty(), jobName, jobId, taskName, storeName);
-          CompletableFuture<SnapshotIndex> snapshotIndexFuture =
-              getSnapshotIndex(snapshotIndexBlobId, requestMetadata).toCompletableFuture();
-          Pair<CompletableFuture<String>, CompletableFuture<SnapshotIndex>> pairOfFutures =
-              Pair.of(CompletableFuture.completedFuture(snapshotIndexBlobId), snapshotIndexFuture);
-
-          // save the future and block once in the end instead of blocking for each request.
-          storeSnapshotIndexFutures.put(storeName, FutureUtil.toFutureOfPair(pairOfFutures));
-        } catch (Exception e) {
-          throw new SamzaException(
-              String.format("Error getting SnapshotIndex for blobId: %s for taskName: %s store: %s",
-                  snapshotIndexBlobId, taskName, storeName), e);
+        if (storesToBackupOrRestore.contains(storeName)) {
+          try {
+            LOG.debug("Getting snapshot index for taskName: {} store: {} blobId: {}", taskName, storeName, snapshotIndexBlobId);
+            Metadata requestMetadata =
+                new Metadata(Metadata.SNAPSHOT_INDEX_PAYLOAD_PATH, Optional.empty(), jobName, jobId, taskName, storeName);
+            CompletableFuture<SnapshotIndex> snapshotIndexFuture =
+                getSnapshotIndex(snapshotIndexBlobId, requestMetadata).toCompletableFuture();
+            Pair<CompletableFuture<String>, CompletableFuture<SnapshotIndex>> pairOfFutures =
+                Pair.of(CompletableFuture.completedFuture(snapshotIndexBlobId), snapshotIndexFuture);
+
+            // save the future and block once in the end instead of blocking for each request.
+            storeSnapshotIndexFutures.put(storeName, FutureUtil.toFutureOfPair(pairOfFutures));
+          } catch (Exception e) {
+            throw new SamzaException(
+                String.format("Error getting SnapshotIndex for blobId: %s for taskName: %s store: %s",
+                    snapshotIndexBlobId, taskName, storeName), e);
+          }
+        } else {
+          LOG.debug("SnapshotIndex blob id {} for store {} is not present in the set of stores to be backed up/restores: {}",
+              snapshotIndexBlobId, storeName, storesToBackupOrRestore);
         }
       });
     } else {
@@ -173,7 +182,7 @@ public class BlobStoreUtil {
     return FutureUtil.executeAsyncWithRetries(opName, () -> {
       ByteArrayOutputStream indexBlobStream = new ByteArrayOutputStream(); // no need to close ByteArrayOutputStream
       return blobStoreManager.get(blobId, indexBlobStream, metadata).toCompletableFuture()
-          .thenApplyAsync(f -> new SnapshotIndexSerde().fromBytes(indexBlobStream.toByteArray()), executor);
+          .thenApplyAsync(f -> snapshotIndexSerde.fromBytes(indexBlobStream.toByteArray()), executor);
     }, isCauseNonRetriable(), executor);
   }
 
@@ -183,7 +192,7 @@ public class BlobStoreUtil {
    * @return a Future containing the blob ID of the {@link SnapshotIndex}.
    */
   public CompletableFuture<String> putSnapshotIndex(SnapshotIndex snapshotIndex) {
-    byte[] bytes = new SnapshotIndexSerde().toBytes(snapshotIndex);
+    byte[] bytes = snapshotIndexSerde.toBytes(snapshotIndex);
     String opName = "putSnapshotIndex for checkpointId: " + snapshotIndex.getSnapshotMetadata().getCheckpointId();
     return FutureUtil.executeAsyncWithRetries(opName, () -> {
       InputStream inputStream = new ByteArrayInputStream(bytes); // no need to close ByteArrayInputStream
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
index 5b33764..67bc5f9 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
+++ b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
@@ -244,7 +244,7 @@ public class ContainerStorageManager {
     Set<String> inMemoryStoreNames = storageEngineFactories.keySet().stream()
         .filter(storeName -> {
           Optional<String> storeFactory = storageConfig.getStorageFactoryClassName(storeName);
-          return storeFactory.isPresent() && !storeFactory.get()
+          return storeFactory.isPresent() && storeFactory.get()
               .equals(StorageConfig.INMEMORY_KV_STORAGE_ENGINE_FACTORY);
         })
         .collect(Collectors.toSet());
diff --git a/samza-core/src/test/java/org/apache/samza/storage/blobstore/TestBlobStoreBackupManager.java b/samza-core/src/test/java/org/apache/samza/storage/blobstore/TestBlobStoreBackupManager.java
index 87164e8..ac64b43 100644
--- a/samza-core/src/test/java/org/apache/samza/storage/blobstore/TestBlobStoreBackupManager.java
+++ b/samza-core/src/test/java/org/apache/samza/storage/blobstore/TestBlobStoreBackupManager.java
@@ -161,7 +161,7 @@ public class TestBlobStoreBackupManager {
     // verify delete snapshot index blob called from init 0 times because prevSnapshotMap returned from init is empty
     // in case of null checkpoint.
     verify(blobStoreUtil, times(0)).deleteSnapshotIndexBlob(anyString(), any(Metadata.class));
-    when(blobStoreUtil.getStoreSnapshotIndexes(anyString(), anyString(), anyString(), any(Checkpoint.class))).thenCallRealMethod();
+    when(blobStoreUtil.getStoreSnapshotIndexes(anyString(), anyString(), anyString(), any(Checkpoint.class), anySetOf(String.class))).thenCallRealMethod();
 
     // init called with Checkpoint V1 -> unsupported
     Checkpoint checkpoint = new CheckpointV1(new HashMap<>());
@@ -288,7 +288,7 @@ public class TestBlobStoreBackupManager {
     Checkpoint checkpoint =
         new CheckpointV2(checkpointId, new HashMap<>(),
             ImmutableMap.of(BlobStoreStateBackendFactory.class.getName(), previousCheckpoints));
-    when(blobStoreUtil.getStoreSnapshotIndexes(anyString(), anyString(), anyString(), any(Checkpoint.class))).thenCallRealMethod();
+    when(blobStoreUtil.getStoreSnapshotIndexes(anyString(), anyString(), anyString(), any(Checkpoint.class), anySetOf(String.class))).thenCallRealMethod();
     blobStoreBackupManager.init(checkpoint);
 
     // mock: set task store dir to return corresponding test local store and create checkpoint dir
diff --git a/samza-core/src/test/java/org/apache/samza/storage/blobstore/util/TestBlobStoreUtil.java b/samza-core/src/test/java/org/apache/samza/storage/blobstore/util/TestBlobStoreUtil.java
index 0c00a1d..a85adef 100644
--- a/samza-core/src/test/java/org/apache/samza/storage/blobstore/util/TestBlobStoreUtil.java
+++ b/samza-core/src/test/java/org/apache/samza/storage/blobstore/util/TestBlobStoreUtil.java
@@ -21,6 +21,7 @@ package org.apache.samza.storage.blobstore.util;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.util.concurrent.MoreExecutors;
 import java.io.File;
 import java.io.FileOutputStream;
@@ -36,10 +37,12 @@ import java.nio.file.attribute.PosixFilePermissions;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Random;
+import java.util.Set;
 import java.util.SortedSet;
 import java.util.TreeSet;
 import java.util.concurrent.CompletableFuture;
@@ -794,7 +797,7 @@ public class TestBlobStoreUtil {
     BlobStoreUtil blobStoreUtil =
         new BlobStoreUtil(mock(BlobStoreManager.class), MoreExecutors.newDirectExecutorService(), null, null);
     Map<String, Pair<String, SnapshotIndex>> snapshotIndexes =
-        blobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", null);
+        blobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", null, new HashSet<>());
     assertTrue(snapshotIndexes.isEmpty());
   }
 
@@ -804,7 +807,7 @@ public class TestBlobStoreUtil {
     BlobStoreUtil blobStoreUtil =
         new BlobStoreUtil(mock(BlobStoreManager.class), MoreExecutors.newDirectExecutorService(), null, null);
     Map<String, Pair<String, SnapshotIndex>> prevSnapshotIndexes =
-        blobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", mockCheckpoint);
+        blobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", mockCheckpoint, new HashSet<>());
     assertEquals(prevSnapshotIndexes.size(), 0);
   }
 
@@ -818,7 +821,7 @@ public class TestBlobStoreUtil {
     BlobStoreUtil blobStoreUtil =
         new BlobStoreUtil(mock(BlobStoreManager.class), MoreExecutors.newDirectExecutorService(), null, null);
     Map<String, Pair<String, SnapshotIndex>> snapshotIndexes =
-        blobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", mockCheckpoint);
+        blobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", mockCheckpoint, new HashSet<>());
     assertTrue(snapshotIndexes.isEmpty());
   }
 
@@ -832,7 +835,7 @@ public class TestBlobStoreUtil {
     BlobStoreUtil blobStoreUtil =
         new BlobStoreUtil(mock(BlobStoreManager.class), MoreExecutors.newDirectExecutorService(), null, null);
     Map<String, Pair<String, SnapshotIndex>> snapshotIndexes =
-        blobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", mockCheckpoint);
+        blobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", mockCheckpoint, new HashSet<>());
     assertTrue(snapshotIndexes.isEmpty());
   }
 
@@ -840,17 +843,24 @@ public class TestBlobStoreUtil {
   public void testGetSSIThrowsExceptionOnSyncBlobStoreErrors() {
     Checkpoint checkpoint = createCheckpointV2(BlobStoreStateBackendFactory.class.getName(),
         ImmutableMap.of("storeName", "snapshotIndexBlobId"));
+    Set<String> storesToBackupOrRestore = new HashSet<>();
+    storesToBackupOrRestore.add("storeName");
     BlobStoreUtil mockBlobStoreUtil = mock(BlobStoreUtil.class);
     when(mockBlobStoreUtil.getSnapshotIndex(anyString(), any(Metadata.class))).thenThrow(new RuntimeException());
     when(mockBlobStoreUtil.getStoreSnapshotIndexes(anyString(), anyString(), anyString(),
-        any(Checkpoint.class))).thenCallRealMethod();
-    mockBlobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", checkpoint);
+        any(Checkpoint.class), anySetOf(String.class))).thenCallRealMethod();
+    mockBlobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", checkpoint, storesToBackupOrRestore);
   }
 
   @Test
   public void testGetSSISkipsStoresWithSnapshotIndexAlreadyDeleted() {
+    String store = "storeName1";
+    String otherStore = "storeName2";
     Checkpoint checkpoint = createCheckpointV2(BlobStoreStateBackendFactory.class.getName(),
-        ImmutableMap.of("storeName1", "snapshotIndexBlobId1", "storeName2", "snapshotIndexBlobId2"));
+        ImmutableMap.of(store, "snapshotIndexBlobId1", otherStore, "snapshotIndexBlobId2"));
+    Set<String> storesToBackupOrRestore = new HashSet<>();
+    storesToBackupOrRestore.add(store);
+    storesToBackupOrRestore.add(otherStore);
     SnapshotIndex store1SnapshotIndex = mock(SnapshotIndex.class);
     BlobStoreUtil mockBlobStoreUtil = mock(BlobStoreUtil.class);
 
@@ -859,10 +869,10 @@ public class TestBlobStoreUtil {
         CompletableFuture.completedFuture(store1SnapshotIndex));
     when(mockBlobStoreUtil.getSnapshotIndex(eq("snapshotIndexBlobId2"), any(Metadata.class))).thenReturn(failedFuture);
     when(mockBlobStoreUtil.getStoreSnapshotIndexes(anyString(), anyString(), anyString(),
-        any(Checkpoint.class))).thenCallRealMethod();
+        any(Checkpoint.class), anySetOf(String.class))).thenCallRealMethod();
 
     Map<String, Pair<String, SnapshotIndex>> snapshotIndexes =
-        mockBlobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", checkpoint);
+        mockBlobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", checkpoint, storesToBackupOrRestore);
     assertEquals(1, snapshotIndexes.size());
     assertEquals("snapshotIndexBlobId1", snapshotIndexes.get("storeName1").getLeft());
     assertEquals(store1SnapshotIndex, snapshotIndexes.get("storeName1").getRight());
@@ -870,12 +880,17 @@ public class TestBlobStoreUtil {
 
   @Test
   public void testGetSSIThrowsExceptionIfAnyNonIgnoredAsyncBlobStoreErrors() {
+    String store = "storeName1";
+    String otherStore = "storeName2";
+    Set<String> storesToBackupOrRestore = new HashSet<>();
+    storesToBackupOrRestore.add(store);
+    storesToBackupOrRestore.add(otherStore);
     Checkpoint checkpoint = createCheckpointV2(BlobStoreStateBackendFactory.class.getName(),
-        ImmutableMap.of("storeName1", "snapshotIndexBlobId1", "storeName2", "snapshotIndexBlobId2"));
+        ImmutableMap.of(store, "snapshotIndexBlobId1", otherStore, "snapshotIndexBlobId2"));
     SnapshotIndex store1SnapshotIndex = mock(SnapshotIndex.class);
     BlobStoreUtil mockBlobStoreUtil = mock(BlobStoreUtil.class);
     when(mockBlobStoreUtil.getStoreSnapshotIndexes(anyString(), anyString(), anyString(),
-        any(Checkpoint.class))).thenCallRealMethod();
+        any(Checkpoint.class), anySetOf(String.class))).thenCallRealMethod();
     RuntimeException nonIgnoredException = new RuntimeException();
     CompletableFuture<SnapshotIndex> failedFuture = FutureUtil.failedFuture(nonIgnoredException);
     when(mockBlobStoreUtil.getSnapshotIndex(eq("snapshotIndexBlobId1"), any(Metadata.class))).thenReturn(
@@ -883,7 +898,7 @@ public class TestBlobStoreUtil {
     when(mockBlobStoreUtil.getSnapshotIndex(eq("snapshotIndexBlobId2"), any(Metadata.class))).thenReturn(failedFuture);
 
     try {
-      mockBlobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", checkpoint);
+      mockBlobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", checkpoint, storesToBackupOrRestore);
       fail("Should have thrown an exception");
     } catch (Exception e) {
       Throwable cause =
@@ -896,6 +911,7 @@ public class TestBlobStoreUtil {
   public void testGetSSIReturnsCorrectSCMSnapshotIndexPair() {
     String storeName = "storeName";
     String otherStoreName = "otherStoreName";
+    Set<String> storesToBackupOrRestore = ImmutableSet.of(storeName, otherStoreName);
     String storeSnapshotIndexBlobId = "snapshotIndexBlobId";
     String otherStoreSnapshotIndexBlobId = "otherSnapshotIndexBlobId";
     SnapshotIndex mockStoreSnapshotIndex = mock(SnapshotIndex.class);
@@ -911,10 +927,10 @@ public class TestBlobStoreUtil {
     when(mockBlobStoreUtil.getSnapshotIndex(eq(otherStoreSnapshotIndexBlobId), any(Metadata.class))).thenReturn(
         CompletableFuture.completedFuture(mockOtherStooreSnapshotIndex));
     when(mockBlobStoreUtil.getStoreSnapshotIndexes(anyString(), anyString(), anyString(),
-        any(Checkpoint.class))).thenCallRealMethod();
+        any(Checkpoint.class), anySetOf(String.class))).thenCallRealMethod();
 
     Map<String, Pair<String, SnapshotIndex>> snapshotIndexes =
-        mockBlobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", checkpoint);
+        mockBlobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", checkpoint, storesToBackupOrRestore);
 
     assertEquals(storeSnapshotIndexBlobId, snapshotIndexes.get(storeName).getKey());
     assertEquals(mockStoreSnapshotIndex, snapshotIndexes.get(storeName).getValue());
@@ -923,6 +939,27 @@ public class TestBlobStoreUtil {
     verify(mockBlobStoreUtil, times(2)).getSnapshotIndex(anyString(), any(Metadata.class));
   }
 
+  @Test
+  public void testGetCheckpointIndexIgnoresStoresNotInStoresToBackupRestoreSet() {
+    String store = "storeName1";
+    String anotherStore = "storeName2";
+    String oneMoreStore = "storeName3";
+    SnapshotIndex mockStoreSnapshotIndex = mock(SnapshotIndex.class);
+    Set<String> storesToBackupOrRestore = ImmutableSet.of(store, anotherStore);
+    CheckpointV2 checkpoint = createCheckpointV2(BlobStoreStateBackendFactory.class.getName(),
+        ImmutableMap.of(store, "1", anotherStore, "2", oneMoreStore, "3"));
+    BlobStoreUtil mockBlobStoreUtil = mock(BlobStoreUtil.class);
+    when(mockBlobStoreUtil.getSnapshotIndex(any(String.class), any(Metadata.class)))
+        .thenReturn(CompletableFuture.completedFuture(mockStoreSnapshotIndex));
+    when(mockBlobStoreUtil.getStoreSnapshotIndexes(anyString(), anyString(), anyString(),
+        any(Checkpoint.class), anySetOf(String.class))).thenCallRealMethod();
+
+    Map<String, Pair<String, SnapshotIndex>> snapshotIndexes =
+        mockBlobStoreUtil.getStoreSnapshotIndexes("testJobName", "testJobId", "taskName", checkpoint, storesToBackupOrRestore);
+
+    verify(mockBlobStoreUtil, times(storesToBackupOrRestore.size())).getSnapshotIndex(anyString(), any(Metadata.class));
+  }
+
   private CheckpointV2 createCheckpointV2(String stateBackendFactory, Map<String, String> storeSnapshotIndexBlobIds) {
     CheckpointId checkpointId = CheckpointId.create();
     Map<String, Map<String, String>> factoryStoreSCMs = new HashMap<>();