You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2018/09/30 23:04:40 UTC

[kafka] branch trunk updated: KAFKA-7223: internally provide full consumer record during restore (#5710)

This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new eb61df6  KAFKA-7223: internally provide full consumer record during restore (#5710)
eb61df6 is described below

commit eb61df642d8227fb7eaa1099ec055145b02e3dca
Author: John Roesler <vv...@users.noreply.github.com>
AuthorDate: Sun Sep 30 18:04:28 2018 -0500

    KAFKA-7223: internally provide full consumer record during restore (#5710)
    
    The Suppression buffer stores the full record context, not just the key and value,
    so its changelog/restore loop will also need to preserve this information.
    
    This change is a precondition to that, creating an option to register a
    state restore callback to receive the full consumer record.
    
    Reviewers: Bill Bejeck <bi...@confluent.io>, Matthias J. Sax <ma...@confluent.io>, Guozhang Wang <wa...@gmail.com>
---
 .../internals/CompositeRestoreListener.java        |  33 ++---
 .../internals/GlobalStateManagerImpl.java          |  17 +--
 .../processor/internals/ProcessorStateManager.java |  23 +---
 ...ava => RecordBatchingStateRestoreCallback.java} |  23 +---
 .../streams/processor/internals/StandbyTask.java   |   5 +-
 .../internals/StateRestoreCallbackAdapter.java     |  52 ++++++++
 .../streams/processor/internals/StateRestorer.java |   6 +-
 .../processor/internals/StoreChangelogReader.java  |   5 +-
 .../internals/CompositeRestoreListenerTest.java    |  20 ++-
 .../internals/ProcessorStateManagerTest.java       |   6 +-
 .../processor/internals/StandbyTaskTest.java       |  16 ++-
 .../internals/StateRestoreCallbackAdapterTest.java | 147 +++++++++++++++++++++
 .../processor/internals/StateRestorerTest.java     |   4 +-
 .../WrappedBatchingStateRestoreCallbackTest.java   |  51 -------
 .../kafka/test/InternalMockProcessorContext.java   |  25 ++--
 15 files changed, 279 insertions(+), 154 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/CompositeRestoreListener.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/CompositeRestoreListener.java
index 4783734..7cccad6 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/CompositeRestoreListener.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/CompositeRestoreListener.java
@@ -18,20 +18,20 @@
 package org.apache.kafka.streams.processor.internals;
 
 
+import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.AbstractNotifyingBatchingRestoreCallback;
-import org.apache.kafka.streams.processor.BatchingStateRestoreCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateRestoreListener;
 
 import java.util.Collection;
 
-public class CompositeRestoreListener implements BatchingStateRestoreCallback, StateRestoreListener {
+public class CompositeRestoreListener implements RecordBatchingStateRestoreCallback, StateRestoreListener {
 
     public static final NoOpStateRestoreListener NO_OP_STATE_RESTORE_LISTENER = new NoOpStateRestoreListener();
-    private final BatchingStateRestoreCallback internalBatchingRestoreCallback;
+    private final RecordBatchingStateRestoreCallback internalBatchingRestoreCallback;
     private final StateRestoreListener storeRestoreListener;
     private StateRestoreListener userRestoreListener = NO_OP_STATE_RESTORE_LISTENER;
 
@@ -43,7 +43,7 @@ public class CompositeRestoreListener implements BatchingStateRestoreCallback, S
             storeRestoreListener = NO_OP_STATE_RESTORE_LISTENER;
         }
 
-        internalBatchingRestoreCallback = getBatchingRestoreCallback(stateRestoreCallback);
+        internalBatchingRestoreCallback = StateRestoreCallbackAdapter.adapt(stateRestoreCallback);
     }
 
     /**
@@ -85,8 +85,8 @@ public class CompositeRestoreListener implements BatchingStateRestoreCallback, S
     }
 
     @Override
-    public void restoreAll(final Collection<KeyValue<byte[], byte[]>> records) {
-        internalBatchingRestoreCallback.restoreAll(records);
+    public void restoreBatch(final Collection<ConsumerRecord<byte[], byte[]>> records) {
+        internalBatchingRestoreCallback.restoreBatch(records);
     }
 
     void setUserRestoreListener(final StateRestoreListener userRestoreListener) {
@@ -96,25 +96,20 @@ public class CompositeRestoreListener implements BatchingStateRestoreCallback, S
     }
 
     @Override
+    public void restoreAll(final Collection<KeyValue<byte[], byte[]>> records) {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
     public void restore(final byte[] key,
                         final byte[] value) {
         throw new UnsupportedOperationException("Single restore functionality shouldn't be called directly but "
-                                                + "through the delegated StateRestoreCallback instance");
+                                                    + "through the delegated StateRestoreCallback instance");
     }
 
-    private BatchingStateRestoreCallback getBatchingRestoreCallback(final StateRestoreCallback restoreCallback) {
-        if (restoreCallback instanceof  BatchingStateRestoreCallback) {
-            return (BatchingStateRestoreCallback) restoreCallback;
-        }
-
-        return new WrappedBatchingStateRestoreCallback(restoreCallback);
-    }
-
-
-    private static final class NoOpStateRestoreListener extends AbstractNotifyingBatchingRestoreCallback {
-
+    private static final class NoOpStateRestoreListener extends AbstractNotifyingBatchingRestoreCallback implements RecordBatchingStateRestoreCallback {
         @Override
-        public void restoreAll(final Collection<KeyValue<byte[], byte[]>> records) {
+        public void restoreBatch(final Collection<ConsumerRecord<byte[], byte[]>> records) {
 
         }
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
index a4ec23d..a20f3b0 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
@@ -25,12 +25,10 @@ import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Utils;
-import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
-import org.apache.kafka.streams.processor.BatchingStateRestoreCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.apache.kafka.streams.processor.StateStore;
@@ -263,11 +261,8 @@ public class GlobalStateManagerImpl extends AbstractStateManager implements Glob
 
             long offset = globalConsumer.position(topicPartition);
             final Long highWatermark = highWatermarks.get(topicPartition);
-            final BatchingStateRestoreCallback stateRestoreAdapter =
-                (BatchingStateRestoreCallback) ((stateRestoreCallback instanceof
-                                                     BatchingStateRestoreCallback)
-                                                ? stateRestoreCallback
-                                                : new WrappedBatchingStateRestoreCallback(stateRestoreCallback));
+            final RecordBatchingStateRestoreCallback stateRestoreAdapter =
+                StateRestoreCallbackAdapter.adapt(stateRestoreCallback);
 
             stateRestoreListener.onRestoreStart(topicPartition, storeName, offset, highWatermark);
             long restoreCount = 0L;
@@ -275,14 +270,14 @@ public class GlobalStateManagerImpl extends AbstractStateManager implements Glob
             while (offset < highWatermark) {
                 try {
                     final ConsumerRecords<byte[], byte[]> records = globalConsumer.poll(pollTime);
-                    final List<KeyValue<byte[], byte[]>> restoreRecords = new ArrayList<>();
-                    for (final ConsumerRecord<byte[], byte[]> record : records) {
+                    final List<ConsumerRecord<byte[], byte[]>> restoreRecords = new ArrayList<>();
+                    for (final ConsumerRecord<byte[], byte[]> record : records.records(topicPartition)) {
                         if (record.key() != null) {
-                            restoreRecords.add(KeyValue.pair(record.key(), record.value()));
+                            restoreRecords.add(record);
                         }
                     }
                     offset = globalConsumer.position(topicPartition);
-                    stateRestoreAdapter.restoreAll(restoreRecords);
+                    stateRestoreAdapter.restoreBatch(restoreRecords);
                     stateRestoreListener.onBatchRestored(topicPartition, storeName, offset, restoreRecords.size());
                     restoreCount += restoreRecords.size();
                 } catch (final InvalidOffsetException recoverableException) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
index 15a5c21..3d0c664 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
@@ -16,11 +16,10 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.errors.ProcessorStateException;
-import org.apache.kafka.streams.processor.BatchingStateRestoreCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
@@ -35,6 +34,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.apache.kafka.streams.processor.internals.StateRestoreCallbackAdapter.adapt;
+
 
 public class ProcessorStateManager extends AbstractStateManager {
     private static final String STATE_CHANGELOG_TOPIC_SUFFIX = "-changelog";
@@ -130,7 +131,6 @@ public class ProcessorStateManager extends AbstractStateManager {
         if (isStandby) {
             log.trace("Preparing standby replica of persistent state store {} with changelog topic {}", storeName, topic);
             restoreCallbacks.put(topic, stateRestoreCallback);
-
         } else {
             log.trace("Restoring state store {} from changelog topic {}", storeName, topic);
             final StateRestorer restorer = new StateRestorer(storePartition,
@@ -138,7 +138,7 @@ public class ProcessorStateManager extends AbstractStateManager {
                                                              checkpointableOffsets.get(storePartition),
                                                              offsetLimit(storePartition),
                                                              store.persistent(),
-                storeName);
+                                                             storeName);
 
             changelogReader.register(restorer);
         }
@@ -173,14 +173,14 @@ public class ProcessorStateManager extends AbstractStateManager {
     }
 
     void updateStandbyStates(final TopicPartition storePartition,
-                             final List<KeyValue<byte[], byte[]>> restoreRecords,
+                             final List<ConsumerRecord<byte[], byte[]>> restoreRecords,
                              final long lastOffset) {
         // restore states from changelog records
-        final BatchingStateRestoreCallback restoreCallback = getBatchingRestoreCallback(restoreCallbacks.get(storePartition.topic()));
+        final RecordBatchingStateRestoreCallback restoreCallback = adapt(restoreCallbacks.get(storePartition.topic()));
 
         if (!restoreRecords.isEmpty()) {
             try {
-                restoreCallback.restoreAll(restoreRecords);
+                restoreCallback.restoreBatch(restoreRecords);
             } catch (final Exception e) {
                 throw new ProcessorStateException(String.format("%sException caught while trying to restore state from %s", logPrefix, storePartition), e);
             }
@@ -313,15 +313,6 @@ public class ProcessorStateManager extends AbstractStateManager {
         return globalStores.get(name);
     }
 
-    private BatchingStateRestoreCallback getBatchingRestoreCallback(final StateRestoreCallback callback) {
-        if (callback instanceof BatchingStateRestoreCallback) {
-            return (BatchingStateRestoreCallback) callback;
-        }
-
-        // TODO: avoid creating a new object for each update call?
-        return new WrappedBatchingStateRestoreCallback(callback);
-    }
-
     Collection<TopicPartition> changelogPartitions() {
         return changelogPartitions;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/WrappedBatchingStateRestoreCallback.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordBatchingStateRestoreCallback.java
similarity index 60%
rename from streams/src/main/java/org/apache/kafka/streams/processor/internals/WrappedBatchingStateRestoreCallback.java
rename to streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordBatchingStateRestoreCallback.java
index b469b38..78a885d 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/WrappedBatchingStateRestoreCallback.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordBatchingStateRestoreCallback.java
@@ -14,33 +14,24 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.processor.BatchingStateRestoreCallback;
-import org.apache.kafka.streams.processor.StateRestoreCallback;
 
 import java.util.Collection;
 
-public class WrappedBatchingStateRestoreCallback implements BatchingStateRestoreCallback {
-
-    private final StateRestoreCallback stateRestoreCallback;
-
-    public WrappedBatchingStateRestoreCallback(final StateRestoreCallback stateRestoreCallback) {
-        this.stateRestoreCallback = stateRestoreCallback;
-    }
+public interface RecordBatchingStateRestoreCallback extends BatchingStateRestoreCallback {
+    void restoreBatch(final Collection<ConsumerRecord<byte[], byte[]>> records);
 
     @Override
-    public void restoreAll(final Collection<KeyValue<byte[], byte[]>> records) {
-        for (final KeyValue<byte[], byte[]> record : records) {
-            restore(record.key, record.value);
-        }
+    default void restoreAll(final Collection<KeyValue<byte[], byte[]>> records) {
+        throw new UnsupportedOperationException();
     }
 
     @Override
-    public void restore(final byte[] key,
-                        final byte[] value) {
-        stateRestoreCallback.restore(key, value);
+    default void restore(final byte[] key, final byte[] value) {
+        throw new UnsupportedOperationException();
     }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
index 6f4e617..45f06b2 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
@@ -19,7 +19,6 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsMetrics;
 import org.apache.kafka.streams.processor.TaskId;
@@ -171,12 +170,12 @@ public class StandbyTask extends AbstractTask {
         final long limit = stateMgr.offsetLimit(partition);
 
         long lastOffset = -1L;
-        final List<KeyValue<byte[], byte[]>> restoreRecords = new ArrayList<>(records.size());
+        final List<ConsumerRecord<byte[], byte[]>> restoreRecords = new ArrayList<>(records.size());
         final List<ConsumerRecord<byte[], byte[]>> remainingRecords = new ArrayList<>();
 
         for (final ConsumerRecord<byte[], byte[]> record : records) {
             if (record.offset() < limit) {
-                restoreRecords.add(KeyValue.pair(record.key(), record.value()));
+                restoreRecords.add(record);
                 lastOffset = record.offset();
                 // ideally, we'd use the stream time at the time of the change logging, but we'll settle for
                 // record timestamp for now.
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapter.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapter.java
new file mode 100644
index 0000000..fce3f80
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapter.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.processor.BatchingStateRestoreCallback;
+import org.apache.kafka.streams.processor.StateRestoreCallback;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+public final class StateRestoreCallbackAdapter {
+    private StateRestoreCallbackAdapter() {}
+
+    public static RecordBatchingStateRestoreCallback adapt(final StateRestoreCallback restoreCallback) {
+        Objects.requireNonNull(restoreCallback, "stateRestoreCallback must not be null");
+        if (restoreCallback instanceof RecordBatchingStateRestoreCallback) {
+            return (RecordBatchingStateRestoreCallback) restoreCallback;
+        } else if (restoreCallback instanceof BatchingStateRestoreCallback) {
+            return records -> {
+                final List<KeyValue<byte[], byte[]>> keyValues = new ArrayList<>();
+                for (final ConsumerRecord<byte[], byte[]> record : records) {
+                    keyValues.add(new KeyValue<>(record.key(), record.value()));
+                }
+                ((BatchingStateRestoreCallback) restoreCallback).restoreAll(keyValues);
+            };
+        } else {
+            return records -> {
+                for (final ConsumerRecord<byte[], byte[]> record : records) {
+                    restoreCallback.restore(record.key(), record.value());
+                }
+            };
+        }
+    }
+}
\ No newline at end of file
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java
index 096ed9d..6a2076e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java
@@ -16,8 +16,8 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.processor.StateRestoreListener;
 
 import java.util.Collection;
@@ -79,8 +79,8 @@ public class StateRestorer {
         compositeRestoreListener.onBatchRestored(partition, storeName, currentRestoredOffset, numRestored);
     }
 
-    void restore(final Collection<KeyValue<byte[], byte[]>> records) {
-        compositeRestoreListener.restoreAll(records);
+    void restore(final Collection<ConsumerRecord<byte[], byte[]>> records) {
+        compositeRestoreListener.restoreBatch(records);
     }
 
     boolean isPersistent() {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
index 9185920..34e6e5c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
@@ -24,7 +24,6 @@ import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.slf4j.Logger;
@@ -281,7 +280,7 @@ public class StoreChangelogReader implements ChangelogReader {
     private long processNext(final List<ConsumerRecord<byte[], byte[]>> records,
                              final StateRestorer restorer,
                              final Long endOffset) {
-        final List<KeyValue<byte[], byte[]>> restoreRecords = new ArrayList<>();
+        final List<ConsumerRecord<byte[], byte[]>> restoreRecords = new ArrayList<>();
         long nextPosition = -1;
         final int numberRecords = records.size();
         int numberRestored = 0;
@@ -295,7 +294,7 @@ public class StoreChangelogReader implements ChangelogReader {
             lastRestoredOffset = offset;
             numberRestored++;
             if (record.key() != null) {
-                restoreRecords.add(KeyValue.pair(record.key(), record.value()));
+                restoreRecords.add(record);
             }
         }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/CompositeRestoreListenerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/CompositeRestoreListenerTest.java
index ef2e6f7..5bfa4a6 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/CompositeRestoreListenerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/CompositeRestoreListenerTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.processor.BatchingStateRestoreCallback;
@@ -48,6 +49,9 @@ public class CompositeRestoreListenerTest {
     private final byte[] key = "key".getBytes(Charset.forName("UTF-8"));
     private final byte[] value = "value".getBytes(Charset.forName("UTF-8"));
     private final Collection<KeyValue<byte[], byte[]>> records = Collections.singletonList(KeyValue.pair(key, value));
+    private final Collection<ConsumerRecord<byte[], byte[]>> consumerRecords = Collections.singletonList(
+        new ConsumerRecord<>("", 0, 0L, key, value)
+    );
     private final String storeName = "test_store";
     private final long startOffset = 0L;
     private final long endOffset = 1L;
@@ -61,7 +65,7 @@ public class CompositeRestoreListenerTest {
     @Test
     public void shouldRestoreInNonBatchMode() {
         setUpCompositeRestoreListener(stateRestoreCallback);
-        compositeRestoreListener.restoreAll(records);
+        compositeRestoreListener.restoreBatch(consumerRecords);
         assertThat(stateRestoreCallback.restoredKey, is(key));
         assertThat(stateRestoreCallback.restoredValue, is(value));
     }
@@ -69,7 +73,7 @@ public class CompositeRestoreListenerTest {
     @Test
     public void shouldRestoreInBatchMode() {
         setUpCompositeRestoreListener(batchingStateRestoreCallback);
-        compositeRestoreListener.restoreAll(records);
+        compositeRestoreListener.restoreBatch(consumerRecords);
         assertThat(batchingStateRestoreCallback.getRestoredRecords(), is(records));
     }
 
@@ -126,7 +130,7 @@ public class CompositeRestoreListenerTest {
         compositeRestoreListener = new CompositeRestoreListener(batchingStateRestoreCallback);
         compositeRestoreListener.setUserRestoreListener(null);
 
-        compositeRestoreListener.restoreAll(records);
+        compositeRestoreListener.restoreBatch(consumerRecords);
         compositeRestoreListener.onRestoreStart(topicPartition, storeName, startOffset, endOffset);
         compositeRestoreListener.onBatchRestored(topicPartition, storeName, batchOffset, numberRestored);
         compositeRestoreListener.onRestoreEnd(topicPartition, storeName, numberRestored);
@@ -140,7 +144,7 @@ public class CompositeRestoreListenerTest {
         compositeRestoreListener = new CompositeRestoreListener(noListenBatchingStateRestoreCallback);
         compositeRestoreListener.setUserRestoreListener(null);
 
-        compositeRestoreListener.restoreAll(records);
+        compositeRestoreListener.restoreBatch(consumerRecords);
         compositeRestoreListener.onRestoreStart(topicPartition, storeName, startOffset, endOffset);
         compositeRestoreListener.onBatchRestored(topicPartition, storeName, batchOffset, numberRestored);
         compositeRestoreListener.onRestoreEnd(topicPartition, storeName, numberRestored);
@@ -151,11 +155,15 @@ public class CompositeRestoreListenerTest {
     @Test(expected = UnsupportedOperationException.class)
     public void shouldThrowExceptionWhenSinglePutDirectlyCalled() {
         compositeRestoreListener = new CompositeRestoreListener(noListenBatchingStateRestoreCallback);
-        compositeRestoreListener.setUserRestoreListener(null);
-
         compositeRestoreListener.restore(key, value);
     }
 
+    @Test(expected = UnsupportedOperationException.class)
+    public void shouldThrowExceptionWhenRestoreAllDirectlyCalled() {
+        compositeRestoreListener = new CompositeRestoreListener(noListenBatchingStateRestoreCallback);
+        compositeRestoreListener.restoreAll(Collections.emptyList());
+    }
+
     private void assertStateRestoreListenerOnStartNotification(final MockStateRestoreListener restoreListener) {
         assertTrue(restoreListener.storeNameCalledStates.containsKey(RESTORE_START));
         assertThat(restoreListener.restoreTopicPartition, is(topicPartition));
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
index cd95a68..fbcb2c8 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
@@ -122,7 +122,7 @@ public class ProcessorStateManagerTest {
             stateMgr.register(persistentStore, batchingRestoreCallback);
             stateMgr.updateStandbyStates(
                 persistentStorePartition,
-                singletonList(KeyValue.pair(consumerRecord.key(), consumerRecord.value())),
+                singletonList(consumerRecord),
                 consumerRecord.offset()
             );
             assertThat(batchingRestoreCallback.getRestoredRecords().size(), is(1));
@@ -144,7 +144,7 @@ public class ProcessorStateManagerTest {
             stateMgr.register(persistentStore, persistentStore.stateRestoreCallback);
             stateMgr.updateStandbyStates(
                 persistentStorePartition,
-                singletonList(KeyValue.pair(consumerRecord.key(), consumerRecord.value())),
+                singletonList(consumerRecord),
                 consumerRecord.offset()
             );
             assertThat(persistentStore.keys.size(), is(1));
@@ -411,7 +411,7 @@ public class ProcessorStateManagerTest {
         final byte[] bytes = Serdes.Integer().serializer().serialize("", 10);
         stateMgr.updateStandbyStates(
             persistentStorePartition,
-            singletonList(KeyValue.pair(bytes, bytes)),
+            singletonList(new ConsumerRecord<>("", 0, 0L, bytes, bytes)),
             888L
         );
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index 820191d..e669879 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -79,6 +79,7 @@ import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkList;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkProperties;
+import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
@@ -176,16 +177,21 @@ public class StandbyTaskTest {
     }
 
     @SuppressWarnings("unchecked")
-    @Test(expected = ProcessorStateException.class)
-    public void testUpdateNonPersistentStore() throws IOException {
+    @Test
+    public void testUpdateNonInitializedStore() throws IOException {
         final StreamsConfig config = createConfig(baseDir);
         final StandbyTask task = new StandbyTask(taskId, topicPartitions, topology, consumer, changelogReader, config, null, stateDirectory);
 
         restoreStateConsumer.assign(new ArrayList<>(task.checkpointedOffsets().keySet()));
 
-        task.update(partition1,
-            singletonList(new ConsumerRecord<>(partition1.topic(), partition1.partition(), 10, 0L, TimestampType.CREATE_TIME, 0L, 0, 0, recordKey, recordValue))
-        );
+        try {
+            task.update(partition1,
+                        singletonList(new ConsumerRecord<>(partition1.topic(), partition1.partition(), 10, 0L, TimestampType.CREATE_TIME, 0L, 0, 0, recordKey, recordValue))
+            );
+            fail("expected an exception");
+        } catch (final NullPointerException npe) {
+            assertThat(npe.getMessage(), containsString("stateRestoreCallback must not be null"));
+        }
 
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapterTest.java
new file mode 100644
index 0000000..60b928a
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapterTest.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals;
+
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.processor.BatchingStateRestoreCallback;
+import org.apache.kafka.streams.processor.StateRestoreCallback;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import static java.util.Arrays.asList;
+import static org.apache.kafka.streams.processor.internals.StateRestoreCallbackAdapter.adapt;
+import static org.easymock.EasyMock.mock;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.Is.is;
+
+public class StateRestoreCallbackAdapterTest {
+    @Test(expected = UnsupportedOperationException.class)
+    public void shouldThrowOnRestoreAll() {
+        adapt(mock(StateRestoreCallback.class)).restoreAll(null);
+    }
+
+    @Test(expected = UnsupportedOperationException.class)
+    public void shouldThrowOnRestore() {
+        adapt(mock(StateRestoreCallback.class)).restore(null, null);
+    }
+
+    @Test
+    public void shouldPassRecordsThrough() {
+        final ArrayList<ConsumerRecord<byte[], byte[]>> actual = new ArrayList<>();
+        final RecordBatchingStateRestoreCallback callback = actual::addAll;
+
+        final RecordBatchingStateRestoreCallback adapted = adapt(callback);
+
+        final byte[] key1 = {1};
+        final byte[] value1 = {2};
+        final byte[] key2 = {3};
+        final byte[] value2 = {4};
+
+        final List<ConsumerRecord<byte[], byte[]>> recordList = asList(
+            new ConsumerRecord<>("topic1", 0, 0L, key1, value1),
+            new ConsumerRecord<>("topic2", 1, 1L, key2, value2)
+        );
+
+        adapted.restoreBatch(recordList);
+
+        validate(actual, recordList);
+    }
+
+    @Test
+    public void shouldConvertToKeyValueBatches() {
+        final ArrayList<KeyValue<byte[], byte[]>> actual = new ArrayList<>();
+        final BatchingStateRestoreCallback callback = new BatchingStateRestoreCallback() {
+            @Override
+            public void restoreAll(final Collection<KeyValue<byte[], byte[]>> records) {
+                actual.addAll(records);
+            }
+
+            @Override
+            public void restore(final byte[] key, final byte[] value) {
+                // unreachable
+            }
+        };
+
+        final RecordBatchingStateRestoreCallback adapted = adapt(callback);
+
+        final byte[] key1 = {1};
+        final byte[] value1 = {2};
+        final byte[] key2 = {3};
+        final byte[] value2 = {4};
+        adapted.restoreBatch(asList(
+            new ConsumerRecord<>("topic1", 0, 0L, key1, value1),
+            new ConsumerRecord<>("topic2", 1, 1L, key2, value2)
+        ));
+
+        assertThat(
+            actual,
+            is(asList(
+                new KeyValue<>(key1, value1),
+                new KeyValue<>(key2, value2)
+            ))
+        );
+    }
+
+    @Test
+    public void shouldConvertToKeyValue() {
+        final ArrayList<KeyValue<byte[], byte[]>> actual = new ArrayList<>();
+        final StateRestoreCallback callback = (key, value) -> actual.add(new KeyValue<>(key, value));
+
+        final RecordBatchingStateRestoreCallback adapted = adapt(callback);
+
+        final byte[] key1 = {1};
+        final byte[] value1 = {2};
+        final byte[] key2 = {3};
+        final byte[] value2 = {4};
+        adapted.restoreBatch(asList(
+            new ConsumerRecord<>("topic1", 0, 0L, key1, value1),
+            new ConsumerRecord<>("topic2", 1, 1L, key2, value2)
+        ));
+
+        assertThat(
+            actual,
+            is(asList(
+                new KeyValue<>(key1, value1),
+                new KeyValue<>(key2, value2)
+            ))
+        );
+    }
+
+    private void validate(final List<ConsumerRecord<byte[], byte[]>> actual,
+                          final List<ConsumerRecord<byte[], byte[]>> expected) {
+        assertThat(actual.size(), is(expected.size()));
+        for (int i = 0; i < actual.size(); i++) {
+            final ConsumerRecord<byte[], byte[]> actual1 = actual.get(i);
+            final ConsumerRecord<byte[], byte[]> expected1 = expected.get(i);
+            assertThat(actual1.topic(), is(expected1.topic()));
+            assertThat(actual1.partition(), is(expected1.partition()));
+            assertThat(actual1.offset(), is(expected1.offset()));
+            assertThat(actual1.key(), is(expected1.key()));
+            assertThat(actual1.value(), is(expected1.value()));
+            assertThat(actual1.timestamp(), is(expected1.timestamp()));
+            assertThat(actual1.headers(), is(expected1.headers()));
+        }
+    }
+
+
+}
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateRestorerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateRestorerTest.java
index 62da23b..dc22bb4 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateRestorerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateRestorerTest.java
@@ -16,8 +16,8 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.test.MockRestoreCallback;
 import org.apache.kafka.test.MockStateRestoreListener;
 import org.junit.Before;
@@ -45,7 +45,7 @@ public class StateRestorerTest {
 
     @Test
     public void shouldCallRestoreOnRestoreCallback() {
-        restorer.restore(Collections.singletonList(KeyValue.pair(new byte[0], new byte[0])));
+        restorer.restore(Collections.singletonList(new ConsumerRecord<>("", 0, 0L, new byte[0], new byte[0])));
         assertThat(callback.restored.size(), equalTo(1));
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/WrappedBatchingStateRestoreCallbackTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/WrappedBatchingStateRestoreCallbackTest.java
deleted file mode 100644
index c602ee1..0000000
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/WrappedBatchingStateRestoreCallbackTest.java
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.kafka.streams.processor.internals;
-
-
-import org.apache.kafka.streams.KeyValue;
-import org.apache.kafka.streams.processor.BatchingStateRestoreCallback;
-import org.apache.kafka.test.MockRestoreCallback;
-import org.junit.Test;
-
-import java.nio.charset.Charset;
-import java.util.Collection;
-import java.util.Collections;
-
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.core.Is.is;
-
-public class WrappedBatchingStateRestoreCallbackTest {
-
-    private final MockRestoreCallback mockRestoreCallback = new MockRestoreCallback();
-    private final byte[] key = "key".getBytes(Charset.forName("UTF-8"));
-    private final byte[] value = "value".getBytes(Charset.forName("UTF-8"));
-    private final Collection<KeyValue<byte[], byte[]>> records = Collections.singletonList(KeyValue.pair(key, value));
-    private final BatchingStateRestoreCallback wrappedBatchingStateRestoreCallback = new WrappedBatchingStateRestoreCallback(mockRestoreCallback);
-
-    @Test
-    public void shouldRestoreSinglePutsFromArray() {
-        wrappedBatchingStateRestoreCallback.restoreAll(records);
-        assertThat(mockRestoreCallback.restored, is(records));
-        final KeyValue<byte[], byte[]> record = mockRestoreCallback.restored.get(0);
-        assertThat(record.key, is(key));
-        assertThat(record.value, is(value));
-    }
-
-
-}
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
index bedf8eb..3b5a915 100644
--- a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
@@ -16,13 +16,13 @@
  */
 package org.apache.kafka.test;
 
+import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsConfig;
-import org.apache.kafka.streams.processor.BatchingStateRestoreCallback;
 import org.apache.kafka.streams.processor.Cancellable;
 import org.apache.kafka.streams.processor.PunctuationType;
 import org.apache.kafka.streams.processor.Punctuator;
@@ -32,12 +32,12 @@ import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.To;
 import org.apache.kafka.streams.processor.internals.AbstractProcessorContext;
+import org.apache.kafka.streams.processor.internals.RecordBatchingStateRestoreCallback;
 import org.apache.kafka.streams.processor.internals.CompositeRestoreListener;
 import org.apache.kafka.streams.processor.internals.ProcessorNode;
 import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
 import org.apache.kafka.streams.processor.internals.ToInternal;
-import org.apache.kafka.streams.processor.internals.WrappedBatchingStateRestoreCallback;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.StateSerdes;
 import org.apache.kafka.streams.state.internals.ThreadCache;
@@ -50,6 +50,8 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.apache.kafka.streams.processor.internals.StateRestoreCallbackAdapter.adapt;
+
 public class InternalMockProcessorContext extends AbstractProcessorContext implements RecordCollector.Supplier {
 
     private final File stateDir;
@@ -303,22 +305,21 @@ public class InternalMockProcessorContext extends AbstractProcessorContext imple
     }
 
     public StateRestoreListener getRestoreListener(final String storeName) {
-        final BatchingStateRestoreCallback restoreCallback = getBatchingRestoreCallback(restoreFuncs.get(storeName));
-        return getStateRestoreListener(restoreCallback);
+        return getStateRestoreListener(restoreFuncs.get(storeName));
     }
 
     public void restore(final String storeName, final Iterable<KeyValue<byte[], byte[]>> changeLog) {
-        final BatchingStateRestoreCallback restoreCallback = getBatchingRestoreCallback(restoreFuncs.get(storeName));
+        final RecordBatchingStateRestoreCallback restoreCallback = adapt(restoreFuncs.get(storeName));
         final StateRestoreListener restoreListener = getRestoreListener(storeName);
 
         restoreListener.onRestoreStart(null, storeName, 0L, 0L);
 
-        final List<KeyValue<byte[], byte[]>> records = new ArrayList<>();
+        final List<ConsumerRecord<byte[], byte[]>> records = new ArrayList<>();
         for (final KeyValue<byte[], byte[]> keyValue : changeLog) {
-            records.add(keyValue);
+            records.add(new ConsumerRecord<>("", 0, 0L, keyValue.key, keyValue.value));
         }
 
-        restoreCallback.restoreAll(records);
+        restoreCallback.restoreBatch(records);
 
         restoreListener.onRestoreEnd(null, storeName, 0L);
     }
@@ -330,12 +331,4 @@ public class InternalMockProcessorContext extends AbstractProcessorContext imple
 
         return CompositeRestoreListener.NO_OP_STATE_RESTORE_LISTENER;
     }
-
-    private BatchingStateRestoreCallback getBatchingRestoreCallback(final StateRestoreCallback restoreCallback) {
-        if (restoreCallback instanceof BatchingStateRestoreCallback) {
-            return (BatchingStateRestoreCallback) restoreCallback;
-        }
-
-        return new WrappedBatchingStateRestoreCallback(restoreCallback);
-    }
 }