You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by vv...@apache.org on 2020/03/14 03:57:30 UTC

[kafka] branch trunk updated: KAFKA-6145: Pt 2. Include offset sums in subscription (#8246)

This is an automated email from the ASF dual-hosted git repository.

vvcephei pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 542853d  KAFKA-6145: Pt 2. Include offset sums in subscription (#8246)
542853d is described below

commit 542853d99b9e0d660a9cf9317be8a3f8fce4c765
Author: A. Sophie Blee-Goldman <so...@confluent.io>
AuthorDate: Fri Mar 13 20:56:59 2020 -0700

    KAFKA-6145: Pt 2. Include offset sums in subscription (#8246)
    
    KIP-441 Pt. 2: Compute sum of offsets across all stores/changelogs in a task and include them in the subscription.
    
    Previously each thread would just encode every task on disk, but we now need to read the changelog file which is unsafe to do without a lock on the task directory. So, each thread now encodes only its assigned active and standby tasks, and ignores any already-locked tasks.
    
    In some cases there may be unowned and unlocked tasks on disk that were reassigned to another instance and haven't been cleaned up yet by the background thread. Each StreamThread makes a weak effort to lock any such task directories it finds, and if successful is then responsible for computing and reporting that task's offset sum (based on reading the checkpoint file)
    
    This PR therefore also addresses two orthogonal issues:
    
    1. Prevent background cleaner thread from deleting unowned stores during a rebalance
    2. Deduplicate standby tasks in subscription: each thread used to include every (non-active) task found on disk in its "standby task" set, which meant every active, standby, and unowned task was encoded by every thread.
    
    Reviewers: Bruno Cadonna <br...@confluent.io>, John Roesler <vv...@apache.org>
---
 .../processor/internals/ProcessorStateManager.java |   2 +-
 .../processor/internals/StateDirectory.java        |  24 +-
 .../streams/processor/internals/StreamTask.java    |   8 +-
 .../kafka/streams/processor/internals/Task.java    |   8 +-
 .../streams/processor/internals/TaskManager.java   | 132 +++++++--
 .../internals/assignment/SubscriptionInfo.java     |  37 ++-
 .../processor/internals/ActiveTaskCreatorTest.java |   4 +-
 .../processor/internals/StreamTaskTest.java        |   2 +-
 .../processor/internals/TaskManagerTest.java       | 307 +++++++++++++++++++--
 .../internals/assignment/SubscriptionInfoTest.java |  38 ++-
 10 files changed, 475 insertions(+), 87 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
index 1317aa9..6553df4 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
@@ -180,7 +180,7 @@ public class ProcessorStateManager implements StateManager {
         this.storeToChangelogTopic = storeToChangelogTopic;
 
         this.baseDir = stateDirectory.directoryForTask(taskId);
-        this.checkpointFile = new OffsetCheckpoint(new File(baseDir, CHECKPOINT_FILE_NAME));
+        this.checkpointFile = new OffsetCheckpoint(stateDirectory.checkpointFileFor(taskId));
 
         log.debug("Created state store manager for task {}", taskId);
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateDirectory.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateDirectory.java
index 2068678..71b8b95 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateDirectory.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateDirectory.java
@@ -108,6 +108,13 @@ public class StateDirectory {
     }
 
     /**
+     * @return The File handle for the checkpoint in the given task's directory
+     */
+    File checkpointFileFor(final TaskId taskId) {
+        return new File(directoryForTask(taskId), StateManagerUtil.CHECKPOINT_FILE_NAME);
+    }
+
+    /**
      * Decide if the directory of the task is empty or not
      */
     boolean directoryForTaskIsEmpty(final TaskId taskId) {
@@ -285,12 +292,7 @@ public class StateDirectory {
 
     private synchronized void cleanRemovedTasks(final long cleanupDelayMs,
                                                 final boolean manualUserCall) throws Exception {
-        final File[] taskDirs = listTaskDirectories();
-        if (taskDirs == null || taskDirs.length == 0) {
-            return; // nothing to do
-        }
-
-        for (final File taskDir : taskDirs) {
+        for (final File taskDir : listTaskDirectories()) {
             final String dirName = taskDir.getName();
             final TaskId id = TaskId.parse(dirName);
             if (!locks.containsKey(id)) {
@@ -347,8 +349,14 @@ public class StateDirectory {
      * @return The list of all the existing local directories for stream tasks
      */
     File[] listTaskDirectories() {
-        return !stateDir.exists() ? new File[0] :
-                stateDir.listFiles(pathname -> pathname.isDirectory() && PATH_NAME.matcher(pathname.getName()).matches());
+        final File[] taskDirectories =
+            stateDir.listFiles(pathname -> pathname.isDirectory() && PATH_NAME.matcher(pathname.getName()).matches());
+
+        if (!stateDir.exists() || taskDirectories == null) {
+            return new File[0];
+        } else {
+            return taskDirectories;
+        }
     }
 
     private FileChannel getOrCreateFileChannel(final TaskId taskId,
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index ba98bec..64adb6a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -671,16 +671,16 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     }
 
     @Override
-    public Map<TopicPartition, Long> purgableOffsets() {
-        final Map<TopicPartition, Long> purgableConsumedOffsets = new HashMap<>();
+    public Map<TopicPartition, Long> purgeableOffsets() {
+        final Map<TopicPartition, Long> purgeableConsumedOffsets = new HashMap<>();
         for (final Map.Entry<TopicPartition, Long> entry : consumedOffsets.entrySet()) {
             final TopicPartition tp = entry.getKey();
             if (topology.isRepartitionTopic(tp.topic())) {
-                purgableConsumedOffsets.put(tp, entry.getValue() + 1);
+                purgeableConsumedOffsets.put(tp, entry.getValue() + 1);
             }
         }
 
-        return purgableConsumedOffsets;
+        return purgeableConsumedOffsets;
     }
 
     private void initializeTopology() {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
index 02f9926..2bdce69 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
@@ -32,7 +32,9 @@ import java.util.Map;
 import java.util.Set;
 
 public interface Task {
-    // this must be negative to distinguish a running active task from other kinds tasks which may be caught up to the same offsets
+
+    // this must be negative to distinguish a running active task from other kinds of tasks
+    // which may be caught up to the same offsets
     long LATEST_OFFSET = -2L;
 
     /*
@@ -176,7 +178,7 @@ public interface Task {
 
     void markChangelogAsCorrupted(final Collection<TopicPartition> partitions);
 
-    default Map<TopicPartition, Long> purgableOffsets() {
+    default Map<TopicPartition, Long> purgeableOffsets() {
         return Collections.emptyMap();
     }
 
@@ -195,6 +197,4 @@ public interface Task {
     default boolean maybePunctuateSystemTime() {
         return false;
     }
-
-
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index ba75f86..2c80c34 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.io.IOException;
+import java.util.Collections;
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.admin.DeleteRecordsResult;
 import org.apache.kafka.clients.admin.RecordsToDelete;
@@ -32,6 +34,7 @@ import org.apache.kafka.streams.errors.TaskIdFormatException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
 import org.slf4j.Logger;
 
 import java.io.File;
@@ -79,6 +82,9 @@ public class TaskManager {
 
     private boolean rebalanceInProgress = false;  // if we are in the middle of a rebalance, it is not safe to commit
 
+    // includes assigned & initialized tasks and unassigned tasks we locked temporarily during rebalance
+    private Set<TaskId> lockedTaskDirectories = new HashSet<>();
+
     TaskManager(final ChangelogReader changelogReader,
                 final UUID processId,
                 final String logPrefix,
@@ -121,6 +127,8 @@ public class TaskManager {
     void handleRebalanceStart(final Set<String> subscribedTopics) {
         builder.addSubscribedTopicsFromMetadata(subscribedTopics, logPrefix);
 
+        tryToLockAllTaskDirectories();
+
         rebalanceInProgress = true;
     }
 
@@ -129,6 +137,8 @@ public class TaskManager {
         // before then the assignment has not been updated yet.
         mainConsumer.pause(mainConsumer.assignment());
 
+        releaseLockedUnassignedTaskDirectories();
+
         rebalanceInProgress = false;
     }
 
@@ -368,50 +378,105 @@ public class TaskManager {
     }
 
     /**
+     * Compute the offset total summed across all stores in a task. Includes offset sum for any tasks we own the
+     * lock for, which includes assigned and unassigned tasks we locked in {@link #tryToLockAllTaskDirectories()}
+     *
      * @return Map from task id to its total offset summed across all state stores
      */
     public Map<TaskId, Long> getTaskOffsetSums() {
         final Map<TaskId, Long> taskOffsetSums = new HashMap<>();
 
-        for (final TaskId id : tasksOnLocalStorage()) {
-            if (isRunning(id)) {
-                taskOffsetSums.put(id, Task.LATEST_OFFSET);
+        for (final TaskId id : lockedTaskDirectories) {
+            final Task task = tasks.get(id);
+            if (task != null) {
+                if (task.isActive() && task.state() == RUNNING) {
+                    taskOffsetSums.put(id, Task.LATEST_OFFSET);
+                } else {
+                    taskOffsetSums.put(id, sumOfChangelogOffsets(id, task.changelogOffsets()));
+                }
             } else {
-                taskOffsetSums.put(id, 0L);
+                final File checkpointFile = stateDirectory.checkpointFileFor(id);
+                try {
+                    if (checkpointFile.exists()) {
+                        taskOffsetSums.put(id, sumOfChangelogOffsets(id, new OffsetCheckpoint(checkpointFile).read()));
+                    }
+                } catch (final IOException e) {
+                    log.warn(String.format("Exception caught while trying to read checkpoint for task %s:", id), e);
+                }
             }
         }
+
         return taskOffsetSums;
     }
 
     /**
-     * Returns ids of tasks whose states are kept on the local storage. This includes active, standby, and previously
-     * assigned but not yet cleaned up tasks
+     * Makes a weak attempt to lock all task directories in the state dir. We are responsible for computing and
+     * reporting the offset sum for any unassigned tasks we obtain the lock for in the upcoming rebalance. Tasks
+     * that we locked but didn't own will be released at the end of the rebalance (unless of course we were
+     * assigned the task as a result of the rebalance). This method should be idempotent.
      */
-    private Set<TaskId> tasksOnLocalStorage() {
-        // A client could contain some inactive tasks whose states are still kept on the local storage in the following scenarios:
-        // 1) the client is actively maintaining standby tasks by maintaining their states from the change log.
-        // 2) the client has just got some tasks migrated out of itself to other clients while these task states
-        //    have not been cleaned up yet (this can happen in a rolling bounce upgrade, for example).
+    private void tryToLockAllTaskDirectories() {
+        for (final File dir : stateDirectory.listTaskDirectories()) {
+            try {
+                final TaskId id = TaskId.parse(dir.getName());
+                try {
+                    if (stateDirectory.lock(id)) {
+                        lockedTaskDirectories.add(id);
+                        if (!tasks.containsKey(id)) {
+                            log.debug("Temporarily locked unassigned task {} for the upcoming rebalance", id);
+                        }
+                    }
+                } catch (final IOException e) {
+                    // if for any reason we can't lock this task dir, just move on
+                    log.warn(String.format("Exception caught while attempting to lock task %s:", id), e);
+                }
+            } catch (final TaskIdFormatException e) {
+                // ignore any unknown files that sit in the same directory
+            }
+        }
+    }
 
-        final Set<TaskId> locallyStoredTasks = new HashSet<>();
+    /**
+     * We must release the lock for any unassigned tasks that we temporarily locked in preparation for a
+     * rebalance in {@link #tryToLockAllTaskDirectories()}.
+     */
+    private void releaseLockedUnassignedTaskDirectories() {
+        final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
 
-        final File[] stateDirs = stateDirectory.listTaskDirectories();
-        if (stateDirs != null) {
-            for (final File dir : stateDirs) {
+        final Iterator<TaskId> taskIdIterator = lockedTaskDirectories.iterator();
+        while (taskIdIterator.hasNext()) {
+            final TaskId id = taskIdIterator.next();
+            if (!tasks.containsKey(id)) {
                 try {
-                    final TaskId id = TaskId.parse(dir.getName());
-                    // if the checkpoint file exists, the state is valid.
-                    if (new File(dir, StateManagerUtil.CHECKPOINT_FILE_NAME).exists()) {
-                        locallyStoredTasks.add(id);
-                    }
-                } catch (final TaskIdFormatException e) {
-                    // there may be some unknown files that sits in the same directory,
-                    // we should ignore these files instead trying to delete them as well
+                    stateDirectory.unlock(id);
+                    taskIdIterator.remove();
+                } catch (final IOException e) {
+                    log.error(String.format("Caught the following exception while trying to unlock task %s", id), e);
+                    firstException.compareAndSet(null,
+                        new StreamsException(String.format("Failed to unlock task directory %s", id), e));
                 }
             }
         }
 
-        return locallyStoredTasks;
+        final RuntimeException fatalException = firstException.get();
+        if (fatalException != null) {
+            throw fatalException;
+        }
+    }
+
+    private long sumOfChangelogOffsets(final TaskId id, final Map<TopicPartition, Long> changelogOffsets) {
+        long offsetSum = 0L;
+        for (final Map.Entry<TopicPartition, Long> changelogEntry : changelogOffsets.entrySet()) {
+            final long offset = changelogEntry.getValue();
+
+            offsetSum += offset;
+            if (offsetSum < 0) {
+                log.warn("Sum of changelog offsets for task {} overflowed, pinning to Long.MAX_VALUE", id);
+                return Long.MAX_VALUE;
+            }
+        }
+
+        return offsetSum;
     }
 
     private void cleanupTask(final Task task) {
@@ -474,6 +539,14 @@ public class TaskManager {
             }
         }
 
+        try {
+            // this should be called after closing all tasks, to make sure we unlock the task dir for tasks that may
+            // have still been in CREATED at the time of shutdown, since Task#close will not do so
+            releaseLockedUnassignedTaskDirectories();
+        } catch (final RuntimeException e) {
+            firstException.compareAndSet(null, e);
+        }
+
         final RuntimeException fatalException = firstException.get();
         if (fatalException != null) {
             throw new RuntimeException("Unexpected exception while closing task", fatalException);
@@ -522,11 +595,6 @@ public class TaskManager {
         return tasks.values().stream().filter(t -> !t.isActive());
     }
 
-    private boolean isRunning(final TaskId id) {
-        final Task task = tasks.get(id);
-        return task != null && task.isActive() && task.state() == RUNNING;
-    }
-
     /**
      * @throws TaskMigratedException if committing offsets failed (non-EOS)
      *                               or if the task producer got fenced (EOS)
@@ -629,7 +697,7 @@ public class TaskManager {
 
             final Map<TopicPartition, RecordsToDelete> recordsToDelete = new HashMap<>();
             for (final Task task : activeTaskIterable()) {
-                for (final Map.Entry<TopicPartition, Long> entry : task.purgableOffsets().entrySet()) {
+                for (final Map.Entry<TopicPartition, Long> entry : task.purgeableOffsets().entrySet()) {
                     recordsToDelete.put(entry.getKey(), RecordsToDelete.beforeOffset(entry.getValue()));
                 }
             }
@@ -676,4 +744,8 @@ public class TaskManager {
     Set<String> producerClientIds() {
         return activeTaskCreator.producerClientIds();
     }
+
+    Set<TaskId> lockedTaskDirectories() {
+        return Collections.unmodifiableSet(lockedTaskDirectories);
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
index e718b1c..411ff04 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
@@ -45,6 +45,8 @@ public class SubscriptionInfo {
     private static final Logger LOG = LoggerFactory.getLogger(SubscriptionInfo.class);
 
     static final int UNKNOWN = -1;
+    static final int MIN_VERSION_OFFSET_SUM_SUBSCRIPTION = 7;
+    static final long UNKNOWN_OFFSET_SUM = -3L;
 
     private final SubscriptionInfoData data;
     private Set<TaskId> prevTasksCache = null;
@@ -96,7 +98,7 @@ public class SubscriptionInfo {
 
         this.data = data;
 
-        if (version >= 7) {
+        if (version >= MIN_VERSION_OFFSET_SUM_SUBSCRIPTION) {
             setTaskOffsetSumDataFromTaskOffsetSumMap(taskOffsetSums);
         } else {
             setPrevAndStandbySetsFromParsedTaskOffsetSumMap(taskOffsetSums);
@@ -112,11 +114,10 @@ public class SubscriptionInfo {
         final Map<Integer, List<SubscriptionInfoData.PartitionToOffsetSum>> topicGroupIdToPartitionOffsetSum = new HashMap<>();
         for (final Map.Entry<TaskId, Long> taskEntry : taskOffsetSums.entrySet()) {
             final TaskId task = taskEntry.getKey();
-            topicGroupIdToPartitionOffsetSum.putIfAbsent(task.topicGroupId, new ArrayList<>());
-            topicGroupIdToPartitionOffsetSum.get(task.topicGroupId).add(
+            topicGroupIdToPartitionOffsetSum.computeIfAbsent(task.topicGroupId, t -> new ArrayList<>()).add(
                 new SubscriptionInfoData.PartitionToOffsetSum()
-                         .setPartition(task.partition)
-                         .setOffsetSum(taskEntry.getValue()));
+                    .setPartition(task.partition)
+                    .setOffsetSum(taskEntry.getValue()));
         }
 
         data.setTaskOffsetSums(topicGroupIdToPartitionOffsetSum.entrySet().stream().map(t -> {
@@ -167,7 +168,7 @@ public class SubscriptionInfo {
 
     public Set<TaskId> prevTasks() {
         if (prevTasksCache == null) {
-            if (data.version() >= 7) {
+            if (data.version() >= MIN_VERSION_OFFSET_SUM_SUBSCRIPTION) {
                 prevTasksCache = getActiveTasksFromTaskOffsetSumMap(taskOffsetSums());
             } else {
                 prevTasksCache = Collections.unmodifiableSet(
@@ -183,7 +184,7 @@ public class SubscriptionInfo {
 
     public Set<TaskId> standbyTasks() {
         if (standbyTasksCache == null) {
-            if (data.version() >= 7) {
+            if (data.version() >= MIN_VERSION_OFFSET_SUM_SUBSCRIPTION) {
                 standbyTasksCache = getStandbyTasksFromTaskOffsetSumMap(taskOffsetSums());
             } else {
                 standbyTasksCache = Collections.unmodifiableSet(
@@ -197,14 +198,22 @@ public class SubscriptionInfo {
         return standbyTasksCache;
     }
 
-    public Map<TaskId, Long> taskOffsetSums() {
+    Map<TaskId, Long> taskOffsetSums() {
         if (taskOffsetSumsCache == null) {
             taskOffsetSumsCache = new HashMap<>();
-            for (final TaskOffsetSum topicGroup : data.taskOffsetSums()) {
-                for (final PartitionToOffsetSum partitionOffsetSum : topicGroup.partitionToOffsetSum()) {
-                    taskOffsetSumsCache.put(new TaskId(topicGroup.topicGroupId(), partitionOffsetSum.partition()),
-                                            partitionOffsetSum.offsetSum());
+            if (data.version() >= MIN_VERSION_OFFSET_SUM_SUBSCRIPTION) {
+                for (final TaskOffsetSum taskOffsetSum : data.taskOffsetSums()) {
+                    for (final PartitionToOffsetSum partitionOffsetSum : taskOffsetSum.partitionToOffsetSum()) {
+                        taskOffsetSumsCache.put(
+                            new TaskId(taskOffsetSum.topicGroupId(),
+                                       partitionOffsetSum.partition()),
+                            partitionOffsetSum.offsetSum()
+                        );
+                    }
                 }
+            } else {
+                prevTasks().forEach(taskId -> taskOffsetSumsCache.put(taskId, Task.LATEST_OFFSET));
+                standbyTasks().forEach(taskId -> taskOffsetSumsCache.put(taskId, UNKNOWN_OFFSET_SUM));
             }
         }
         return taskOffsetSumsCache;
@@ -266,7 +275,9 @@ public class SubscriptionInfo {
             subscriptionInfoData.setVersion(version);
             subscriptionInfoData.setLatestSupportedVersion(latestSupportedVersion);
             LOG.info("Unable to decode subscription data: used version: {}; latest supported version: {}",
-                     version, latestSupportedVersion);
+                version,
+                latestSupportedVersion
+            );
             return new SubscriptionInfo(subscriptionInfoData);
         } else {
             data.rewind();
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java
index 7fc351b..0f9ecdf 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.io.File;
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.MockConsumer;
@@ -116,10 +117,11 @@ public class ActiveTaskCreatorTest {
         expect(config.getInt(anyString())).andReturn(0);
         expect(config.getProducerConfigs(anyString())).andReturn(new HashMap<>());
         expect(builder.buildSubtopology(taskId.topicGroupId)).andReturn(topology);
+        expect(stateDirectory.directoryForTask(taskId)).andReturn(new File(taskId.toString()));
         expect(topology.storeToChangelogTopic()).andReturn(Collections.emptyMap());
         expect(topology.source("topic")).andReturn(mock(SourceNode.class));
         expect(topology.globalStateStores()).andReturn(Collections.emptyList());
-        replay(config, builder, topology);
+        replay(config, builder, stateDirectory, topology);
 
         mockClientSupplier.setApplicationIdForProducer("appId");
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index c2b5e77..dc0404b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -1264,7 +1264,7 @@ public class StreamTaskTest {
 
         task.commit();
 
-        final Map<TopicPartition, Long> map = task.purgableOffsets();
+        final Map<TopicPartition, Long> map = task.purgeableOffsets();
 
         assertThat(map, equalTo(Collections.singletonMap(repartition, 11L)));
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 1a8db34..62043e3 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.util.HashSet;
+import java.util.stream.Collectors;
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.admin.DeleteRecordsResult;
 import org.apache.kafka.clients.admin.DeletedRecords;
@@ -39,6 +41,7 @@ import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
+import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
 import org.easymock.EasyMock;
 import org.easymock.EasyMockRunner;
 import org.easymock.Mock;
@@ -77,13 +80,16 @@ import static org.easymock.EasyMock.eq;
 import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.expectLastCall;
 import static org.easymock.EasyMock.replay;
+import static org.easymock.EasyMock.reset;
 import static org.easymock.EasyMock.resetToStrict;
 import static org.easymock.EasyMock.verify;
 import static org.hamcrest.CoreMatchers.hasItem;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
 import static org.hamcrest.core.IsEqual.equalTo;
 import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
 
 @RunWith(EasyMockRunner.class)
 public class TaskManagerTest {
@@ -104,9 +110,11 @@ public class TaskManagerTest {
     private final TopicPartition t1p2 = new TopicPartition(topic1, 2);
     private final Set<TopicPartition> taskId02Partitions = mkSet(t1p2);
 
+    private final TaskId taskId10 = new TaskId(1, 0);
+
     @Mock(type = MockType.STRICT)
     private InternalTopologyBuilder topologyBuilder;
-    @Mock(type = MockType.NICE)
+    @Mock(type = MockType.DEFAULT)
     private StateDirectory stateDirectory;
     @Mock(type = MockType.NICE)
     private ChangelogReader changeLogReader;
@@ -159,26 +167,176 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldReturnOffsetsForAllCachedTaskIdsFromDirectory() throws IOException {
-        final File[] taskFolders = asList(testFolder.newFolder("0_1"),
-                                          testFolder.newFolder("0_2"),
-                                          testFolder.newFolder("0_3"),
-                                          testFolder.newFolder("1_1"),
-                                          testFolder.newFolder("dummy")).toArray(new File[0]);
+    public void shouldNotLockAnythingIfStateDirIsEmpty() {
+        expect(stateDirectory.listTaskDirectories()).andReturn(new File[0]).once();
 
-        assertThat((new File(taskFolders[0], StateManagerUtil.CHECKPOINT_FILE_NAME)).createNewFile(), is(true));
-        assertThat((new File(taskFolders[1], StateManagerUtil.CHECKPOINT_FILE_NAME)).createNewFile(), is(true));
-        assertThat((new File(taskFolders[3], StateManagerUtil.CHECKPOINT_FILE_NAME)).createNewFile(), is(true));
+        replay(stateDirectory);
+        taskManager.handleRebalanceStart(singleton("topic"));
 
-        expect(stateDirectory.listTaskDirectories()).andReturn(taskFolders).once();
+        verify(stateDirectory);
+        assertTrue(taskManager.lockedTaskDirectories().isEmpty());
+    }
+
+    @Test
+    public void shouldTryToLockValidTaskDirsAtRebalanceStart() throws IOException {
+        expectLockObtainedFor(taskId01);
+        expectLockFailedFor(taskId10);
+
+        makeTaskFolders(
+            taskId01.toString(),
+            taskId10.toString(),
+            "dummy"
+        );
+        replay(stateDirectory);
+        taskManager.handleRebalanceStart(singleton("topic"));
+
+        verify(stateDirectory);
+        assertThat(taskManager.lockedTaskDirectories(), is(singleton(taskId01)));
+    }
+
+    @Test
+    public void shouldReleaseLockForUnassignedTasksAfterRebalance() throws IOException {
+        expectLockObtainedFor(taskId00, taskId01, taskId02);
+        expectUnlockFor(taskId02);
+
+        makeTaskFolders(
+            taskId00.toString(),  // active task
+            taskId01.toString(),  // standby task
+            taskId02.toString()   // unassigned but able to lock
+        );
+        replay(stateDirectory);
+        taskManager.handleRebalanceStart(singleton("topic"));
 
-        replay(activeTaskCreator, stateDirectory);
+        assertThat(taskManager.lockedTaskDirectories(), is(mkSet(taskId00, taskId01, taskId02)));
 
-        final Map<TaskId, Long> taskOffsetSums = taskManager.getTaskOffsetSums();
+        handleAssignment(taskId00Assignment, taskId01Assignment, emptyMap());
+        reset(consumer);
+        expectConsumerAssignmentPaused(consumer);
+        replay(consumer);
+
+        taskManager.handleRebalanceComplete();
+        assertThat(taskManager.lockedTaskDirectories(), is(mkSet(taskId00, taskId01)));
+        verify(stateDirectory);
+    }
+
+    @Test
+    public void shouldReportLatestOffsetAsOffsetSumForRunningTask() throws IOException {
+        final Map<TaskId, Long> expectedOffsetSums = mkMap(mkEntry(taskId00, Task.LATEST_OFFSET));
 
-        verify(activeTaskCreator, stateDirectory);
+        expectLockObtainedFor(taskId00);
+        makeTaskFolders(taskId00.toString());
+        replay(stateDirectory);
 
-        assertThat(taskOffsetSums.keySet(), equalTo(mkSet(taskId01, taskId02, new TaskId(1, 1))));
+        taskManager.handleRebalanceStart(singleton("topic"));
+        handleAssignment(taskId00Assignment, emptyMap(), emptyMap());
+
+        assertThat(taskManager.getTaskOffsetSums(), is(expectedOffsetSums));
+    }
+
+    @Test
+    public void shouldComputeOffsetSumForNonRunningActiveTask() throws IOException {
+        final Map<TopicPartition, Long> changelogOffsets = mkMap(
+            mkEntry(new TopicPartition("changelog", 0), 5L),
+            mkEntry(new TopicPartition("changelog", 1), 10L)
+        );
+        final Map<TaskId, Long> expectedOffsetSums = mkMap(mkEntry(taskId00, 15L));
+
+        expectLockObtainedFor(taskId00);
+        makeTaskFolders(taskId00.toString());
+        replay(stateDirectory);
+
+        taskManager.handleRebalanceStart(singleton("topic"));
+        final StateMachineTask restoringTask = handleAssignment(
+            emptyMap(),
+            emptyMap(),
+            taskId00Assignment
+        ).get(taskId00);
+        restoringTask.setChangelogOffsets(changelogOffsets);
+
+        assertThat(taskManager.getTaskOffsetSums(), is(expectedOffsetSums));
+    }
+
+    @Test
+    public void shouldComputeOffsetSumForStandbyTask() throws IOException {
+        final Map<TopicPartition, Long> changelogOffsets = mkMap(
+            mkEntry(new TopicPartition("changelog", 0), 5L),
+            mkEntry(new TopicPartition("changelog", 1), 10L)
+        );
+        final Map<TaskId, Long> expectedOffsetSums = mkMap(mkEntry(taskId00, 15L));
+
+        expectLockObtainedFor(taskId00);
+        makeTaskFolders(taskId00.toString());
+        replay(stateDirectory);
+
+        taskManager.handleRebalanceStart(singleton("topic"));
+        final StateMachineTask restoringTask = handleAssignment(
+            emptyMap(),
+            taskId00Assignment,
+            emptyMap()
+        ).get(taskId00);
+        restoringTask.setChangelogOffsets(changelogOffsets);
+
+        assertThat(taskManager.getTaskOffsetSums(), is(expectedOffsetSums));
+    }
+
+    @Test
+    public void shouldComputeOffsetSumForUnassignedTaskWeCanLock() throws IOException {
+        final Map<TopicPartition, Long> changelogOffsets = mkMap(
+            mkEntry(new TopicPartition("changelog", 0), 5L),
+            mkEntry(new TopicPartition("changelog", 1), 10L)
+        );
+        final Map<TaskId, Long> expectedOffsetSums = mkMap(mkEntry(taskId00, 15L));
+
+        expectLockObtainedFor(taskId00);
+        makeTaskFolders(taskId00.toString());
+        writeCheckpointFile(taskId00, changelogOffsets);
+
+        replay(stateDirectory);
+        taskManager.handleRebalanceStart(singleton("topic"));
+
+        assertThat(taskManager.getTaskOffsetSums(), is(expectedOffsetSums));
+    }
+
+    @Test
+    public void shouldNotReportOffsetSumsForTaskWeCantLock() throws IOException {
+        expectLockFailedFor(taskId00);
+        makeTaskFolders(taskId00.toString());
+        replay(stateDirectory);
+        taskManager.handleRebalanceStart(singleton("topic"));
+        assertTrue(taskManager.lockedTaskDirectories().isEmpty());
+
+        assertTrue(taskManager.getTaskOffsetSums().isEmpty());
+    }
+
+    @Test
+    public void shouldNotReportOffsetSumsAndReleaseLockForUnassignedTaskWithoutCheckpoint() throws IOException {
+        expectLockObtainedFor(taskId00);
+        makeTaskFolders(taskId00.toString());
+        expect(stateDirectory.checkpointFileFor(taskId00)).andReturn(getCheckpointFile(taskId00));
+        replay(stateDirectory);
+        taskManager.handleRebalanceStart(singleton("topic"));
+
+        assertTrue(taskManager.getTaskOffsetSums().isEmpty());
+        verify(stateDirectory);
+    }
+
+    @Test
+    public void shouldPinOffsetSumToLongMaxValueInCaseOfOverflow() throws IOException {
+        final long largeOffset = Long.MAX_VALUE / 2;
+        final Map<TopicPartition, Long> changelogOffsets = mkMap(
+            mkEntry(new TopicPartition("changelog", 1), largeOffset),
+            mkEntry(new TopicPartition("changelog", 2), largeOffset),
+            mkEntry(new TopicPartition("changelog", 3), largeOffset)
+        );
+        final Map<TaskId, Long> expectedOffsetSums = mkMap(mkEntry(taskId00, Long.MAX_VALUE));
+
+        expectLockObtainedFor(taskId00);
+        makeTaskFolders(taskId00.toString());
+        writeCheckpointFile(taskId00, changelogOffsets);
+        replay(stateDirectory);
+        taskManager.handleRebalanceStart(singleton("topic"));
+
+        assertThat(taskManager.getTaskOffsetSums(), is(expectedOffsetSums));
     }
 
     @Test
@@ -346,7 +504,7 @@ public class TaskManagerTest {
 
     @Test
     public void shouldAddNewActiveTasks() {
-        final Map<TaskId, Set<TopicPartition>> assignment = singletonMap(taskId00, taskId00Partitions);
+        final Map<TaskId, Set<TopicPartition>> assignment = taskId00Assignment;
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
 
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
@@ -492,7 +650,7 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldCloseActiveTasksAndPropogateExceptionsOnCleanShutdown() {
+    public void shouldCloseActiveTasksAndPropagateExceptionsOnCleanShutdown() {
         final TopicPartition changelog = new TopicPartition("changelog", 0);
         final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
             mkEntry(taskId00, taskId00Partitions),
@@ -841,7 +999,8 @@ public class TaskManagerTest {
         expect(consumer.assignment()).andReturn(assignment);
         consumer.pause(assignment);
         expectLastCall();
-        replay(consumer);
+        expect(stateDirectory.listTaskDirectories()).andReturn(new File[0]);
+        replay(consumer, stateDirectory);
         assertThat(taskManager.isRebalanceInProgress(), is(false));
         taskManager.handleRebalanceStart(emptySet());
         assertThat(taskManager.isRebalanceInProgress(), is(true));
@@ -875,17 +1034,19 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldNotCommitActiveAndStandbyTasksWhileRebalanceInProgress() {
+    public void shouldNotCommitActiveAndStandbyTasksWhileRebalanceInProgress() throws IOException {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
         final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false);
 
+        makeTaskFolders(taskId00.toString(), task01.toString());
+        expectLockObtainedFor(taskId00, taskId01);
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
             .andReturn(singletonList(task00)).anyTimes();
         expect(standbyTaskCreator.createTasks(eq(taskId01Assignment)))
             .andReturn(singletonList(task01)).anyTimes();
 
-        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, stateDirectory, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
@@ -975,7 +1136,7 @@ public class TaskManagerTest {
         final Map<TopicPartition, Long> purgableOffsets = new HashMap<>();
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
             @Override
-            public Map<TopicPartition, Long> purgableOffsets() {
+            public Map<TopicPartition, Long> purgeableOffsets() {
                 return purgableOffsets;
             }
         };
@@ -1011,7 +1172,7 @@ public class TaskManagerTest {
         final Map<TopicPartition, Long> purgableOffsets = new HashMap<>();
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
             @Override
-            public Map<TopicPartition, Long> purgableOffsets() {
+            public Map<TopicPartition, Long> purgeableOffsets() {
                 return purgableOffsets;
             }
         };
@@ -1410,6 +1571,80 @@ public class TaskManagerTest {
         assertThat(taskManager.producerMetrics(), is(dummyProducerMetrics));
     }
 
+    private Map<TaskId, StateMachineTask> handleAssignment(final Map<TaskId, Set<TopicPartition>> runningActiveAssignment,
+                                                           final Map<TaskId, Set<TopicPartition>> standbyAssignment,
+                                                           final Map<TaskId, Set<TopicPartition>> restoringActiveAssignment) {
+        final Set<Task> runningTasks = runningActiveAssignment.entrySet().stream()
+                                           .map(t -> new StateMachineTask(t.getKey(), t.getValue(), true))
+                                           .collect(Collectors.toSet());
+        final Set<Task> standbyTasks = standbyAssignment.entrySet().stream()
+                                           .map(t -> new StateMachineTask(t.getKey(), t.getValue(), false))
+                                           .collect(Collectors.toSet());
+        final Set<Task> restoringTasks = restoringActiveAssignment.entrySet().stream()
+                                             .map(t -> new StateMachineTask(t.getKey(), t.getValue(), true))
+                                             .collect(Collectors.toSet());
+
+        // Initially assign only the active tasks we want to complete restoration
+        final Map<TaskId, Set<TopicPartition>> allActiveTasksAssignment = new HashMap<>(runningActiveAssignment);
+        allActiveTasksAssignment.putAll(restoringActiveAssignment);
+        final Set<Task> allActiveTasks = new HashSet<>(runningTasks);
+        allActiveTasks.addAll(restoringTasks);
+
+        expect(activeTaskCreator.createTasks(anyObject(), eq(runningActiveAssignment))).andStubReturn(runningTasks);
+        expect(standbyTaskCreator.createTasks(eq(standbyAssignment))).andStubReturn(standbyTasks);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(allActiveTasksAssignment))).andStubReturn(allActiveTasks);
+
+        expectRestoreToBeCompleted(consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
+
+        taskManager.handleAssignment(runningActiveAssignment, standbyAssignment);
+        assertThat(taskManager.tryToCompleteRestoration(), is(true));
+
+        taskManager.handleAssignment(allActiveTasksAssignment, standbyAssignment);
+
+        final Map<TaskId, StateMachineTask> allTasks = new HashMap<>();
+
+        // Just make sure all tasks ended up in the expected state
+        for (final Task task : runningTasks) {
+            assertThat(task.state(), is(Task.State.RUNNING));
+            allTasks.put(task.id(), (StateMachineTask) task);
+        }
+        for (final Task task : restoringTasks) {
+            assertThat(task.state(), not(Task.State.RUNNING));
+            allTasks.put(task.id(), (StateMachineTask) task);
+        }
+        for (final Task task : standbyTasks) {
+            assertThat(task.state(), is(Task.State.RUNNING));
+            allTasks.put(task.id(), (StateMachineTask) task);
+        }
+        return allTasks;
+    }
+
+    private void expectLockObtainedFor(final TaskId... tasks) throws IOException {
+        for (final TaskId task : tasks) {
+            expect(stateDirectory.lock(task)).andReturn(true).once();
+        }
+    }
+
+    private void expectLockFailedFor(final TaskId... tasks) throws IOException {
+        for (final TaskId task : tasks) {
+            expect(stateDirectory.lock(task)).andReturn(false).once();
+        }
+    }
+
+    private void expectUnlockFor(final TaskId... tasks) throws IOException {
+        for (final TaskId task : tasks) {
+            stateDirectory.unlock(task);
+            expectLastCall();
+        }
+    }
+
+    private static void expectConsumerAssignmentPaused(final Consumer<byte[], byte[]> consumer) {
+        final Set<TopicPartition> assignment = singleton(new TopicPartition("assignment", 0));
+        expect(consumer.assignment()).andReturn(assignment);
+        consumer.pause(assignment);
+    }
+
     private static void expectRestoreToBeCompleted(final Consumer<byte[], byte[]> consumer,
                                                    final ChangelogReader changeLogReader) {
         final Set<TopicPartition> assignment = singleton(new TopicPartition("assignment", 0));
@@ -1425,11 +1660,31 @@ public class TaskManagerTest {
         return futureDeletedRecords;
     }
 
+    private void makeTaskFolders(final String... names) throws IOException {
+        final File[] taskFolders = new File[names.length];
+        for (int i = 0; i < names.length; ++i) {
+            taskFolders[i] = testFolder.newFolder(names[i]);
+        }
+        expect(stateDirectory.listTaskDirectories()).andReturn(taskFolders).once();
+    }
+
+    private void writeCheckpointFile(final TaskId task, final Map<TopicPartition, Long> offsets) throws IOException {
+        final File checkpointFile = getCheckpointFile(task);
+        assertThat(checkpointFile.createNewFile(), is(true));
+        new OffsetCheckpoint(checkpointFile).write(offsets);
+        expect(stateDirectory.checkpointFileFor(task)).andReturn(checkpointFile);
+    }
+
+    private File getCheckpointFile(final TaskId task) {
+        return new File(new File(testFolder.getRoot(), task.toString()), StateManagerUtil.CHECKPOINT_FILE_NAME);
+    }
+
     private static class StateMachineTask extends AbstractTask implements Task {
         private final boolean active;
         private boolean commitNeeded = false;
         private boolean commitRequested = false;
         private Map<TopicPartition, Long> purgeableOffsets;
+        private Map<TopicPartition, Long> changelogOffsets;
         private Map<TopicPartition, LinkedList<ConsumerRecord<byte[], byte[]>>> queue = new HashMap<>();
 
         StateMachineTask(final TaskId id,
@@ -1522,13 +1777,17 @@ public class TaskManagerTest {
         }
 
         @Override
-        public Map<TopicPartition, Long> purgableOffsets() {
+        public Map<TopicPartition, Long> purgeableOffsets() {
             return purgeableOffsets;
         }
 
+        void setChangelogOffsets(final Map<TopicPartition, Long> changelogOffsets) {
+            this.changelogOffsets = changelogOffsets;
+        }
+
         @Override
         public Map<TopicPartition, Long> changelogOffsets() {
-            return null;
+            return changelogOffsets;
         }
 
         @Override
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java
index cf79a5f..07bb085 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java
@@ -30,6 +30,8 @@ import java.util.UUID;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION;
+import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.MIN_VERSION_OFFSET_SUM_SUBSCRIPTION;
+import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
@@ -294,13 +296,47 @@ public class SubscriptionInfoTest {
     }
 
     @Test
-    public void shouldConvertTaskOffsetSumMapToTaskSetsForOlderVersion() {
+    public void shouldConvertTaskOffsetSumMapToTaskSets() {
         final SubscriptionInfo info =
             new SubscriptionInfo(7, LATEST_SUPPORTED_VERSION, processId, "localhost:80", TASK_OFFSET_SUMS);
         assertThat(info.prevTasks(), is(ACTIVE_TASKS));
         assertThat(info.standbyTasks(), is(STANDBY_TASKS));
     }
 
+    @Test
+    public void shouldReturnTaskOffsetSumsMapForDecodedSubscription() {
+        final SubscriptionInfo info = SubscriptionInfo.decode(
+            new SubscriptionInfo(MIN_VERSION_OFFSET_SUM_SUBSCRIPTION,
+                                 LATEST_SUPPORTED_VERSION, processId,
+                                 "localhost:80",
+                                 TASK_OFFSET_SUMS)
+                .encode());
+        assertThat(info.taskOffsetSums(), is(TASK_OFFSET_SUMS));
+    }
+
+    @Test
+    public void shouldConvertTaskSetsToTaskOffsetSumMapWithOlderSubscription() {
+        final Map<TaskId, Long> expectedOffsetSumsMap = mkMap(
+            mkEntry(new TaskId(0, 0), Task.LATEST_OFFSET),
+            mkEntry(new TaskId(0, 1), Task.LATEST_OFFSET),
+            mkEntry(new TaskId(1, 0), Task.LATEST_OFFSET),
+            mkEntry(new TaskId(1, 1), UNKNOWN_OFFSET_SUM),
+            mkEntry(new TaskId(2, 0), UNKNOWN_OFFSET_SUM)
+        );
+
+        final SubscriptionInfo info = SubscriptionInfo.decode(
+            new LegacySubscriptionInfoSerde(
+                SubscriptionInfo.MIN_VERSION_OFFSET_SUM_SUBSCRIPTION - 1,
+                LATEST_SUPPORTED_VERSION,
+                processId,
+                ACTIVE_TASKS,
+                STANDBY_TASKS,
+                "localhost:80")
+            .encode());
+
+        assertThat(info.taskOffsetSums(), is(expectedOffsetSumsMap));
+    }
+
     private static ByteBuffer encodeFutureVersion() {
         final ByteBuffer buf = ByteBuffer.allocate(4 /* used version */
                                                        + 4 /* supported version */);