You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2016/08/31 17:28:42 UTC
[24/27] flink git commit: [FLINK-4380] Add tests for new
Key-Group/Max-Parallelism
[FLINK-4380] Add tests for new Key-Group/Max-Parallelism
This tests the rescaling features in CheckpointCoordinator and
SavepointCoordinator.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/516ad011
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/516ad011
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/516ad011
Branch: refs/heads/master
Commit: 516ad011865ca5beece273ca9b985e2861b3435a
Parents: 847ead0
Author: Till Rohrmann <tr...@apache.org>
Authored: Thu Aug 11 12:14:18 2016 +0200
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Wed Aug 31 19:10:01 2016 +0200
----------------------------------------------------------------------
.../checkpoint/CheckpointCoordinatorTest.java | 733 ++++++++++++++++++-
.../runtime/tasks/OneInputStreamTaskTest.java | 280 ++++++-
.../test/checkpointing/RescalingITCase.java | 1 -
3 files changed, 1007 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/516ad011/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 50330fa..495dced 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -18,28 +18,45 @@
package org.apache.flink.runtime.checkpoint;
+import com.google.common.collect.Iterables;
import org.apache.flink.api.common.JobID;
+import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.runtime.checkpoint.savepoint.HeapSavepointStore;
import org.apache.flink.runtime.checkpoint.stats.DisabledCheckpointStatsTracker;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
+import org.junit.Assert;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
+import scala.concurrent.ExecutionContext;
import scala.concurrent.Future;
+import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
@@ -47,12 +64,14 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -124,7 +143,7 @@ public class CheckpointCoordinatorTest {
final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID();
final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID();
ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1);
- ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2, ExecutionState.FINISHED);
+ ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2, new JobVertexID(), 1, 1, ExecutionState.FINISHED);
// create some mock Execution vertices that need to ack the checkpoint
final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID();
@@ -1529,7 +1548,7 @@ public class CheckpointCoordinatorTest {
coord.startCheckpointScheduler();
// after a while, there should be exactly as many checkpoints
- // as concurrently permitted
+ // as concurrently permitted
long now = System.currentTimeMillis();
long timeout = now + 60000;
long minDuration = now + 100;
@@ -1622,7 +1641,7 @@ public class CheckpointCoordinatorTest {
}
while (System.currentTimeMillis() < timeout &&
coord.getNumberOfPendingCheckpoints() == 0);
-
+
assertTrue(coord.getNumberOfPendingCheckpoints() > 0);
}
catch (Exception e) {
@@ -1738,4 +1757,712 @@ public class CheckpointCoordinatorTest {
return vertex;
}
+/**
+ * Tests that the checkpointed partitioned and non-partitioned state is assigned properly to
+ * the {@link Execution} upon recovery.
+ *
+ * @throws Exception
+ */
+ @Test
+ public void testRestoreLatestCheckpointedState() throws Exception {
+ final JobID jid = new JobID();
+ final long timestamp = System.currentTimeMillis();
+
+ final JobVertexID jobVertexID1 = new JobVertexID();
+ final JobVertexID jobVertexID2 = new JobVertexID();
+ int parallelism1 = 3;
+ int parallelism2 = 2;
+ int maxParallelism1 = 42;
+ int maxParallelism2 = 13;
+
+ final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+ jobVertexID1,
+ parallelism1,
+ maxParallelism1);
+ final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
+ jobVertexID2,
+ parallelism2,
+ maxParallelism2);
+
+ List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+
+ allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+ allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+
+ ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]);
+
+ // set up the coordinator and validate the initial state
+ CheckpointCoordinator coord = new CheckpointCoordinator(
+ jid,
+ 600000,
+ 600000,
+ 0,
+ Integer.MAX_VALUE,
+ arrayExecutionVertices,
+ arrayExecutionVertices,
+ arrayExecutionVertices,
+ cl,
+ new StandaloneCheckpointIDCounter(),
+ new StandaloneCompletedCheckpointStore(1, cl),
+ new HeapSavepointStore(),
+ new DisabledCheckpointStatsTracker());
+
+ // trigger the checkpoint
+ coord.triggerCheckpoint(timestamp);
+
+ assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+ long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+ List<KeyGroupRange> keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1);
+ List<KeyGroupRange> keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+ for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+ ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID1, index);
+ List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
+
+ AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+ jid,
+ jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+ checkpointId,
+ nonPartitionedState,
+ partitionedKeyGroupState);
+
+ coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+ }
+
+
+ for (int index = 0; index < jobVertex2.getParallelism(); index++) {
+ ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID2, index);
+ List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
+
+ AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+ jid,
+ jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+ checkpointId,
+ nonPartitionedState,
+ partitionedKeyGroupState);
+
+ coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+ }
+
+ List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+
+ assertEquals(1, completedCheckpoints.size());
+
+ Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+ tasks.put(jobVertexID1, jobVertex1);
+ tasks.put(jobVertexID2, jobVertex2);
+
+ coord.restoreLatestCheckpointedState(tasks, true, true);
+
+ // verify the restored state
+ verifiyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1);
+ verifiyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2);
+ }
+
+ /**
+ * Tests that the checkpoint restoration fails if the max parallelism of the job vertices has
+ * changed.
+ *
+ * @throws Exception
+ */
+ @Test(expected=IllegalStateException.class)
+ public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws Exception {
+ final JobID jid = new JobID();
+ final long timestamp = System.currentTimeMillis();
+
+ final JobVertexID jobVertexID1 = new JobVertexID();
+ final JobVertexID jobVertexID2 = new JobVertexID();
+ int parallelism1 = 3;
+ int parallelism2 = 2;
+ int maxParallelism1 = 42;
+ int maxParallelism2 = 13;
+
+ final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+ jobVertexID1,
+ parallelism1,
+ maxParallelism1);
+ final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
+ jobVertexID2,
+ parallelism2,
+ maxParallelism2);
+
+ List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+
+ allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+ allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+
+ ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]);
+
+ // set up the coordinator and validate the initial state
+ CheckpointCoordinator coord = new CheckpointCoordinator(
+ jid,
+ 600000,
+ 600000,
+ 0,
+ Integer.MAX_VALUE,
+ arrayExecutionVertices,
+ arrayExecutionVertices,
+ arrayExecutionVertices,
+ cl,
+ new StandaloneCheckpointIDCounter(),
+ new StandaloneCompletedCheckpointStore(1, cl),
+ new HeapSavepointStore(),
+ new DisabledCheckpointStatsTracker());
+
+ // trigger the checkpoint
+ coord.triggerCheckpoint(timestamp);
+
+ assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+ long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+ List<KeyGroupRange> keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1);
+ List<KeyGroupRange> keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+ for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+ ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+ List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
+
+ AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+ jid,
+ jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+ checkpointId,
+ valueSizeTuple,
+ keyGroupState);
+
+ coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+ }
+
+
+ for (int index = 0; index < jobVertex2.getParallelism(); index++) {
+ ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID2, index);
+ List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
+
+ AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+ jid,
+ jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+ checkpointId,
+ valueSizeTuple,
+ keyGroupState);
+
+ coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+ }
+
+ List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+
+ assertEquals(1, completedCheckpoints.size());
+
+ Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+ int newMaxParallelism1 = 20;
+ int newMaxParallelism2 = 42;
+
+ final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
+ jobVertexID1,
+ parallelism1,
+ newMaxParallelism1);
+
+ final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
+ jobVertexID2,
+ parallelism2,
+ newMaxParallelism2);
+
+ tasks.put(jobVertexID1, newJobVertex1);
+ tasks.put(jobVertexID2, newJobVertex2);
+
+ coord.restoreLatestCheckpointedState(tasks, true, true);
+
+ fail("The restoration should have failed because the max parallelism changed.");
+ }
+
+ /**
+ * Tests that the checkpoint restoration fails if the parallelism of a job vertices with
+ * non-partitioned state has changed.
+ *
+ * @throws Exception
+ */
+ @Test(expected=IllegalStateException.class)
+ public void testRestoreLatestCheckpointFailureWhenParallelismChanges() throws Exception {
+ final JobID jid = new JobID();
+ final long timestamp = System.currentTimeMillis();
+
+ final JobVertexID jobVertexID1 = new JobVertexID();
+ final JobVertexID jobVertexID2 = new JobVertexID();
+ int parallelism1 = 3;
+ int parallelism2 = 2;
+ int maxParallelism1 = 42;
+ int maxParallelism2 = 13;
+
+ final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+ jobVertexID1,
+ parallelism1,
+ maxParallelism1);
+ final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
+ jobVertexID2,
+ parallelism2,
+ maxParallelism2);
+
+ List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+
+ allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+ allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+
+ ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]);
+
+ // set up the coordinator and validate the initial state
+ CheckpointCoordinator coord = new CheckpointCoordinator(
+ jid,
+ 600000,
+ 600000,
+ 0,
+ Integer.MAX_VALUE,
+ arrayExecutionVertices,
+ arrayExecutionVertices,
+ arrayExecutionVertices,
+ cl,
+ new StandaloneCheckpointIDCounter(),
+ new StandaloneCompletedCheckpointStore(1, cl),
+ new HeapSavepointStore(),
+ new DisabledCheckpointStatsTracker());
+
+ // trigger the checkpoint
+ coord.triggerCheckpoint(timestamp);
+
+ assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+ long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+ List<KeyGroupRange> keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1);
+ List<KeyGroupRange> keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+ for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+ ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+ List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
+ jobVertexID1, keyGroupPartitions1.get(index));
+
+ AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+ jid,
+ jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+ checkpointId,
+ valueSizeTuple,
+ keyGroupState);
+
+ coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+ }
+
+
+ for (int index = 0; index < jobVertex2.getParallelism(); index++) {
+
+ ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID2, index);
+ List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
+ jobVertexID2, keyGroupPartitions2.get(index));
+
+ AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+ jid,
+ jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+ checkpointId,
+ state,
+ keyGroupState);
+
+ coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+ }
+
+ List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+
+ assertEquals(1, completedCheckpoints.size());
+
+ Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+ int newParallelism1 = 4;
+ int newParallelism2 = 3;
+
+ final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
+ jobVertexID1,
+ newParallelism1,
+ maxParallelism1);
+
+ final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
+ jobVertexID2,
+ newParallelism2,
+ maxParallelism2);
+
+ tasks.put(jobVertexID1, newJobVertex1);
+ tasks.put(jobVertexID2, newJobVertex2);
+
+ coord.restoreLatestCheckpointedState(tasks, true, true);
+
+ fail("The restoration should have failed because the parallelism of an vertex with " +
+ "non-partitioned state changed.");
+ }
+
+ /**
+ * Tests the checkpoint restoration with changing parallelism of job vertex with partitioned
+ * state.
+ *
+ * @throws Exception
+ */
+ @Test
+ public void testRestoreLatestCheckpointedStateWithChangingParallelism() throws Exception {
+ final JobID jid = new JobID();
+ final long timestamp = System.currentTimeMillis();
+
+ final JobVertexID jobVertexID1 = new JobVertexID();
+ final JobVertexID jobVertexID2 = new JobVertexID();
+ int parallelism1 = 3;
+ int parallelism2 = 2;
+ int maxParallelism1 = 42;
+ int maxParallelism2 = 13;
+
+ final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+ jobVertexID1,
+ parallelism1,
+ maxParallelism1);
+ final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
+ jobVertexID2,
+ parallelism2,
+ maxParallelism2);
+
+ List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+
+ allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+ allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+
+ ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]);
+
+ // set up the coordinator and validate the initial state
+ CheckpointCoordinator coord = new CheckpointCoordinator(
+ jid,
+ 600000,
+ 600000,
+ 0,
+ Integer.MAX_VALUE,
+ arrayExecutionVertices,
+ arrayExecutionVertices,
+ arrayExecutionVertices,
+ cl,
+ new StandaloneCheckpointIDCounter(),
+ new StandaloneCompletedCheckpointStore(1, cl),
+ new HeapSavepointStore(),
+ new DisabledCheckpointStatsTracker());
+
+ // trigger the checkpoint
+ coord.triggerCheckpoint(timestamp);
+
+ assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+ long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+ List<KeyGroupRange> keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1);
+ List<KeyGroupRange> keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+ for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+ ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+ List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
+
+ AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+ jid,
+ jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+ checkpointId,
+ valueSizeTuple,
+ keyGroupState);
+
+ coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+ }
+
+
+ for (int index = 0; index < jobVertex2.getParallelism(); index++) {
+ List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
+
+ AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+ jid,
+ jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+ checkpointId,
+ null,
+ keyGroupState);
+
+ coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+ }
+
+ List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+
+ assertEquals(1, completedCheckpoints.size());
+
+ Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+ int newParallelism2 = 13;
+
+ List<KeyGroupRange> newKeyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, newParallelism2);
+
+ final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
+ jobVertexID1,
+ parallelism1,
+ maxParallelism1);
+
+ final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
+ jobVertexID2,
+ newParallelism2,
+ maxParallelism2);
+
+ tasks.put(jobVertexID1, newJobVertex1);
+ tasks.put(jobVertexID2, newJobVertex2);
+ coord.restoreLatestCheckpointedState(tasks, true, true);
+
+ // verify the restored state
+ verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
+
+ for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
+ List<KeyGroupsStateHandle> originalKeyGroupState = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i));
+
+ ChainedStateHandle<StreamStateHandle> operatorState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
+ List<KeyGroupsStateHandle> keyGroupState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
+
+ assertNull(operatorState);
+ comparePartitionedState(originalKeyGroupState, keyGroupState);
+ }
+ }
+
+ // ------------------------------------------------------------------------
+ // Utilities
+ // ------------------------------------------------------------------------
+
+ static void sendAckMessageToCoordinator(
+ CheckpointCoordinator coord,
+ long checkpointId, JobID jid,
+ ExecutionJobVertex jobVertex,
+ JobVertexID jobVertexID,
+ List<KeyGroupRange> keyGroupPartitions) throws Exception {
+
+ for (int index = 0; index < jobVertex.getParallelism(); index++) {
+ ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID, index);
+ List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
+ jobVertexID,
+ keyGroupPartitions.get(index));
+
+ AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+ jid,
+ jobVertex.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+ checkpointId,
+ state,
+ keyGroupState);
+
+ coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+ }
+ }
+
+ public static List<KeyGroupsStateHandle> generateKeyGroupState(
+ JobVertexID jobVertexID,
+ KeyGroupRange keyGroupPartition) throws IOException {
+
+ KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupPartition);
+ List<Integer> testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups());
+ int runningGroupsOffset = 0;
+ // generate state for one keygroup
+ for (int keyGroupIndex : keyGroupPartition) {
+ Random random = new Random(jobVertexID.hashCode() + keyGroupIndex);
+ int simulatedStateValue = random.nextInt();
+ testStatesLists.add(simulatedStateValue);
+ }
+
+ return generateKeyGroupState(keyGroupPartition, testStatesLists);
+ }
+
+ public static List<KeyGroupsStateHandle> generateKeyGroupState(KeyGroupRange keyGroupRange, List< ? extends Serializable> states) throws IOException {
+ Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size());
+
+ long[] offsets = new long[keyGroupRange.getNumberOfKeyGroups()];
+ List<byte[]> serializedGroupValues = new ArrayList<>(offsets.length);
+
+ KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, offsets);
+
+ int runningGroupsOffset = 0;
+ // generate test state for all keygroups
+ int idx = 0;
+ for (int keyGroup : keyGroupRange) {
+ keyGroupRangeOffsets.setKeyGroupOffset(keyGroup,runningGroupsOffset);
+ byte[] serializedValue = InstantiationUtil.serializeObject(states.get(idx));
+ runningGroupsOffset += serializedValue.length;
+ serializedGroupValues.add(serializedValue);
+ ++idx;
+ }
+
+ //write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray
+ byte[] allSerializedValuesConcatenated = new byte[runningGroupsOffset];
+ runningGroupsOffset = 0;
+ byte[] old = null;
+ for(byte[] serializedGroupValue : serializedGroupValues) {
+ System.arraycopy(
+ serializedGroupValue,
+ 0,
+ allSerializedValuesConcatenated,
+ runningGroupsOffset,
+ serializedGroupValue.length);
+ runningGroupsOffset += serializedGroupValue.length;
+ old = serializedGroupValue;
+ }
+
+ ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
+ allSerializedValuesConcatenated);
+ KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(
+ keyGroupRangeOffsets,
+ allSerializedStatesHandle);
+ List<KeyGroupsStateHandle> keyGroupsStateHandleList = new ArrayList<>();
+ keyGroupsStateHandleList.add(keyGroupsStateHandle);
+ return keyGroupsStateHandleList;
+ }
+
+ public static ChainedStateHandle<StreamStateHandle> generateStateForVertex(
+ JobVertexID jobVertexID,
+ int index) throws IOException {
+
+ Random random = new Random(jobVertexID.hashCode() + index);
+ int value = random.nextInt();
+ return generateChainedStateHandle(value);
+ }
+
+ public static ChainedStateHandle<StreamStateHandle> generateChainedStateHandle(
+ Serializable value) throws IOException {
+ return ChainedStateHandle.wrapSingleHandle(ByteStreamStateHandle.fromSerializable(value));
+ }
+
+ public static ExecutionJobVertex mockExecutionJobVertex(
+ JobVertexID jobVertexID,
+ int parallelism,
+ int maxParallelism) {
+ final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class);
+
+ ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism];
+
+ for (int i = 0; i < parallelism; i++) {
+ executionVertices[i] = mockExecutionVertex(
+ new ExecutionAttemptID(),
+ jobVertexID,
+ parallelism,
+ maxParallelism,
+ ExecutionState.RUNNING);
+
+ when(executionVertices[i].getParallelSubtaskIndex()).thenReturn(i);
+ }
+
+ when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID);
+ when(executionJobVertex.getTaskVertices()).thenReturn(executionVertices);
+ when(executionJobVertex.getParallelism()).thenReturn(parallelism);
+ when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism);
+
+ return executionJobVertex;
+ }
+
+ private static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) {
+ return mockExecutionVertex(
+ attemptID,
+ new JobVertexID(),
+ 1,
+ 1,
+ ExecutionState.RUNNING);
+ }
+
+ private static ExecutionVertex mockExecutionVertex(
+ ExecutionAttemptID attemptID,
+ JobVertexID jobVertexID,
+ int parallelism,
+ int maxParallelism,
+ ExecutionState state,
+ ExecutionState ... successiveStates) {
+
+ ExecutionVertex vertex = mock(ExecutionVertex.class);
+
+ final Execution exec = spy(new Execution(
+ mock(ExecutionContext.class),
+ vertex,
+ 1,
+ 1L,
+ null
+ ));
+ when(exec.getAttemptId()).thenReturn(attemptID);
+ when(exec.getState()).thenReturn(state, successiveStates);
+
+ when(vertex.getJobvertexId()).thenReturn(jobVertexID);
+ when(vertex.getCurrentExecutionAttempt()).thenReturn(exec);
+ when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
+ when(vertex.getMaxParallelism()).thenReturn(maxParallelism);
+
+ return vertex;
+ }
+
+ public static void verifiyStateRestore(
+ JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex,
+ List<KeyGroupRange> keyGroupPartitions) throws Exception {
+
+ for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
+
+ ChainedStateHandle<StreamStateHandle> expectNonPartitionedState = generateStateForVertex(jobVertexID, i);
+ ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = executionJobVertex.
+ getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
+ assertEquals(expectNonPartitionedState.get(0), actualNonPartitionedState.get(0));
+
+ List<KeyGroupsStateHandle> expectPartitionedKeyGroupState = generateKeyGroupState(
+ jobVertexID,
+ keyGroupPartitions.get(i));
+ List<KeyGroupsStateHandle> actualPartitionedKeyGroupState = executionJobVertex.
+ getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
+ comparePartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState);
+ }
+ }
+
+ public static void comparePartitionedState(
+ List<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
+ List<KeyGroupsStateHandle> actualPartitionedKeyGroupState) throws Exception {
+
+ KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.get(0);
+ int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getNumberOfKeyGroups();
+ int actualTotalKeyGroups = 0;
+ for(KeyGroupsStateHandle keyGroupsStateHandle: actualPartitionedKeyGroupState) {
+ actualTotalKeyGroups += keyGroupsStateHandle.getNumberOfKeyGroups();
+ }
+
+ assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
+
+ FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.getStateHandle().openInputStream();
+ for(int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) {
+ long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+ inputStream.seek(offset);
+ int expectedKeyGroupState = InstantiationUtil.deserializeObject(inputStream);
+ for(KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
+ if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
+ long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+ FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.getStateHandle().openInputStream();
+ actualInputStream.seek(actualOffset);
+ int actualGroupState = InstantiationUtil.deserializeObject(actualInputStream);
+
+ assertEquals(expectedKeyGroupState, actualGroupState);
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testCreateKeyGroupPartitions() {
+ testCreateKeyGroupPartitions(1, 1);
+ testCreateKeyGroupPartitions(13, 1);
+ testCreateKeyGroupPartitions(13, 2);
+ testCreateKeyGroupPartitions(Short.MAX_VALUE, 1);
+ testCreateKeyGroupPartitions(Short.MAX_VALUE, 13);
+ testCreateKeyGroupPartitions(Short.MAX_VALUE, Short.MAX_VALUE);
+
+ Random r = new Random(1234);
+ for (int k = 0; k < 1000; ++k) {
+ int maxParallelism = 1 + r.nextInt(Short.MAX_VALUE - 1);
+ int parallelism = 1 + r.nextInt(maxParallelism);
+ testCreateKeyGroupPartitions(maxParallelism, parallelism);
+ }
+ }
+
+ private void testCreateKeyGroupPartitions(int maxParallelism, int parallelism) {
+ List<KeyGroupRange> ranges = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism, parallelism);
+ for (int i = 0; i < maxParallelism; ++i) {
+ KeyGroupRange range = ranges.get(KeyGroupRange.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, i));
+ if (!range.contains(i)) {
+ Assert.fail("Could not find expected key-group " + i + " in range " + range);
+ }
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/516ad011/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
index 5fcc59e..f757943 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
@@ -18,26 +18,54 @@
package org.apache.flink.streaming.runtime.tasks;
+import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.graph.StreamEdge;
+import org.apache.flink.streaming.api.graph.StreamNode;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.TestHarnessUtil;
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.TestLogger;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
+import scala.concurrent.duration.Deadline;
+import scala.concurrent.duration.FiniteDuration;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
+import java.util.Random;
import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.TimeUnit;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
/**
* Tests for {@link OneInputStreamTask}.
@@ -51,7 +79,7 @@ import java.util.concurrent.ConcurrentLinkedQueue;
@RunWith(PowerMockRunner.class)
@PrepareForTest({ResultPartitionWriter.class})
@PowerMockIgnore({"javax.management.*", "com.sun.jndi.*"})
-public class OneInputStreamTaskTest {
+public class OneInputStreamTaskTest extends TestLogger {
/**
* This test verifies that open() and close() are correctly called. This test also verifies
@@ -82,7 +110,7 @@ public class OneInputStreamTaskTest {
testHarness.waitForTaskCompletion();
- Assert.assertTrue("RichFunction methods where not called.", TestOpenCloseMapFunction.closeCalled);
+ assertTrue("RichFunction methods where not called.", TestOpenCloseMapFunction.closeCalled);
TestHarnessUtil.assertOutputEquals("Output was not correct.",
expectedOutput,
@@ -165,7 +193,7 @@ public class OneInputStreamTaskTest {
testHarness.waitForTaskCompletion();
List<String> resultElements = TestHarnessUtil.getRawElementsFromOutput(testHarness.getOutput());
- Assert.assertEquals(2, resultElements.size());
+ assertEquals(2, resultElements.size());
}
/**
@@ -293,6 +321,252 @@ public class OneInputStreamTaskTest {
TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
}
+ /**
+ * Tests that the stream operator can snapshot and restore the operator state of chained
+ * operators
+ */
+ @Test
+ public void testSnapshottingAndRestoring() throws Exception {
+ final Deadline deadline = new FiniteDuration(2, TimeUnit.MINUTES).fromNow();
+ final OneInputStreamTask<String, String> streamTask = new OneInputStreamTask<String, String>();
+ final OneInputStreamTaskTestHarness<String, String> testHarness = new OneInputStreamTaskTestHarness<String, String>(streamTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
+ IdentityKeySelector<String> keySelector = new IdentityKeySelector<>();
+ testHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO);
+
+ long checkpointId = 1L;
+ long checkpointTimestamp = 1L;
+ long recoveryTimestamp = 3L;
+ long seed = 2L;
+ int numberChainedTasks = 11;
+
+ StreamConfig streamConfig = testHarness.getStreamConfig();
+
+ configureChainedTestingStreamOperator(streamConfig, numberChainedTasks, seed, recoveryTimestamp);
+
+ AcknowledgeStreamMockEnvironment env = new AcknowledgeStreamMockEnvironment(
+ testHarness.jobConfig,
+ testHarness.taskConfig,
+ testHarness.executionConfig,
+ testHarness.memorySize,
+ new MockInputSplitProvider(),
+ testHarness.bufferSize);
+
+ // reset number of restore calls
+ TestingStreamOperator.numberRestoreCalls = 0;
+
+ testHarness.invoke(env);
+ testHarness.waitForTaskRunning(deadline.timeLeft().toMillis());
+
+ streamTask.triggerCheckpoint(checkpointId, checkpointTimestamp);
+
+ testHarness.endInput();
+ testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
+
+ // since no state was set, there shouldn't be restore calls
+ assertEquals(0, TestingStreamOperator.numberRestoreCalls);
+
+ assertEquals(checkpointId, env.getCheckpointId());
+
+ final OneInputStreamTask<String, String> restoredTask = new OneInputStreamTask<String, String>();
+ restoredTask.setInitialState(env.getState(), env.getKeyGroupStates());
+
+ final OneInputStreamTaskTestHarness<String, String> restoredTaskHarness = new OneInputStreamTaskTestHarness<String, String>(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
+ restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO);
+
+ StreamConfig restoredTaskStreamConfig = restoredTaskHarness.getStreamConfig();
+
+ configureChainedTestingStreamOperator(restoredTaskStreamConfig, numberChainedTasks, seed, recoveryTimestamp);
+
+ TestingStreamOperator.numberRestoreCalls = 0;
+
+ restoredTaskHarness.invoke();
+ restoredTaskHarness.endInput();
+ restoredTaskHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
+
+ // restore of every chained operator should have been called
+ assertEquals(numberChainedTasks, TestingStreamOperator.numberRestoreCalls);
+
+ TestingStreamOperator.numberRestoreCalls = 0;
+ }
+
+ //==============================================================================================
+ // Utility functions and classes
+ //==============================================================================================
+
+ private void configureChainedTestingStreamOperator(
+ StreamConfig streamConfig,
+ int numberChainedTasks,
+ long seed,
+ long recoveryTimestamp) {
+
+ Preconditions.checkArgument(numberChainedTasks >= 1, "The operator chain must at least " +
+ "contain one operator.");
+
+ Random random = new Random(seed);
+
+ TestingStreamOperator<Integer, Integer> previousOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp);
+ streamConfig.setStreamOperator(previousOperator);
+
+ // create the chain of operators
+ Map<Integer, StreamConfig> chainedTaskConfigs = new HashMap<>(numberChainedTasks - 1);
+ List<StreamEdge> outputEdges = new ArrayList<>(numberChainedTasks - 1);
+
+ for (int chainedIndex = 1; chainedIndex < numberChainedTasks; chainedIndex++) {
+ TestingStreamOperator<Integer, Integer> chainedOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp);
+ StreamConfig chainedConfig = new StreamConfig(new Configuration());
+ chainedConfig.setStreamOperator(chainedOperator);
+ chainedTaskConfigs.put(chainedIndex, chainedConfig);
+
+ StreamEdge outputEdge = new StreamEdge(
+ new StreamNode(
+ null,
+ chainedIndex - 1,
+ null,
+ null,
+ null,
+ null,
+ null
+ ),
+ new StreamNode(
+ null,
+ chainedIndex,
+ null,
+ null,
+ null,
+ null,
+ null
+ ),
+ 0,
+ Collections.<String>emptyList(),
+ null
+ );
+
+ outputEdges.add(outputEdge);
+ }
+
+ streamConfig.setChainedOutputs(outputEdges);
+ streamConfig.setTransitiveChainedTaskConfigs(chainedTaskConfigs);
+ }
+
+ private static class IdentityKeySelector<IN> implements KeySelector<IN, IN> {
+
+ private static final long serialVersionUID = -3555913664416688425L;
+
+ @Override
+ public IN getKey(IN value) throws Exception {
+ return value;
+ }
+ }
+
+ private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment {
+ private long checkpointId;
+ private ChainedStateHandle<StreamStateHandle> state;
+ private List<KeyGroupsStateHandle> keyGroupStates;
+
+ public long getCheckpointId() {
+ return checkpointId;
+ }
+
+ public ChainedStateHandle<StreamStateHandle> getState() {
+ return state;
+ }
+
+ List<KeyGroupsStateHandle> getKeyGroupStates() {
+ List<KeyGroupsStateHandle> result = new ArrayList<>();
+ for (int i = 0; i < keyGroupStates.size(); i++) {
+ if (keyGroupStates.get(i) != null) {
+ result.add(keyGroupStates.get(i));
+ }
+ }
+ return result;
+ }
+
+ AcknowledgeStreamMockEnvironment(Configuration jobConfig, Configuration taskConfig,
+ ExecutionConfig executionConfig, long memorySize,
+ MockInputSplitProvider inputSplitProvider, int bufferSize) {
+ super(jobConfig, taskConfig, executionConfig, memorySize, inputSplitProvider, bufferSize);
+ }
+
+
+ @Override
+ public void acknowledgeCheckpoint(long checkpointId, ChainedStateHandle<StreamStateHandle> state,
+ List<KeyGroupsStateHandle> keyGroupStates) {
+ this.checkpointId = checkpointId;
+ this.state = state;
+ this.keyGroupStates = keyGroupStates;
+ }
+ }
+
+ private static class TestingStreamOperator<IN, OUT>
+ extends AbstractStreamOperator<OUT> implements OneInputStreamOperator<IN, OUT> {
+
+ private static final long serialVersionUID = 774614855940397174L;
+
+ public static int numberRestoreCalls = 0;
+
+ private final long seed;
+ private final long recoveryTimestamp;
+
+ private transient Random random;
+
+ TestingStreamOperator(long seed, long recoveryTimestamp) {
+ this.seed = seed;
+ this.recoveryTimestamp = recoveryTimestamp;
+ }
+
+ @Override
+ public void processElement(StreamRecord<IN> element) throws Exception {
+
+ }
+
+ @Override
+ public void processWatermark(Watermark mark) throws Exception {
+
+ }
+
+ @Override
+ public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception {
+ if (random == null) {
+ random = new Random(seed);
+ }
+
+ Serializable functionState = generateFunctionState();
+ Integer operatorState = generateOperatorState();
+
+ InstantiationUtil.serializeObject(out, functionState);
+ InstantiationUtil.serializeObject(out, operatorState);
+ }
+
+ @Override
+ public void restoreState(FSDataInputStream in) throws Exception {
+ numberRestoreCalls++;
+
+ if (random == null) {
+ random = new Random(seed);
+ }
+
+ assertEquals(this.recoveryTimestamp, recoveryTimestamp);
+
+ assertNotNull(in);
+
+ Serializable functionState= InstantiationUtil.deserializeObject(in);
+ Integer operatorState= InstantiationUtil.deserializeObject(in);
+
+ assertEquals(random.nextInt(), functionState);
+ assertEquals(random.nextInt(), (int) operatorState);
+ }
+
+
+ private Serializable generateFunctionState() {
+ return random.nextInt();
+ }
+
+ private Integer generateOperatorState() {
+ return random.nextInt();
+ }
+ }
+
+
// This must only be used in one test, otherwise the static fields will be changed
// by several tests concurrently
private static class TestOpenCloseMapFunction extends RichMapFunction<String, String> {
http://git-wip-us.apache.org/repos/asf/flink/blob/516ad011/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index 8d1baeb..39f3086 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -352,7 +352,6 @@ public class RescalingITCase extends TestLogger {
for (int key = 0; key < numberKeys; key++) {
int keyGroupIndex = keyGroupAssigner.getKeyGroupIndex(key);
-// expectedResult.add(Tuple2.of(keyGroupIndex % parallelism, numberElements * key));
expectedResult.add(Tuple2.of(KeyGroupRange.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex) , numberElements * key));
}