You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by ca...@apache.org on 2021/10/08 18:12:27 UTC

[samza] branch master updated: SAMZA-2695: Unit tests for KafkaCheckpointManager take too long to run (#1541)

This is an automated email from the ASF dual-hosted git repository.

cameronlee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new dcdd0ec  SAMZA-2695: Unit tests for KafkaCheckpointManager take too long to run (#1541)
dcdd0ec is described below

commit dcdd0ec9de223753a9146e8f7bc535575afb738c
Author: Cameron Lee <ca...@linkedin.com>
AuthorDate: Fri Oct 8 11:12:19 2021 -0700

    SAMZA-2695: Unit tests for KafkaCheckpointManager take too long to run (#1541)
---
 .../kafka/TestKafkaCheckpointManager.java          | 561 +++++++++++++++++++++
 .../kafka/TestKafkaCheckpointManagerJava.java      | 285 -----------
 .../kafka/TestKafkaCheckpointManager.scala         | 533 --------------------
 .../samza/test/harness/IntegrationTestHarness.java |   1 +
 .../KafkaCheckpointManagerIntegrationTest.java     | 206 ++++++++
 5 files changed, 768 insertions(+), 818 deletions(-)

diff --git a/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.java b/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.java
new file mode 100644
index 0000000..fe9bfb1
--- /dev/null
+++ b/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.java
@@ -0,0 +1,561 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.checkpoint.kafka;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import kafka.common.TopicAlreadyMarkedForDeletionException;
+import org.apache.samza.Partition;
+import org.apache.samza.SamzaException;
+import org.apache.samza.checkpoint.Checkpoint;
+import org.apache.samza.checkpoint.CheckpointId;
+import org.apache.samza.checkpoint.CheckpointV1;
+import org.apache.samza.checkpoint.CheckpointV2;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.container.grouper.stream.GroupByPartitionFactory;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.serializers.CheckpointV1Serde;
+import org.apache.samza.serializers.CheckpointV2Serde;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.OutgoingMessageEnvelope;
+import org.apache.samza.system.StreamValidationException;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.system.SystemFactory;
+import org.apache.samza.system.SystemProducer;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.system.kafka.KafkaStreamSpec;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.stubbing.OngoingStubbing;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+
+public class TestKafkaCheckpointManager {
+  private static final TaskName TASK0 = new TaskName("Partition 0");
+  private static final TaskName TASK1 = new TaskName("Partition 1");
+  private static final String CHECKPOINT_TOPIC = "checkpointTopic";
+  private static final String CHECKPOINT_SYSTEM = "checkpointSystem";
+  private static final SystemStreamPartition CHECKPOINT_SSP =
+      new SystemStreamPartition(CHECKPOINT_SYSTEM, CHECKPOINT_TOPIC, new Partition(0));
+  private static final SystemStreamPartition INPUT_SSP0 =
+      new SystemStreamPartition("inputSystem", "inputTopic", new Partition(0));
+  private static final SystemStreamPartition INPUT_SSP1 =
+      new SystemStreamPartition("inputSystem", "inputTopic", new Partition(1));
+  private static final String GROUPER_FACTORY_CLASS = GroupByPartitionFactory.class.getCanonicalName();
+  private static final KafkaStreamSpec CHECKPOINT_SPEC =
+      new KafkaStreamSpec(CHECKPOINT_TOPIC, CHECKPOINT_TOPIC, CHECKPOINT_SYSTEM, 1);
+  private static final CheckpointV1Serde CHECKPOINT_V1_SERDE = new CheckpointV1Serde();
+  private static final CheckpointV2Serde CHECKPOINT_V2_SERDE = new CheckpointV2Serde();
+  private static final KafkaCheckpointLogKeySerde KAFKA_CHECKPOINT_LOG_KEY_SERDE = new KafkaCheckpointLogKeySerde();
+
+  @Mock
+  private SystemProducer systemProducer;
+  @Mock
+  private SystemConsumer systemConsumer;
+  @Mock
+  private SystemAdmin systemAdmin;
+  @Mock
+  private SystemAdmin createResourcesSystemAdmin;
+  @Mock
+  private SystemFactory systemFactory;
+  @Mock
+  private MetricsRegistry metricsRegistry;
+
+  @Before
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+  }
+
+  @Test(expected = TopicAlreadyMarkedForDeletionException.class)
+  public void testCreateResourcesTopicCreationError() {
+    setupSystemFactory(config());
+    // throw an exception during createStream
+    doThrow(new TopicAlreadyMarkedForDeletionException("invalid stream")).when(this.createResourcesSystemAdmin)
+        .createStream(CHECKPOINT_SPEC);
+    KafkaCheckpointManager checkpointManager = buildKafkaCheckpointManager(true, config());
+    // expect an exception during startup
+    checkpointManager.createResources();
+  }
+
+  @Test(expected = StreamValidationException.class)
+  public void testCreateResourcesTopicValidationError() {
+    setupSystemFactory(config());
+    // throw an exception during validateStream
+    doThrow(new StreamValidationException("invalid stream")).when(this.createResourcesSystemAdmin)
+        .validateStream(CHECKPOINT_SPEC);
+    KafkaCheckpointManager checkpointManager = buildKafkaCheckpointManager(true, config());
+    // expect an exception during startup
+    checkpointManager.createResources();
+  }
+
+  @Test(expected = SamzaException.class)
+  public void testReadFailsOnSerdeExceptions() throws InterruptedException {
+    setupSystemFactory(config());
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, buildCheckpointV1(INPUT_SSP0, "0"), "0"));
+    setupConsumer(checkpointEnvelopes);
+    // wire up an exception throwing serde with the checkpointManager
+    CheckpointV1Serde checkpointV1Serde = mock(CheckpointV1Serde.class);
+    doThrow(new RuntimeException("serde failed")).when(checkpointV1Serde).fromBytes(any());
+    KafkaCheckpointManager checkpointManager =
+        new KafkaCheckpointManager(CHECKPOINT_SPEC, this.systemFactory, true, config(), this.metricsRegistry,
+            checkpointV1Serde, CHECKPOINT_V2_SERDE, KAFKA_CHECKPOINT_LOG_KEY_SERDE);
+    checkpointManager.register(TASK0);
+
+    // expect an exception
+    checkpointManager.readLastCheckpoint(TASK0);
+  }
+
+  @Test
+  public void testReadSucceedsOnKeySerdeExceptionsWhenValidationIsDisabled() throws InterruptedException {
+    setupSystemFactory(config());
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, buildCheckpointV1(INPUT_SSP0, "0"), "0"));
+    setupConsumer(checkpointEnvelopes);
+    // wire up an exception throwing serde with the checkpointManager
+    CheckpointV1Serde checkpointV1Serde = mock(CheckpointV1Serde.class);
+    doThrow(new RuntimeException("serde failed")).when(checkpointV1Serde).fromBytes(any());
+    KafkaCheckpointManager checkpointManager =
+        new KafkaCheckpointManager(CHECKPOINT_SPEC, this.systemFactory, false, config(), this.metricsRegistry,
+            checkpointV1Serde, CHECKPOINT_V2_SERDE, KAFKA_CHECKPOINT_LOG_KEY_SERDE);
+    checkpointManager.register(TASK0);
+
+    // expect the read to succeed in spite of the exception from ExceptionThrowingSerde
+    assertNull(checkpointManager.readLastCheckpoint(TASK0));
+  }
+
+  @Test
+  public void testStart() {
+    setupSystemFactory(config());
+    String oldestOffset = "1";
+    String newestOffset = "2";
+    SystemStreamMetadata checkpointTopicMetadata = new SystemStreamMetadata(CHECKPOINT_TOPIC,
+        ImmutableMap.of(new Partition(0), new SystemStreamPartitionMetadata(oldestOffset, newestOffset,
+            Integer.toString(Integer.parseInt(newestOffset) + 1))));
+    when(this.systemAdmin.getSystemStreamMetadata(Collections.singleton(CHECKPOINT_TOPIC))).thenReturn(
+        ImmutableMap.of(CHECKPOINT_TOPIC, checkpointTopicMetadata));
+
+    KafkaCheckpointManager checkpointManager = buildKafkaCheckpointManager(true, config());
+
+    checkpointManager.start();
+
+    verify(this.systemProducer).start();
+    verify(this.systemAdmin).start();
+    verify(this.systemConsumer).register(CHECKPOINT_SSP, oldestOffset);
+    verify(this.systemConsumer).start();
+  }
+
+  @Test
+  public void testRegister() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    verify(this.systemProducer).register(TASK0.getTaskName());
+  }
+
+  @Test
+  public void testStop() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager checkpointManager = buildKafkaCheckpointManager(true, config());
+    checkpointManager.stop();
+    verify(this.systemProducer).stop();
+    // default configuration for stopConsumerAfterFirstRead means that consumer is not stopped here
+    verify(this.systemConsumer, never()).stop();
+    verify(this.systemAdmin).stop();
+  }
+
+  @Test
+  public void testWriteCheckpointShouldRecreateSystemProducerOnFailure() {
+    setupSystemFactory(config());
+    SystemProducer secondKafkaProducer = mock(SystemProducer.class);
+    // override default mock behavior to return a second producer on the second call to create a producer
+    when(this.systemFactory.getProducer(CHECKPOINT_SYSTEM, config(), this.metricsRegistry,
+        KafkaCheckpointManager.class.getSimpleName())).thenReturn(this.systemProducer, secondKafkaProducer);
+    // first producer throws an exception on flush
+    doThrow(new RuntimeException("flush failed")).when(this.systemProducer).flush(TASK0.getTaskName());
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+
+    CheckpointV1 checkpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    kafkaCheckpointManager.writeCheckpoint(TASK0, checkpointV1);
+
+    // first producer should be stopped
+    verify(this.systemProducer).stop();
+    // register and start the second producer
+    verify(secondKafkaProducer).register(TASK0.getTaskName());
+    verify(secondKafkaProducer).start();
+    // check that the second producer was given the message to send out
+    ArgumentCaptor<OutgoingMessageEnvelope> outgoingMessageEnvelopeArgumentCaptor =
+        ArgumentCaptor.forClass(OutgoingMessageEnvelope.class);
+    verify(secondKafkaProducer).send(eq(TASK0.getTaskName()), outgoingMessageEnvelopeArgumentCaptor.capture());
+    assertEquals(CHECKPOINT_SSP, outgoingMessageEnvelopeArgumentCaptor.getValue().getSystemStream());
+    assertEquals(new KafkaCheckpointLogKey(KafkaCheckpointLogKey.CHECKPOINT_V1_KEY_TYPE, TASK0, GROUPER_FACTORY_CLASS),
+        KAFKA_CHECKPOINT_LOG_KEY_SERDE.fromBytes((byte[]) outgoingMessageEnvelopeArgumentCaptor.getValue().getKey()));
+    assertEquals(checkpointV1,
+        CHECKPOINT_V1_SERDE.fromBytes((byte[]) outgoingMessageEnvelopeArgumentCaptor.getValue().getMessage()));
+    verify(secondKafkaProducer).flush(TASK0.getTaskName());
+  }
+
+  @Test
+  public void testCreateResources() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.createResources();
+
+    verify(this.createResourcesSystemAdmin).start();
+    verify(this.createResourcesSystemAdmin).createStream(CHECKPOINT_SPEC);
+    verify(this.createResourcesSystemAdmin).validateStream(CHECKPOINT_SPEC);
+    verify(this.createResourcesSystemAdmin).stop();
+  }
+
+  @Test
+  public void testCreateResourcesSkipValidation() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(false, config());
+    kafkaCheckpointManager.createResources();
+
+    verify(this.createResourcesSystemAdmin).start();
+    verify(this.createResourcesSystemAdmin).createStream(CHECKPOINT_SPEC);
+    verify(this.createResourcesSystemAdmin, never()).validateStream(CHECKPOINT_SPEC);
+    verify(this.createResourcesSystemAdmin).stop();
+  }
+
+  @Test
+  public void testReadEmpty() throws InterruptedException {
+    setupSystemFactory(config());
+    setupConsumer(ImmutableList.of());
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    assertNull(kafkaCheckpointManager.readLastCheckpoint(TASK0));
+  }
+
+  @Test
+  public void testReadCheckpointV1() throws InterruptedException {
+    setupSystemFactory(config());
+    CheckpointV1 checkpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, checkpointV1, "0"));
+    setupConsumer(checkpointEnvelopes);
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    Checkpoint actualCheckpoint = kafkaCheckpointManager.readLastCheckpoint(TASK0);
+    assertEquals(checkpointV1, actualCheckpoint);
+  }
+
+  @Test
+  public void testReadIgnoreCheckpointV2WhenV1Enabled() throws InterruptedException {
+    setupSystemFactory(config());
+    CheckpointV1 checkpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, checkpointV1, "0"),
+            newCheckpointV2Envelope(TASK0, buildCheckpointV2(INPUT_SSP0, "1"), "1"));
+    setupConsumer(checkpointEnvelopes);
+    // default is to only read CheckpointV1
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    Checkpoint actualCheckpoint = kafkaCheckpointManager.readLastCheckpoint(TASK0);
+    assertEquals(checkpointV1, actualCheckpoint);
+  }
+
+  @Test
+  public void testReadCheckpointV2() throws InterruptedException {
+    Config config = config(ImmutableMap.of(TaskConfig.CHECKPOINT_READ_VERSIONS, "1,2"));
+    setupSystemFactory(config);
+    CheckpointV2 checkpointV2 = buildCheckpointV2(INPUT_SSP0, "0");
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV2Envelope(TASK0, checkpointV2, "0"));
+    setupConsumer(checkpointEnvelopes);
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config);
+    kafkaCheckpointManager.register(TASK0);
+    Checkpoint actualCheckpoint = kafkaCheckpointManager.readLastCheckpoint(TASK0);
+    assertEquals(checkpointV2, actualCheckpoint);
+  }
+
+  @Test
+  public void testReadCheckpointPriority() throws InterruptedException {
+    Config config = config(ImmutableMap.of(TaskConfig.CHECKPOINT_READ_VERSIONS, "2,1"));
+    setupSystemFactory(config);
+    CheckpointV2 checkpointV2 = buildCheckpointV2(INPUT_SSP0, "1");
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, buildCheckpointV1(INPUT_SSP0, "0"), "0"),
+            newCheckpointV2Envelope(TASK0, checkpointV2, "1"));
+    setupConsumer(checkpointEnvelopes);
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config);
+    kafkaCheckpointManager.register(TASK0);
+    Checkpoint actualCheckpoint = kafkaCheckpointManager.readLastCheckpoint(TASK0);
+    assertEquals(checkpointV2, actualCheckpoint);
+  }
+
+  @Test
+  public void testReadMultipleCheckpointsMultipleSSP() throws InterruptedException {
+    setupSystemFactory(config());
+    KafkaCheckpointManager checkpointManager = buildKafkaCheckpointManager(true, config());
+    checkpointManager.register(TASK0);
+    checkpointManager.register(TASK1);
+
+    // mock out a consumer that returns 5 checkpoint IMEs for each SSP
+    int newestOffset = 5;
+    int checkpointOffsetCounter = 0;
+    List<List<IncomingMessageEnvelope>> pollOutputs = new ArrayList<>();
+    for (int offset = 1; offset <= newestOffset; offset++) {
+      pollOutputs.add(ImmutableList.of(
+          // use regular offset value for INPUT_SSP0
+          newCheckpointV1Envelope(TASK0, buildCheckpointV1(INPUT_SSP0, Integer.toString(offset)),
+              Integer.toString(checkpointOffsetCounter++)),
+          // use (offset * 2) value for INPUT_SSP1 so offsets are different from INPUT_SSP0
+          newCheckpointV1Envelope(TASK1, buildCheckpointV1(INPUT_SSP1, Integer.toString(offset * 2)),
+              Integer.toString(checkpointOffsetCounter++))));
+    }
+    setupConsumerMultiplePoll(pollOutputs);
+
+    assertEquals(buildCheckpointV1(INPUT_SSP0, Integer.toString(newestOffset)),
+        checkpointManager.readLastCheckpoint(TASK0));
+    assertEquals(buildCheckpointV1(INPUT_SSP1, Integer.toString(newestOffset * 2)),
+        checkpointManager.readLastCheckpoint(TASK1));
+    // check expected number of polls (+1 is for the final empty poll), and the checkpoint is the newest message
+    verify(this.systemConsumer, times(newestOffset + 1)).poll(ImmutableSet.of(CHECKPOINT_SSP),
+        SystemConsumer.BLOCK_ON_OUTSTANDING_MESSAGES);
+  }
+
+  @Test
+  public void testReadMultipleCheckpointsUpgradeCheckpointVersion() throws InterruptedException {
+    Config config = config(ImmutableMap.of(TaskConfig.CHECKPOINT_READ_VERSIONS, "2,1"));
+    setupSystemFactory(config);
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config);
+    kafkaCheckpointManager.register(TASK0);
+    kafkaCheckpointManager.register(TASK1);
+
+    List<IncomingMessageEnvelope> checkpointEnvelopesV1 =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, buildCheckpointV1(INPUT_SSP0, "0"), "0"),
+            newCheckpointV1Envelope(TASK1, buildCheckpointV1(INPUT_SSP1, "0"), "1"));
+    CheckpointV2 ssp0CheckpointV2 = buildCheckpointV2(INPUT_SSP0, "10");
+    CheckpointV2 ssp1CheckpointV2 = buildCheckpointV2(INPUT_SSP1, "11");
+    List<IncomingMessageEnvelope> checkpointEnvelopesV2 =
+        ImmutableList.of(newCheckpointV2Envelope(TASK0, ssp0CheckpointV2, "2"),
+            newCheckpointV2Envelope(TASK1, ssp1CheckpointV2, "3"));
+    setupConsumerMultiplePoll(ImmutableList.of(checkpointEnvelopesV1, checkpointEnvelopesV2));
+    assertEquals(ssp0CheckpointV2, kafkaCheckpointManager.readLastCheckpoint(TASK0));
+    assertEquals(ssp1CheckpointV2, kafkaCheckpointManager.readLastCheckpoint(TASK1));
+    // 2 polls for actual checkpoints, 1 final empty poll
+    verify(this.systemConsumer, times(3)).poll(ImmutableSet.of(CHECKPOINT_SSP),
+        SystemConsumer.BLOCK_ON_OUTSTANDING_MESSAGES);
+  }
+
+  @Test
+  public void testWriteCheckpointV1() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    CheckpointV1 checkpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    kafkaCheckpointManager.writeCheckpoint(TASK0, checkpointV1);
+    ArgumentCaptor<OutgoingMessageEnvelope> outgoingMessageEnvelopeArgumentCaptor =
+        ArgumentCaptor.forClass(OutgoingMessageEnvelope.class);
+    verify(this.systemProducer).send(eq(TASK0.getTaskName()), outgoingMessageEnvelopeArgumentCaptor.capture());
+    assertEquals(CHECKPOINT_SSP, outgoingMessageEnvelopeArgumentCaptor.getValue().getSystemStream());
+    assertEquals(new KafkaCheckpointLogKey(KafkaCheckpointLogKey.CHECKPOINT_V1_KEY_TYPE, TASK0, GROUPER_FACTORY_CLASS),
+        KAFKA_CHECKPOINT_LOG_KEY_SERDE.fromBytes((byte[]) outgoingMessageEnvelopeArgumentCaptor.getValue().getKey()));
+    assertEquals(checkpointV1,
+        CHECKPOINT_V1_SERDE.fromBytes((byte[]) outgoingMessageEnvelopeArgumentCaptor.getValue().getMessage()));
+    verify(this.systemProducer).flush(TASK0.getTaskName());
+  }
+
+  @Test
+  public void testWriteCheckpointV2() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    CheckpointV2 checkpointV2 = buildCheckpointV2(INPUT_SSP0, "0");
+    kafkaCheckpointManager.writeCheckpoint(TASK0, checkpointV2);
+    ArgumentCaptor<OutgoingMessageEnvelope> outgoingMessageEnvelopeArgumentCaptor =
+        ArgumentCaptor.forClass(OutgoingMessageEnvelope.class);
+    verify(this.systemProducer).send(eq(TASK0.getTaskName()), outgoingMessageEnvelopeArgumentCaptor.capture());
+    assertEquals(CHECKPOINT_SSP, outgoingMessageEnvelopeArgumentCaptor.getValue().getSystemStream());
+    assertEquals(new KafkaCheckpointLogKey(KafkaCheckpointLogKey.CHECKPOINT_V2_KEY_TYPE, TASK0, GROUPER_FACTORY_CLASS),
+        KAFKA_CHECKPOINT_LOG_KEY_SERDE.fromBytes((byte[]) outgoingMessageEnvelopeArgumentCaptor.getValue().getKey()));
+    assertEquals(checkpointV2,
+        CHECKPOINT_V2_SERDE.fromBytes((byte[]) outgoingMessageEnvelopeArgumentCaptor.getValue().getMessage()));
+    verify(this.systemProducer).flush(TASK0.getTaskName());
+  }
+
+  @Test
+  public void testWriteCheckpointShouldRetryFiniteTimesOnFailure() {
+    setupSystemFactory(config());
+    doThrow(new RuntimeException("send failed")).when(this.systemProducer).send(any(), any());
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    kafkaCheckpointManager.MaxRetryDurationInMillis_$eq(100); // setter for scala var MaxRetryDurationInMillis
+    CheckpointV2 checkpointV2 = buildCheckpointV2(INPUT_SSP0, "0");
+    try {
+      kafkaCheckpointManager.writeCheckpoint(TASK0, checkpointV2);
+      fail("Expected to throw SamzaException");
+    } catch (SamzaException e) {
+      // expected to get here
+    }
+    // one call to send which fails, then writeCheckpoint gives up
+    verify(this.systemProducer).send(any(), any());
+    verify(this.systemProducer, never()).flush(any());
+  }
+
+  @Test
+  public void testConsumerStopsAfterInitialRead() throws Exception {
+    setupSystemFactory(config());
+    CheckpointV1 checkpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    setupConsumer(ImmutableList.of(newCheckpointV1Envelope(TASK0, checkpointV1, "0")));
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    assertEquals(checkpointV1, kafkaCheckpointManager.readLastCheckpoint(TASK0));
+    // 1 call to get actual checkpoints, 1 call for empty poll to signal done reading
+    verify(this.systemConsumer, times(2)).poll(ImmutableSet.of(CHECKPOINT_SSP), SystemConsumer.BLOCK_ON_OUTSTANDING_MESSAGES);
+    verify(this.systemConsumer).stop();
+    // reading checkpoint again should not read more messages from the consumer
+    assertEquals(checkpointV1, kafkaCheckpointManager.readLastCheckpoint(TASK0));
+    verifyNoMoreInteractions(this.systemConsumer);
+  }
+
+  @Test
+  public void testConsumerStopsAfterInitialReadDisabled() throws Exception {
+    Config config =
+        config(ImmutableMap.of(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, "false"));
+    setupSystemFactory(config);
+    // 1) return checkpointV1 for INPUT_SSP
+    CheckpointV1 ssp0FirstCheckpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    List<IncomingMessageEnvelope> checkpointEnvelopes0 =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, buildCheckpointV1(INPUT_SSP0, "0"), "0"));
+    setupConsumer(checkpointEnvelopes0);
+    KafkaCheckpointManager kafkaCheckpointManager = buildKafkaCheckpointManager(true, config);
+    kafkaCheckpointManager.register(TASK0);
+    assertEquals(ssp0FirstCheckpointV1, kafkaCheckpointManager.readLastCheckpoint(TASK0));
+
+    // 2) return new checkpointV1 for just INPUT_SSP
+    CheckpointV1 ssp0SecondCheckpointV1 = buildCheckpointV1(INPUT_SSP0, "10");
+    List<IncomingMessageEnvelope> checkpointEnvelopes1 =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, ssp0SecondCheckpointV1, "1"));
+    setupConsumer(checkpointEnvelopes1);
+    assertEquals(ssp0SecondCheckpointV1, kafkaCheckpointManager.readLastCheckpoint(TASK0));
+
+    verify(this.systemConsumer, never()).stop();
+  }
+
+  private KafkaCheckpointManager buildKafkaCheckpointManager(boolean validateCheckpoint, Config config) {
+    return new KafkaCheckpointManager(CHECKPOINT_SPEC, this.systemFactory, validateCheckpoint, config,
+        this.metricsRegistry, CHECKPOINT_V1_SERDE, CHECKPOINT_V2_SERDE, KAFKA_CHECKPOINT_LOG_KEY_SERDE);
+  }
+
+  private void setupConsumer(List<IncomingMessageEnvelope> pollOutput) throws InterruptedException {
+    setupConsumerMultiplePoll(ImmutableList.of(pollOutput));
+  }
+
+  /**
+   * Create a new {@link SystemConsumer} that returns a list of messages sequentially at each subsequent poll.
+   *
+   * @param pollOutputs a list of poll outputs to be returned at subsequent polls.
+   *                    The i'th call to consumer.poll() will return the list at pollOutputs[i]
+   */
+  private void setupConsumerMultiplePoll(List<List<IncomingMessageEnvelope>> pollOutputs) throws InterruptedException {
+    OngoingStubbing<Map<SystemStreamPartition, List<IncomingMessageEnvelope>>> when =
+        when(this.systemConsumer.poll(ImmutableSet.of(CHECKPOINT_SSP), SystemConsumer.BLOCK_ON_OUTSTANDING_MESSAGES));
+    for (List<IncomingMessageEnvelope> pollOutput : pollOutputs) {
+      when = when.thenReturn(ImmutableMap.of(CHECKPOINT_SSP, pollOutput));
+    }
+    when.thenReturn(ImmutableMap.of());
+  }
+
+  private void setupSystemFactory(Config config) {
+    when(this.systemFactory.getProducer(CHECKPOINT_SYSTEM, config, this.metricsRegistry,
+        KafkaCheckpointManager.class.getSimpleName())).thenReturn(this.systemProducer);
+    when(this.systemFactory.getConsumer(CHECKPOINT_SYSTEM, config, this.metricsRegistry,
+        KafkaCheckpointManager.class.getSimpleName())).thenReturn(this.systemConsumer);
+    when(this.systemFactory.getAdmin(CHECKPOINT_SYSTEM, config,
+        KafkaCheckpointManager.class.getSimpleName())).thenReturn(this.systemAdmin);
+    when(this.systemFactory.getAdmin(CHECKPOINT_SYSTEM, config,
+        KafkaCheckpointManager.class.getSimpleName() + "createResource")).thenReturn(this.createResourcesSystemAdmin);
+  }
+
+  private static CheckpointV1 buildCheckpointV1(SystemStreamPartition ssp, String offset) {
+    return new CheckpointV1(ImmutableMap.of(ssp, offset));
+  }
+
+  /**
+   * Creates a new checkpoint envelope for the provided task, ssp and offset
+   */
+  private IncomingMessageEnvelope newCheckpointV1Envelope(TaskName taskName, CheckpointV1 checkpointV1,
+      String checkpointMessageOffset) {
+    KafkaCheckpointLogKey checkpointKey = new KafkaCheckpointLogKey("checkpoint", taskName, GROUPER_FACTORY_CLASS);
+    KafkaCheckpointLogKeySerde checkpointKeySerde = new KafkaCheckpointLogKeySerde();
+    CheckpointV1Serde checkpointMsgSerde = new CheckpointV1Serde();
+    return new IncomingMessageEnvelope(CHECKPOINT_SSP, checkpointMessageOffset,
+        checkpointKeySerde.toBytes(checkpointKey), checkpointMsgSerde.toBytes(checkpointV1));
+  }
+
+  private static CheckpointV2 buildCheckpointV2(SystemStreamPartition ssp, String offset) {
+    return new CheckpointV2(CheckpointId.create(), ImmutableMap.of(ssp, offset),
+        ImmutableMap.of("backend", ImmutableMap.of("store", "10")));
+  }
+
+  private IncomingMessageEnvelope newCheckpointV2Envelope(TaskName taskName, CheckpointV2 checkpointV2,
+      String checkpointMessageOffset) {
+    KafkaCheckpointLogKey checkpointKey =
+        new KafkaCheckpointLogKey(KafkaCheckpointLogKey.CHECKPOINT_V2_KEY_TYPE, taskName, GROUPER_FACTORY_CLASS);
+    KafkaCheckpointLogKeySerde checkpointKeySerde = new KafkaCheckpointLogKeySerde();
+    CheckpointV2Serde checkpointMsgSerde = new CheckpointV2Serde();
+    return new IncomingMessageEnvelope(CHECKPOINT_SSP, checkpointMessageOffset,
+        checkpointKeySerde.toBytes(checkpointKey), checkpointMsgSerde.toBytes(checkpointV2));
+  }
+
+  /**
+   * Build base {@link Config} for tests.
+   */
+  private static Config config() {
+    return new MapConfig(ImmutableMap.of(JobConfig.SSP_GROUPER_FACTORY, GROUPER_FACTORY_CLASS));
+  }
+
+  private static Config config(Map<String, String> additional) {
+    Map<String, String> configMap = new HashMap<>(config());
+    configMap.putAll(additional);
+    return new MapConfig(configMap);
+  }
+}
\ No newline at end of file
diff --git a/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManagerJava.java b/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManagerJava.java
deleted file mode 100644
index d0e927f..0000000
--- a/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManagerJava.java
+++ /dev/null
@@ -1,285 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.checkpoint.kafka;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import kafka.common.KafkaException;
-import kafka.common.TopicAlreadyMarkedForDeletionException;
-import org.apache.samza.Partition;
-import org.apache.samza.SamzaException;
-import org.apache.samza.checkpoint.CheckpointV1;
-import org.apache.samza.config.Config;
-import org.apache.samza.config.JobConfig;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.container.grouper.stream.GroupByPartitionFactory;
-import org.apache.samza.metrics.MetricsRegistry;
-import org.apache.samza.serializers.CheckpointV1Serde;
-import org.apache.samza.serializers.CheckpointV2Serde;
-import org.apache.samza.system.IncomingMessageEnvelope;
-import org.apache.samza.system.StreamValidationException;
-import org.apache.samza.system.SystemAdmin;
-import org.apache.samza.system.SystemConsumer;
-import org.apache.samza.system.SystemFactory;
-import org.apache.samza.system.SystemProducer;
-import org.apache.samza.system.SystemStreamMetadata;
-import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata;
-import org.apache.samza.system.SystemStreamPartition;
-import org.apache.samza.system.kafka.KafkaStreamSpec;
-import org.junit.Assert;
-import org.junit.Test;
-import org.mockito.stubbing.OngoingStubbing;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-
-import static org.mockito.Mockito.*;
-
-public class TestKafkaCheckpointManagerJava {
-  private static final TaskName TASK1 = new TaskName("task1");
-  private static final String CHECKPOINT_TOPIC = "topic-1";
-  private static final String CHECKPOINT_SYSTEM = "system-1";
-  private static final Partition CHECKPOINT_PARTITION = new Partition(0);
-  private static final SystemStreamPartition CHECKPOINT_SSP =
-      new SystemStreamPartition(CHECKPOINT_SYSTEM, CHECKPOINT_TOPIC, CHECKPOINT_PARTITION);
-  private static final String GROUPER_FACTORY_CLASS = GroupByPartitionFactory.class.getCanonicalName();
-
-  @Test(expected = TopicAlreadyMarkedForDeletionException.class)
-  public void testStartFailsOnTopicCreationErrors() {
-
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-    // create an admin that throws an exception during createStream
-    SystemAdmin mockAdmin = newAdmin("0", "10");
-    doThrow(new TopicAlreadyMarkedForDeletionException("invalid stream")).when(mockAdmin).createStream(checkpointSpec);
-
-    SystemFactory factory = newFactory(mock(SystemProducer.class), mock(SystemConsumer.class), mockAdmin);
-    KafkaCheckpointManager checkpointManager = new KafkaCheckpointManager(checkpointSpec, factory,
-        true, mock(Config.class), mock(MetricsRegistry.class), null, null, new KafkaCheckpointLogKeySerde());
-
-    // expect an exception during startup
-    checkpointManager.createResources();
-    checkpointManager.start();
-  }
-
-  @Test(expected = StreamValidationException.class)
-  public void testStartFailsOnTopicValidationErrors() {
-
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-
-    // create an admin that throws an exception during validateStream
-    SystemAdmin mockAdmin = newAdmin("0", "10");
-    doThrow(new StreamValidationException("invalid stream")).when(mockAdmin).validateStream(checkpointSpec);
-
-    SystemFactory factory = newFactory(mock(SystemProducer.class), mock(SystemConsumer.class), mockAdmin);
-    KafkaCheckpointManager checkpointManager = new KafkaCheckpointManager(checkpointSpec, factory,
-        true, mock(Config.class), mock(MetricsRegistry.class), null, null, new KafkaCheckpointLogKeySerde());
-
-    // expect an exception during startup
-    checkpointManager.createResources();
-    checkpointManager.start();
-  }
-
-  @Test(expected = SamzaException.class)
-  public void testReadFailsOnSerdeExceptions() throws Exception {
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-    Config mockConfig = mock(Config.class);
-    when(mockConfig.get(JobConfig.SSP_GROUPER_FACTORY)).thenReturn(GROUPER_FACTORY_CLASS);
-
-    // mock out a consumer that returns a single checkpoint IME
-    SystemStreamPartition ssp = new SystemStreamPartition("system-1", "input-topic", new Partition(0));
-    List<List<IncomingMessageEnvelope>> checkpointEnvelopes = ImmutableList.of(
-        ImmutableList.of(newCheckpointEnvelope(TASK1, ssp, "0")));
-    SystemConsumer mockConsumer = newConsumer(checkpointEnvelopes);
-
-    SystemAdmin mockAdmin = newAdmin("0", "1");
-    SystemFactory factory = newFactory(mock(SystemProducer.class), mockConsumer, mockAdmin);
-
-    // wire up an exception throwing serde with the checkpointmanager
-    KafkaCheckpointManager checkpointManager = new KafkaCheckpointManager(checkpointSpec, factory,
-        true, mockConfig, mock(MetricsRegistry.class), new ExceptionThrowingCheckpointV1Serde(), null, new KafkaCheckpointLogKeySerde());
-    checkpointManager.register(TASK1);
-    checkpointManager.start();
-
-    // expect an exception from ExceptionThrowingSerde
-    checkpointManager.readLastCheckpoint(TASK1);
-  }
-
-  @Test
-  public void testReadSucceedsOnKeySerdeExceptionsWhenValidationIsDisabled() throws Exception {
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-    Config mockConfig = mock(Config.class);
-    when(mockConfig.get(JobConfig.SSP_GROUPER_FACTORY)).thenReturn(GROUPER_FACTORY_CLASS);
-
-    // mock out a consumer that returns a single checkpoint IME
-    SystemStreamPartition ssp = new SystemStreamPartition("system-1", "input-topic", new Partition(0));
-    List<List<IncomingMessageEnvelope>> checkpointEnvelopes = ImmutableList.of(
-        ImmutableList.of(newCheckpointEnvelope(TASK1, ssp, "0")));
-    SystemConsumer mockConsumer = newConsumer(checkpointEnvelopes);
-
-    SystemAdmin mockAdmin = newAdmin("0", "1");
-    SystemFactory factory = newFactory(mock(SystemProducer.class), mockConsumer, mockAdmin);
-
-    // wire up an exception throwing serde with the checkpointmanager
-    KafkaCheckpointManager checkpointManager = new KafkaCheckpointManager(checkpointSpec, factory,
-        false, mockConfig, mock(MetricsRegistry.class), new ExceptionThrowingCheckpointV1Serde(), null,
-        new ExceptionThrowingCheckpointKeySerde());
-    checkpointManager.register(TASK1);
-    checkpointManager.start();
-
-    // expect the read to succeed inspite of the exception from ExceptionThrowingSerde
-    checkpointManager.readLastCheckpoint(TASK1);
-  }
-
-  @Test
-  public void testCheckpointsAreReadFromOldestOffset() throws Exception {
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-    Config mockConfig = mock(Config.class);
-    when(mockConfig.get(JobConfig.SSP_GROUPER_FACTORY)).thenReturn(GROUPER_FACTORY_CLASS);
-
-    // mock out a consumer that returns a single checkpoint IME
-    SystemStreamPartition ssp = new SystemStreamPartition("system-1", "input-topic", new Partition(0));
-    SystemConsumer mockConsumer = newConsumer(ImmutableList.of(
-        ImmutableList.of(newCheckpointEnvelope(TASK1, ssp, "0"))));
-
-    String oldestOffset = "0";
-    SystemAdmin mockAdmin = newAdmin(oldestOffset, "1");
-    SystemFactory factory = newFactory(mock(SystemProducer.class), mockConsumer, mockAdmin);
-    KafkaCheckpointManager checkpointManager = new KafkaCheckpointManager(checkpointSpec, factory,
-        true, mockConfig, mock(MetricsRegistry.class), new CheckpointV1Serde(), new CheckpointV2Serde(),
-        new KafkaCheckpointLogKeySerde());
-    checkpointManager.register(TASK1);
-
-    // 1. verify that consumer.register is called only during checkpointManager.start.
-    // 2. verify that consumer.register is called with the oldest offset.
-    // 3. verify that no other operation on the CheckpointManager re-invokes register since start offsets are set during
-    // register
-    verify(mockConsumer, times(0)).register(CHECKPOINT_SSP, oldestOffset);
-    checkpointManager.start();
-    verify(mockConsumer, times(1)).register(CHECKPOINT_SSP, oldestOffset);
-
-    checkpointManager.readLastCheckpoint(TASK1);
-    verify(mockConsumer, times(1)).register(CHECKPOINT_SSP, oldestOffset);
-  }
-
-  @Test
-  public void testAllMessagesInTheLogAreRead() throws Exception {
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-    Config mockConfig = mock(Config.class);
-    when(mockConfig.get(JobConfig.SSP_GROUPER_FACTORY)).thenReturn(GROUPER_FACTORY_CLASS);
-
-    SystemStreamPartition ssp = new SystemStreamPartition("system-1", "input-topic", new Partition(0));
-
-    int oldestOffset = 0;
-    int newestOffset = 10;
-
-    // mock out a consumer that returns ten checkpoint IMEs for the same ssp
-    List<List<IncomingMessageEnvelope>> pollOutputs = new ArrayList<>();
-    for (int offset = oldestOffset; offset <= newestOffset; offset++) {
-      pollOutputs.add(ImmutableList.of(newCheckpointEnvelope(TASK1, ssp, Integer.toString(offset))));
-    }
-
-    // return one message at a time from each poll simulating a KafkaConsumer with max.poll.records = 1
-    SystemConsumer mockConsumer = newConsumer(pollOutputs);
-    SystemAdmin mockAdmin = newAdmin(Integer.toString(oldestOffset), Integer.toString(newestOffset));
-    SystemFactory factory = newFactory(mock(SystemProducer.class), mockConsumer, mockAdmin);
-
-    KafkaCheckpointManager checkpointManager = new KafkaCheckpointManager(checkpointSpec, factory,
-        true, mockConfig, mock(MetricsRegistry.class), new CheckpointV1Serde(), new CheckpointV2Serde(),
-        new KafkaCheckpointLogKeySerde());
-    checkpointManager.register(TASK1);
-    checkpointManager.start();
-
-    // check that all ten messages are read, and the checkpoint is the newest message
-    CheckpointV1 checkpoint = (CheckpointV1) checkpointManager.readLastCheckpoint(TASK1);
-    Assert.assertEquals(checkpoint.getOffsets(), ImmutableMap.of(ssp, Integer.toString(newestOffset)));
-  }
-
-  /**
-   * Create a new {@link SystemConsumer} that returns a list of messages sequentially at each subsequent poll.
-   *
-   * @param pollOutputs a list of poll outputs to be returned at subsequent polls.
-   *                    The i'th call to consumer.poll() will return the list at pollOutputs[i]
-   * @return the consumer
-   */
-  private SystemConsumer newConsumer(List<List<IncomingMessageEnvelope>> pollOutputs) throws Exception {
-    SystemConsumer mockConsumer = mock(SystemConsumer.class);
-    OngoingStubbing<Map> when = when(mockConsumer.poll(anySet(), anyLong()));
-    for (List<IncomingMessageEnvelope> pollOutput : pollOutputs) {
-      when = when.thenReturn(ImmutableMap.of(CHECKPOINT_SSP, pollOutput));
-    }
-    when.thenReturn(ImmutableMap.of());
-    return mockConsumer;
-  }
-
-  /**
-   * Create a new {@link SystemAdmin} that returns the provided oldest and newest offsets for its topics
-   */
-  private SystemAdmin newAdmin(String oldestOffset, String newestOffset) {
-    SystemStreamMetadata checkpointTopicMetadata = new SystemStreamMetadata(CHECKPOINT_TOPIC,
-        ImmutableMap.of(new Partition(0), new SystemStreamPartitionMetadata(oldestOffset,
-            newestOffset, Integer.toString(Integer.parseInt(newestOffset) + 1))));
-    SystemAdmin mockAdmin = mock(SystemAdmin.class);
-    when(mockAdmin.getSystemStreamMetadata(Collections.singleton(CHECKPOINT_TOPIC))).thenReturn(
-        ImmutableMap.of(CHECKPOINT_TOPIC, checkpointTopicMetadata));
-    return mockAdmin;
-  }
-
-  private SystemFactory newFactory(SystemProducer producer, SystemConsumer consumer, SystemAdmin admin) {
-    SystemFactory factory = mock(SystemFactory.class);
-    when(factory.getProducer(anyString(), any(Config.class), any(MetricsRegistry.class), anyString())).thenReturn(producer);
-    when(factory.getConsumer(anyString(), any(Config.class), any(MetricsRegistry.class), anyString())).thenReturn(consumer);
-    when(factory.getAdmin(anyString(), any(Config.class), anyString())).thenReturn(admin);
-    return factory;
-  }
-
-  /**
-   * Creates a new checkpoint envelope for the provided task, ssp and offset
-   */
-  private IncomingMessageEnvelope newCheckpointEnvelope(TaskName taskName, SystemStreamPartition ssp, String offset) {
-    KafkaCheckpointLogKey checkpointKey =
-        new KafkaCheckpointLogKey("checkpoint", taskName, GROUPER_FACTORY_CLASS);
-    KafkaCheckpointLogKeySerde checkpointKeySerde = new KafkaCheckpointLogKeySerde();
-
-    CheckpointV1 checkpointMsg = new CheckpointV1(ImmutableMap.of(ssp, offset));
-    CheckpointV1Serde checkpointMsgSerde = new CheckpointV1Serde();
-
-    return new IncomingMessageEnvelope(CHECKPOINT_SSP, offset, checkpointKeySerde.toBytes(checkpointKey),
-        checkpointMsgSerde.toBytes(checkpointMsg));
-  }
-
-  private static class ExceptionThrowingCheckpointV1Serde extends CheckpointV1Serde {
-    public CheckpointV1 fromBytes(byte[] bytes) {
-      throw new KafkaException("exception");
-    }
-  }
-
-  private static class ExceptionThrowingCheckpointKeySerde extends KafkaCheckpointLogKeySerde {
-    public KafkaCheckpointLogKey fromBytes(byte[] bytes) {
-      throw new KafkaException("exception");
-    }
-  }
-}
diff --git a/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala b/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
deleted file mode 100644
index 835f53e..0000000
--- a/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
+++ /dev/null
@@ -1,533 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.checkpoint.kafka
-
-import java.util.Properties
-import kafka.integration.KafkaServerTestHarness
-import kafka.utils.{CoreUtils, TestUtils}
-import com.google.common.collect.ImmutableMap
-import org.apache.samza.checkpoint.{Checkpoint, CheckpointId, CheckpointV1, CheckpointV2}
-import org.apache.samza.config._
-import org.apache.samza.container.TaskName
-import org.apache.samza.container.grouper.stream.GroupByPartitionFactory
-import org.apache.samza.metrics.MetricsRegistry
-import org.apache.samza.serializers.{CheckpointV1Serde, CheckpointV2Serde}
-import org.apache.samza.system._
-import org.apache.samza.system.kafka.{KafkaStreamSpec, KafkaSystemFactory}
-import org.apache.samza.util.ScalaJavaUtil.JavaOptionals
-import org.apache.samza.util.{NoOpMetricsRegistry, ReflectionUtil}
-import org.apache.samza.{Partition, SamzaException}
-import org.junit.Assert._
-import org.junit._
-import org.mockito.Mockito
-import org.mockito.Matchers
-
-class TestKafkaCheckpointManager extends KafkaServerTestHarness {
-
-  protected def numBrokers: Int = 3
-
-  val checkpointSystemName = "kafka"
-  val sspGrouperFactoryName = classOf[GroupByPartitionFactory].getCanonicalName
-
-  val ssp = new SystemStreamPartition("kafka", "topic", new Partition(0))
-  val checkpoint1 = new CheckpointV1(ImmutableMap.of(ssp, "offset-1"))
-  val checkpoint2 = new CheckpointV1(ImmutableMap.of(ssp, "offset-2"))
-  val taskName = new TaskName("Partition 0")
-  var config: Config = null
-
-  @Before
-  override def setUp {
-    super.setUp
-    TestUtils.waitUntilTrue(() => servers.head.metadataCache.getAliveBrokers.size == numBrokers, "Wait for cache to update")
-    config = getConfig()
-  }
-
-  override def generateConfigs() = {
-    val props = TestUtils.createBrokerConfigs(numBrokers, zkConnect, enableControlledShutdown = true)
-    // do not use relative imports
-    props.map(_root_.kafka.server.KafkaConfig.fromProps)
-  }
-
-  @Test
-  def testWriteCheckpointShouldRecreateSystemProducerOnFailure(): Unit = {
-    val checkpointTopic = "checkpoint-topic-2"
-    val mockKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])
-
-    class MockSystemFactory extends KafkaSystemFactory {
-      override def getProducer(systemName: String, config: Config, registry: MetricsRegistry): SystemProducer = {
-        mockKafkaProducer
-      }
-    }
-
-    Mockito.doThrow(new RuntimeException()).when(mockKafkaProducer).flush(taskName.getTaskName)
-
-    val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
-    val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
-    val checkPointManager = Mockito.spy(new KafkaCheckpointManager(spec, new MockSystemFactory, false, config, new NoOpMetricsRegistry))
-    val newKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])
-
-    Mockito.doReturn(newKafkaProducer).when(checkPointManager).getSystemProducer()
-
-    checkPointManager.register(taskName)
-    checkPointManager.start
-    checkPointManager.writeCheckpoint(taskName, new CheckpointV1(ImmutableMap.of()))
-    checkPointManager.stop()
-
-    // Verifications after the test
-
-    Mockito.verify(mockKafkaProducer).stop()
-    Mockito.verify(newKafkaProducer).register(taskName.getTaskName)
-    Mockito.verify(newKafkaProducer).start()
-  }
-
-  @Test
-  def testCheckpointShouldBeNullIfCheckpointTopicDoesNotExistShouldBeCreatedOnWriteAndShouldBeReadableAfterWrite(): Unit = {
-    val checkpointTopic = "checkpoint-topic-1"
-    val kcm1 = createKafkaCheckpointManager(checkpointTopic)
-    kcm1.register(taskName)
-    kcm1.createResources
-    kcm1.start
-    kcm1.stop
-
-    // check that start actually creates the topic with log compaction enabled
-    val topicConfig = adminZkClient.getAllTopicConfigs().getOrElse(checkpointTopic, new Properties())
-
-    assertEquals(topicConfig, new KafkaConfig(config).getCheckpointTopicProperties())
-    assertEquals("compact", topicConfig.get("cleanup.policy"))
-    assertEquals("26214400", topicConfig.get("segment.bytes"))
-
-    // read before topic exists should result in a null checkpoint
-    val readCp = readCheckpoint(checkpointTopic, taskName)
-    assertNull(readCp)
-
-    writeCheckpoint(checkpointTopic, taskName, checkpoint1)
-    assertEquals(checkpoint1, readCheckpoint(checkpointTopic, taskName))
-
-    // writing a second message and reading it returns a more recent checkpoint
-    writeCheckpoint(checkpointTopic, taskName, checkpoint2)
-    assertEquals(checkpoint2, readCheckpoint(checkpointTopic, taskName))
-  }
-
-  @Test
-  def testCheckpointV1AndV2WriteAndReadV1(): Unit = {
-    val checkpointTopic = "checkpoint-topic-1"
-    val kcm1 = createKafkaCheckpointManager(checkpointTopic)
-    kcm1.register(taskName)
-    kcm1.createResources
-    kcm1.start
-    kcm1.stop
-
-    // check that start actually creates the topic with log compaction enabled
-    val topicConfig = adminZkClient.getAllTopicConfigs().getOrElse(checkpointTopic, new Properties())
-
-    assertEquals(topicConfig, new KafkaConfig(config).getCheckpointTopicProperties())
-    assertEquals("compact", topicConfig.get("cleanup.policy"))
-    assertEquals("26214400", topicConfig.get("segment.bytes"))
-
-    // read before topic exists should result in a null checkpoint
-    val readCp = readCheckpoint(checkpointTopic, taskName)
-    assertNull(readCp)
-
-    val checkpointV1 = new CheckpointV1(ImmutableMap.of(ssp, "offset-1"))
-    val checkpointV2 = new CheckpointV2(CheckpointId.create(), ImmutableMap.of(ssp, "offset-2"),
-      ImmutableMap.of("factory1", ImmutableMap.of("store1", "changelogOffset")))
-
-    // skips v2 checkpoints from checkpoint topic
-    writeCheckpoint(checkpointTopic, taskName, checkpointV2)
-    assertNull(readCheckpoint(checkpointTopic, taskName))
-
-    // reads latest v1 checkpoints
-    writeCheckpoint(checkpointTopic, taskName, checkpointV1)
-    assertEquals(checkpointV1, readCheckpoint(checkpointTopic, taskName))
-
-    // writing checkpoint v2 still returns the previous v1 checkpoint
-    writeCheckpoint(checkpointTopic, taskName, checkpointV2)
-    assertEquals(checkpointV1, readCheckpoint(checkpointTopic, taskName))
-  }
-
-  @Test
-  def testCheckpointV1AndV2WriteAndReadV2(): Unit = {
-    val checkpointTopic = "checkpoint-topic-1"
-    val kcm1 = createKafkaCheckpointManager(checkpointTopic)
-    kcm1.register(taskName)
-    kcm1.createResources
-    kcm1.start
-    kcm1.stop
-
-    // check that start actually creates the topic with log compaction enabled
-    val topicConfig = adminZkClient.getAllTopicConfigs().getOrElse(checkpointTopic, new Properties())
-
-    assertEquals(topicConfig, new KafkaConfig(config).getCheckpointTopicProperties())
-    assertEquals("compact", topicConfig.get("cleanup.policy"))
-    assertEquals("26214400", topicConfig.get("segment.bytes"))
-
-    // read before topic exists should result in a null checkpoint
-    val readCp = readCheckpoint(checkpointTopic, taskName)
-    assertNull(readCp)
-
-    val checkpointV1 = new CheckpointV1(ImmutableMap.of(ssp, "offset-1"))
-    val checkpointV2 = new CheckpointV2(CheckpointId.create(), ImmutableMap.of(ssp, "offset-2"),
-      ImmutableMap.of("factory1", ImmutableMap.of("store1", "changelogOffset")))
-
-    val overrideConfig = new MapConfig(new ImmutableMap.Builder[String, String]()
-      .put(JobConfig.JOB_NAME, "some-job-name")
-      .put(JobConfig.JOB_ID, "i001")
-      .put(s"systems.$checkpointSystemName.samza.factory", classOf[KafkaSystemFactory].getCanonicalName)
-      .put(s"systems.$checkpointSystemName.producer.bootstrap.servers", brokerList)
-      .put(s"systems.$checkpointSystemName.consumer.zookeeper.connect", zkConnect)
-      .put("task.checkpoint.system", checkpointSystemName)
-      .put(TaskConfig.CHECKPOINT_READ_VERSIONS, "2")
-      .build())
-
-    // Skips reading any v1 checkpoints
-    writeCheckpoint(checkpointTopic, taskName, checkpointV1)
-    assertNull(readCheckpoint(checkpointTopic, taskName, overrideConfig))
-
-    // writing a v2 checkpoint would allow reading it back
-    writeCheckpoint(checkpointTopic, taskName, checkpointV2)
-    assertEquals(checkpointV2, readCheckpoint(checkpointTopic, taskName, overrideConfig))
-
-    // writing v1 checkpoint is still skipped
-    writeCheckpoint(checkpointTopic, taskName, checkpointV1)
-    assertEquals(checkpointV2, readCheckpoint(checkpointTopic, taskName, overrideConfig))
-  }
-
-  @Test
-  def testCheckpointV1AndV2WriteAndReadV1V2PrecedenceList(): Unit = {
-    val checkpointTopic = "checkpoint-topic-1"
-    val kcm1 = createKafkaCheckpointManager(checkpointTopic)
-    kcm1.register(taskName)
-    kcm1.createResources
-    kcm1.start
-    kcm1.stop
-
-    // check that start actually creates the topic with log compaction enabled
-    val topicConfig = adminZkClient.getAllTopicConfigs().getOrElse(checkpointTopic, new Properties())
-
-    assertEquals(topicConfig, new KafkaConfig(config).getCheckpointTopicProperties())
-    assertEquals("compact", topicConfig.get("cleanup.policy"))
-    assertEquals("26214400", topicConfig.get("segment.bytes"))
-
-    // read before topic exists should result in a null checkpoint
-    val readCp = readCheckpoint(checkpointTopic, taskName)
-    assertNull(readCp)
-
-    val checkpointV1 = new CheckpointV1(ImmutableMap.of(ssp, "offset-1"))
-    val checkpointV2 = new CheckpointV2(CheckpointId.create(), ImmutableMap.of(ssp, "offset-2"),
-      ImmutableMap.of("factory1", ImmutableMap.of("store1", "changelogOffset")))
-
-    val overrideConfig = new MapConfig(new ImmutableMap.Builder[String, String]()
-      .put(JobConfig.JOB_NAME, "some-job-name")
-      .put(JobConfig.JOB_ID, "i001")
-      .put(s"systems.$checkpointSystemName.samza.factory", classOf[KafkaSystemFactory].getCanonicalName)
-      .put(s"systems.$checkpointSystemName.producer.bootstrap.servers", brokerList)
-      .put(s"systems.$checkpointSystemName.consumer.zookeeper.connect", zkConnect)
-      .put("task.checkpoint.system", checkpointSystemName)
-      .put(TaskConfig.CHECKPOINT_READ_VERSIONS, "2,1")
-      .build())
-
-    // Still reads any v1 checkpoints due to precedence list
-    writeCheckpoint(checkpointTopic, taskName, checkpointV1)
-    assertEquals(checkpointV1, readCheckpoint(checkpointTopic, taskName, overrideConfig))
-
-    // writing a v2 checkpoint would allow reading it back
-    writeCheckpoint(checkpointTopic, taskName, checkpointV2)
-    assertEquals(checkpointV2, readCheckpoint(checkpointTopic, taskName, overrideConfig))
-
-    // writing v1 checkpoint is still skipped
-    writeCheckpoint(checkpointTopic, taskName, checkpointV1)
-    assertEquals(checkpointV2, readCheckpoint(checkpointTopic, taskName, overrideConfig))
-
-    val newCheckpointV2 = new CheckpointV2(CheckpointId.create(), ImmutableMap.of(ssp, "offset-3"),
-      ImmutableMap.of("factory1", ImmutableMap.of("store1", "changelogOffset")))
-    // writing v2 returns a new checkpoint v2
-    writeCheckpoint(checkpointTopic, taskName, newCheckpointV2)
-    assertEquals(newCheckpointV2, readCheckpoint(checkpointTopic, taskName, overrideConfig))
-  }
-
-  @Test
-  def testCheckpointValidationSkipped(): Unit = {
-    val checkpointTopic = "checkpoint-topic-1"
-    val kcm1 = createKafkaCheckpointManager(checkpointTopic, serde = new MockCheckpointSerde(),
-      failOnTopicValidation = false)
-    kcm1.register(taskName)
-    kcm1.start
-    kcm1.writeCheckpoint(taskName, new CheckpointV1(ImmutableMap.of(ssp, "offset-1")))
-    kcm1.readLastCheckpoint(taskName)
-    kcm1.stop
-  }
-
-  @Test
-  def testReadCheckpointShouldIgnoreUnknownCheckpointKeys(): Unit = {
-      val checkpointTopic = "checkpoint-topic-1"
-      val kcm1 = createKafkaCheckpointManager(checkpointTopic)
-      kcm1.register(taskName)
-      kcm1.createResources
-      kcm1.start
-      kcm1.stop
-
-      // check that start actually creates the topic with log compaction enabled
-      val topicConfig = adminZkClient.getAllTopicConfigs().getOrElse(checkpointTopic, new Properties())
-
-      assertEquals(topicConfig, new KafkaConfig(config).getCheckpointTopicProperties())
-      assertEquals("compact", topicConfig.get("cleanup.policy"))
-      assertEquals("26214400", topicConfig.get("segment.bytes"))
-
-      // read before topic exists should result in a null checkpoint
-      val readCp = readCheckpoint(checkpointTopic, taskName)
-      assertNull(readCp)
-    // skips unknown checkpoints from checkpoint topic
-    writeCheckpoint(checkpointTopic, taskName, checkpoint1, "checkpoint-v2", useMock = true)
-    assertNull(readCheckpoint(checkpointTopic, taskName, useMock = true))
-
-    // reads latest v1 checkpoints
-    writeCheckpoint(checkpointTopic, taskName, checkpoint1, useMock = true)
-    assertEquals(checkpoint1, readCheckpoint(checkpointTopic, taskName, useMock = true))
-
-    // writing checkpoint v2 still returns the previous v1 checkpoint
-    writeCheckpoint(checkpointTopic, taskName, checkpoint2, "checkpoint-v2", useMock = true)
-    assertEquals(checkpoint1, readCheckpoint(checkpointTopic, taskName, useMock = true))
-
-    // writing checkpoint2 with the correct key returns the checkpoint2
-    writeCheckpoint(checkpointTopic, taskName, checkpoint2, useMock = true)
-    assertEquals(checkpoint2, readCheckpoint(checkpointTopic, taskName, useMock = true))
-  }
-
-  @Test
-  def testWriteCheckpointShouldRetryFiniteTimesOnFailure(): Unit = {
-    val checkpointTopic = "checkpoint-topic-2"
-    val mockKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])
-    val mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer])
-
-    Mockito.doThrow(new RuntimeException()).when(mockKafkaProducer).flush(taskName.getTaskName)
-
-    val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
-    val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
-    val checkPointManager = new KafkaCheckpointManager(spec, new MockSystemFactory(mockKafkaSystemConsumer, mockKafkaProducer), false, config, new NoOpMetricsRegistry)
-    checkPointManager.MaxRetryDurationInMillis = 1
-
-    try {
-      checkPointManager.register(taskName)
-      checkPointManager.start
-      checkPointManager.writeCheckpoint(taskName, new CheckpointV1(ImmutableMap.of()))
-    } catch {
-      case _: SamzaException => info("Got SamzaException as expected.")
-      case unexpectedException: Throwable => fail("Expected SamzaException but got %s" format unexpectedException)
-    } finally {
-      checkPointManager.stop()
-    }
-  }
-
-  @Test
-  def testFailOnTopicValidation(): Unit = {
-    // By default, should fail if there is a topic validation error
-    val checkpointTopic = "eight-partition-topic";
-    val kcm = createKafkaCheckpointManager(checkpointTopic)
-    kcm.register(taskName)
-    // create topic with the wrong number of partitions
-    createTopic(checkpointTopic, 8, new KafkaConfig(config).getCheckpointTopicProperties())
-    try {
-      kcm.createResources()
-      kcm.start()
-      fail("Expected an exception for invalid number of partitions in the checkpoint topic.")
-    } catch {
-      case e: StreamValidationException => None
-    }
-    kcm.stop()
-  }
-
-  @Test
-  def testNoFailOnTopicValidationDisabled(): Unit = {
-    val checkpointTopic = "eight-partition-topic";
-    // create topic with the wrong number of partitions
-    createTopic(checkpointTopic, 8, new KafkaConfig(config).getCheckpointTopicProperties())
-    val failOnTopicValidation = false
-    val kcm = createKafkaCheckpointManager(checkpointTopic, new CheckpointV1Serde, failOnTopicValidation)
-    kcm.register(taskName)
-    kcm.createResources()
-    kcm.start()
-    kcm.stop()
-  }
-
-  @Test
-  def testConsumerStopsAfterInitialReadIfConfigSetTrue(): Unit = {
-    val mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer])
-
-    val checkpointTopic = "checkpoint-topic-test"
-    val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
-    val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
-
-    val configMapWithOverride = new java.util.HashMap[String, String](config)
-    configMapWithOverride.put(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, "true")
-    val kafkaCheckpointManager = new KafkaCheckpointManager(spec, new MockSystemFactory(mockKafkaSystemConsumer), false, new MapConfig(configMapWithOverride), new NoOpMetricsRegistry)
-
-    kafkaCheckpointManager.register(taskName)
-    kafkaCheckpointManager.start()
-    kafkaCheckpointManager.readLastCheckpoint(taskName)
-
-    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).register(Matchers.any(), Matchers.any())
-    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).start()
-    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).poll(Matchers.any(), Matchers.any())
-    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).stop()
-
-    kafkaCheckpointManager.stop()
-
-    Mockito.verifyNoMoreInteractions(mockKafkaSystemConsumer)
-  }
-
-  @Test
-  def testConsumerDoesNotStopAfterInitialReadIfConfigSetFalse(): Unit = {
-    val mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer])
-
-    val checkpointTopic = "checkpoint-topic-test"
-    val props = new org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
-    val spec = new KafkaStreamSpec("id", checkpointTopic, checkpointSystemName, 1, 1, props)
-
-    val configMapWithOverride = new java.util.HashMap[String, String](config)
-    configMapWithOverride.put(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, "false")
-    val kafkaCheckpointManager = new KafkaCheckpointManager(spec, new MockSystemFactory(mockKafkaSystemConsumer), false, new MapConfig(configMapWithOverride), new NoOpMetricsRegistry)
-
-    kafkaCheckpointManager.register(taskName)
-    kafkaCheckpointManager.start()
-    kafkaCheckpointManager.readLastCheckpoint(taskName)
-
-    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(0)).stop()
-
-    kafkaCheckpointManager.stop()
-
-    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).stop()
-  }
-
-  @After
-  override def tearDown(): Unit = {
-    if (servers != null) {
-      servers.foreach(_.shutdown())
-      servers.foreach(server => CoreUtils.delete(server.config.logDirs))
-    }
-    super.tearDown
-  }
-
-  private def getConfig(): Config = {
-    new MapConfig(new ImmutableMap.Builder[String, String]()
-      .put(JobConfig.JOB_NAME, "some-job-name")
-      .put(JobConfig.JOB_ID, "i001")
-      .put(s"systems.$checkpointSystemName.samza.factory", classOf[KafkaSystemFactory].getCanonicalName)
-      .put(s"systems.$checkpointSystemName.producer.bootstrap.servers", brokerList)
-      .put(s"systems.$checkpointSystemName.consumer.zookeeper.connect", zkConnect)
-      .put("task.checkpoint.system", checkpointSystemName)
-      .build())
-  }
-
-  private def createKafkaCheckpointManager(cpTopic: String, serde: CheckpointV1Serde = new CheckpointV1Serde,
-    failOnTopicValidation: Boolean = true, useMock: Boolean = false, checkpointKey: String = KafkaCheckpointLogKey.CHECKPOINT_V1_KEY_TYPE,
-    overrideConfig: Config = config) = {
-    val kafkaConfig = new org.apache.samza.config.KafkaConfig(overrideConfig)
-    val props = kafkaConfig.getCheckpointTopicProperties()
-    val systemName = kafkaConfig.getCheckpointSystem.getOrElse(
-      throw new SamzaException("No system defined for Kafka's checkpoint manager."))
-
-    val systemConfig = new SystemConfig(overrideConfig)
-    val systemFactoryClassName = JavaOptionals.toRichOptional(systemConfig.getSystemFactory(systemName)).toOption
-      .getOrElse(throw new SamzaException("Missing configuration: " + SystemConfig.SYSTEM_FACTORY_FORMAT format systemName))
-
-    val systemFactory = ReflectionUtil.getObj(systemFactoryClassName, classOf[SystemFactory])
-
-    val spec = new KafkaStreamSpec("id", cpTopic, checkpointSystemName, 1, 1, props)
-
-    if (useMock) {
-      new MockKafkaCheckpointManager(spec, systemFactory, failOnTopicValidation, serde, checkpointKey)
-    } else {
-      new KafkaCheckpointManager(spec, systemFactory, failOnTopicValidation, overrideConfig, new NoOpMetricsRegistry, serde)
-    }
-  }
-
-  private def readCheckpoint(checkpointTopic: String, taskName: TaskName, config: Config = config,
-    useMock: Boolean = false) : Checkpoint = {
-    val kcm = createKafkaCheckpointManager(checkpointTopic, overrideConfig = config, useMock = useMock)
-    kcm.register(taskName)
-    kcm.start
-    val checkpoint = kcm.readLastCheckpoint(taskName)
-    kcm.stop
-    checkpoint
-  }
-
-  private def writeCheckpoint(checkpointTopic: String, taskName: TaskName, checkpoint: Checkpoint,
-    checkpointKey: String = KafkaCheckpointLogKey.CHECKPOINT_V1_KEY_TYPE, useMock: Boolean = false): Unit = {
-    val kcm = createKafkaCheckpointManager(checkpointTopic, checkpointKey = checkpointKey, useMock = useMock)
-    kcm.register(taskName)
-    kcm.start
-    kcm.writeCheckpoint(taskName, checkpoint)
-    kcm.stop
-  }
-
-  private def createTopic(cpTopic: String, partNum: Int, props: Properties) {
-    adminZkClient.createTopic(cpTopic, partNum, 1, props)
-  }
-
-  class MockSystemFactory(
-    mockKafkaSystemConsumer: SystemConsumer = Mockito.mock(classOf[SystemConsumer]),
-    mockKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])) extends KafkaSystemFactory {
-    override def getProducer(systemName: String, config: Config, registry: MetricsRegistry): SystemProducer = {
-      mockKafkaProducer
-    }
-
-    override def getConsumer(systemName: String, config: Config, registry: MetricsRegistry): SystemConsumer = {
-      mockKafkaSystemConsumer
-    }
-  }
-
-  class MockCheckpointSerde() extends CheckpointV1Serde {
-    override def fromBytes(bytes: Array[Byte]): CheckpointV1 = {
-      throw new SamzaException("Failed to deserialize")
-    }
-  }
-
-
-  class MockKafkaCheckpointManager(spec: KafkaStreamSpec, systemFactory: SystemFactory, failOnTopicValidation: Boolean,
-    serde: CheckpointV1Serde = new CheckpointV1Serde, checkpointKey: String)
-    extends KafkaCheckpointManager(spec, systemFactory, failOnTopicValidation, config,
-      new NoOpMetricsRegistry, serde) {
-
-    override def buildOutgoingMessageEnvelope[T <: Checkpoint](taskName: TaskName, checkpoint: T): OutgoingMessageEnvelope = {
-      val key = new KafkaCheckpointLogKey(checkpointKey, taskName, expectedGrouperFactory)
-      val keySerde = new KafkaCheckpointLogKeySerde
-      val checkpointMsgSerde = new CheckpointV1Serde
-      val checkpointV2MsgSerde = new CheckpointV2Serde
-      val keyBytes = try {
-        keySerde.toBytes(key)
-      } catch {
-        case e: Exception => throw new SamzaException(s"Exception when writing checkpoint-key for $taskName: $checkpoint", e)
-      }
-      val msgBytes = try {
-        checkpoint match {
-          case v1: CheckpointV1 =>
-            checkpointMsgSerde.toBytes(v1)
-          case v2: CheckpointV2 =>
-            checkpointV2MsgSerde.toBytes(v2)
-          case _ =>
-            throw new IllegalArgumentException("Unknown checkpoint key type for test, please use Checkpoint v1 or v2")
-        }
-      } catch {
-        case e: Exception => throw new SamzaException(s"Exception when writing checkpoint for $taskName: $checkpoint", e)
-      }
-      new OutgoingMessageEnvelope(checkpointSsp, keyBytes, msgBytes)
-    }
-  }
-}
diff --git a/samza-test/src/test/java/org/apache/samza/test/harness/IntegrationTestHarness.java b/samza-test/src/test/java/org/apache/samza/test/harness/IntegrationTestHarness.java
index 57987af..c1db4d1 100644
--- a/samza-test/src/test/java/org/apache/samza/test/harness/IntegrationTestHarness.java
+++ b/samza-test/src/test/java/org/apache/samza/test/harness/IntegrationTestHarness.java
@@ -109,6 +109,7 @@ public class IntegrationTestHarness extends AbstractKafkaServerTestHarness {
     * it shouldn't impact the tests nor have any side effects.
     */
     adminClient.close(ADMIN_OPERATION_WAIT_DURATION_MS, TimeUnit.MILLISECONDS);
+    consumer.unsubscribe();
     consumer.close();
     producer.close();
     super.tearDown();
diff --git a/samza-test/src/test/java/org/apache/samza/test/kafka/KafkaCheckpointManagerIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/kafka/KafkaCheckpointManagerIntegrationTest.java
new file mode 100644
index 0000000..612647c
--- /dev/null
+++ b/samza-test/src/test/java/org/apache/samza/test/kafka/KafkaCheckpointManagerIntegrationTest.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.test.kafka;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import com.google.common.collect.ImmutableMap;
+import org.apache.samza.application.TaskApplication;
+import org.apache.samza.application.descriptors.TaskApplicationDescriptor;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.JobCoordinatorConfig;
+import org.apache.samza.config.KafkaConfig;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.kafka.descriptors.KafkaInputDescriptor;
+import org.apache.samza.system.kafka.descriptors.KafkaSystemDescriptor;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.StreamTask;
+import org.apache.samza.task.StreamTaskFactory;
+import org.apache.samza.task.TaskCoordinator;
+import org.apache.samza.test.framework.StreamApplicationIntegrationTestHarness;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+
+/**
+ * 1) Run app and consume messages
+ * 2) Commit only for first message
+ * 3) Shutdown application
+ * 4) Run app a second time to use the checkpoint
+ * 5) Verify that we had to re-process the message after the first message
+ */
+public class KafkaCheckpointManagerIntegrationTest extends StreamApplicationIntegrationTestHarness {
+  private static final String SYSTEM = "kafka";
+  private static final String INPUT_STREAM = "inputStream";
+  private static final Map<String, String> CONFIGS = ImmutableMap.of(
+      JobCoordinatorConfig.JOB_COORDINATOR_FACTORY, "org.apache.samza.standalone.PassthroughJobCoordinatorFactory",
+      JobConfig.PROCESSOR_ID, "0",
+      TaskConfig.CHECKPOINT_MANAGER_FACTORY, "org.apache.samza.checkpoint.kafka.KafkaCheckpointManagerFactory",
+      KafkaConfig.CHECKPOINT_REPLICATION_FACTOR(), "1",
+      TaskConfig.COMMIT_MS, "-1"); // manual commit only
+  /**
+   * Keep track of which messages have been received by the application.
+   */
+  private static final Map<String, AtomicInteger> PROCESSED = new HashMap<>();
+
+  /**
+   * If message has this prefix, then request a commit after processing it.
+   */
+  private static final String COMMIT_PREFIX = "commit";
+  /**
+   * If message equals this string, then shut down the task if the task is configured to handle intermediate shutdown.
+   */
+  private static final String INTERMEDIATE_SHUTDOWN = "intermediateShutdown";
+  /**
+   * If message equals this string, then shut down the task.
+   */
+  private static final String END_OF_STREAM = "endOfStream";
+
+  @Before
+  public void setup() {
+    PROCESSED.clear();
+  }
+
+  @Test
+  public void testCheckpoint() {
+    createTopic(INPUT_STREAM, 2);
+    produceMessages(0);
+    produceMessages(1);
+
+    // run application once and verify processed messages before shutdown
+    runApplication(new CheckpointApplication(true), "CheckpointApplication", CONFIGS).getRunner().waitForFinish();
+    verifyProcessedMessagesFirstRun();
+
+    // run application a second time and verify that certain messages had to be re-processed
+    runApplication(new CheckpointApplication(false), "CheckpointApplication", CONFIGS).getRunner().waitForFinish();
+    verifyProcessedMessagesSecondRun();
+  }
+
+  private void produceMessages(int partitionId) {
+    String key = "key" + partitionId;
+    // commit first message
+    produceMessage(INPUT_STREAM, partitionId, key, commitMessage(partitionId, 0));
+    // don't commit second message
+    produceMessage(INPUT_STREAM, partitionId, key, noCommitMessage(partitionId, 1));
+    // do an initial shutdown so that the test can check that the second message gets re-processed
+    produceMessage(INPUT_STREAM, partitionId, key, INTERMEDIATE_SHUTDOWN);
+    // do a commit on the third message
+    produceMessage(INPUT_STREAM, partitionId, key, commitMessage(partitionId, 2));
+    // this will make the task shut down for the second run
+    produceMessage(INPUT_STREAM, partitionId, key, END_OF_STREAM);
+  }
+
+  /**
+   * Each partition should have seen two messages before shutting down.
+   */
+  private static void verifyProcessedMessagesFirstRun() {
+    assertEquals(4, PROCESSED.size());
+    assertEquals(1, PROCESSED.get(commitMessage(0, 0)).get());
+    assertEquals(1, PROCESSED.get(noCommitMessage(0, 1)).get());
+    assertEquals(1, PROCESSED.get(commitMessage(0, 0)).get());
+    assertEquals(1, PROCESSED.get(noCommitMessage(0, 1)).get());
+  }
+
+  /**
+   * For each partition: re-process the second message (for 2 total of the second message), receive the third message.
+   */
+  private static void verifyProcessedMessagesSecondRun() {
+    assertEquals(6, PROCESSED.size());
+    assertEquals(1, PROCESSED.get(commitMessage(0, 0)).get());
+    assertEquals(2, PROCESSED.get(noCommitMessage(0, 1)).get());
+    assertEquals(1, PROCESSED.get(commitMessage(0, 2)).get());
+    assertEquals(1, PROCESSED.get(commitMessage(1, 0)).get());
+    assertEquals(2, PROCESSED.get(noCommitMessage(1, 1)).get());
+    assertEquals(1, PROCESSED.get(commitMessage(1, 2)).get());
+  }
+
+  private static String commitMessage(int partitionId, int messageId) {
+    return String.join("_", COMMIT_PREFIX, "partition", Integer.toString(partitionId), Integer.toString(messageId));
+  }
+
+  private static String noCommitMessage(int partitionId, int messageId) {
+    return String.join("_", "partition", Integer.toString(partitionId), Integer.toString(messageId));
+  }
+
+  private static class CheckpointApplication implements TaskApplication {
+    private final boolean handleIntermediateShutdown;
+
+    private CheckpointApplication(boolean handleIntermediateShutdown) {
+      this.handleIntermediateShutdown = handleIntermediateShutdown;
+    }
+
+    @Override
+    public void describe(TaskApplicationDescriptor appDescriptor) {
+      KafkaSystemDescriptor sd = new KafkaSystemDescriptor(SYSTEM);
+      KafkaInputDescriptor<String> isd = sd.getInputDescriptor(INPUT_STREAM, new StringSerde());
+      appDescriptor.withInputStream(isd)
+          .withTaskFactory((StreamTaskFactory) () -> new CheckpointTask(this.handleIntermediateShutdown));
+    }
+  }
+
+  private static class CheckpointTask implements StreamTask {
+    /**
+     * Determine if task should respond to {@link #INTERMEDIATE_SHUTDOWN}.
+     * Helps with testing that any uncommitted messages get reprocessed if the job starts again.
+     */
+    private final boolean handleIntermediateShutdown;
+    /**
+     * When requesting shutdown, there is no guarantee of an immediate shutdown, since there are multiple tasks in the
+     * container. Use this flag to make sure we don't process more messages past the shutdown request in order to have
+     * deterministic counting of the messages for the test.
+     */
+    private boolean stopProcessing = false;
+
+    private CheckpointTask(boolean handleIntermediateShutdown) {
+      this.handleIntermediateShutdown = handleIntermediateShutdown;
+    }
+
+    @Override
+    public void process(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator) {
+      if (!this.stopProcessing) {
+        String value = (String) envelope.getMessage();
+        if (INTERMEDIATE_SHUTDOWN.equals(value)) {
+          if (this.handleIntermediateShutdown) {
+            setShutdown(coordinator);
+          }
+        } else if (END_OF_STREAM.equals(value)) {
+          setShutdown(coordinator);
+        } else {
+          synchronized (this) {
+            PROCESSED.putIfAbsent(value, new AtomicInteger(0));
+            PROCESSED.get(value).incrementAndGet();
+          }
+          if (value.startsWith(COMMIT_PREFIX)) {
+            coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+          }
+        }
+      }
+    }
+
+    private void setShutdown(TaskCoordinator coordinator) {
+      this.stopProcessing = true;
+      coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+    }
+  }
+}