You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2019/10/28 08:22:31 UTC
[flink] 01/08: [hotfix] Extract utils of CheckpointCoordinatorTest
into a separate utils class and correct codestyle
This is an automated email from the ASF dual-hosted git repository.
pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 126cff6ab4e433a69c774e126b2e85335e30ee7a
Author: ifndef-SleePy <mm...@gmail.com>
AuthorDate: Wed Sep 18 17:30:04 2019 +0800
[hotfix] Extract utils of CheckpointCoordinatorTest into a separate utils class and correct codestyle
---
.../CheckpointCoordinatorFailureTest.java | 5 +-
.../CheckpointCoordinatorMasterHooksTest.java | 4 +-
.../checkpoint/CheckpointCoordinatorTest.java | 517 +++------------------
.../CheckpointCoordinatorTestingUtils.java | 472 +++++++++++++++++++
.../checkpoint/CheckpointStateRestoreTest.java | 4 +-
.../runtime/messages/CheckpointMessagesTest.java | 13 +-
6 files changed, 542 insertions(+), 473 deletions(-)
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index da181ba..b6b7930 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -49,6 +49,9 @@ import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
+/**
+ * Tests for failure of checkpoint coordinator.
+ */
@RunWith(PowerMockRunner.class)
@PrepareForTest(PendingCheckpoint.class)
public class CheckpointCoordinatorFailureTest extends TestLogger {
@@ -62,7 +65,7 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
JobID jid = new JobID();
final ExecutionAttemptID executionAttemptId = new ExecutionAttemptID();
- final ExecutionVertex vertex = CheckpointCoordinatorTest.mockExecutionVertex(executionAttemptId);
+ final ExecutionVertex vertex = CheckpointCoordinatorTestingUtils.mockExecutionVertex(executionAttemptId);
final long triggerTimestamp = 1L;
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
index 48b6583..8453cba 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
@@ -48,7 +48,7 @@ import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest.mockExecutionVertex;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@@ -353,7 +353,7 @@ public class CheckpointCoordinatorMasterHooksTest {
// ------------------------------------------------------------------------
/**
- * This test makes sure that the checkpoint is already registered by the time
+ * This test makes sure that the checkpoint is already registered by the time.
* that the hooks are called
*/
@Test
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 98d80fd..034fac4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -19,9 +19,7 @@
package org.apache.flink.runtime.checkpoint;
import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.concurrent.Executors;
import org.apache.flink.runtime.execution.ExecutionState;
@@ -39,7 +37,6 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
import org.apache.flink.runtime.state.IncrementalRemoteKeyedStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
-import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
@@ -55,8 +52,6 @@ import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLo
import org.apache.flink.runtime.testutils.CommonTestUtils;
import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
import org.apache.flink.util.ExceptionUtils;
-import org.apache.flink.util.InstantiationUtil;
-import org.apache.flink.util.Preconditions;
import org.apache.flink.util.SerializableObject;
import org.apache.flink.util.TestLogger;
@@ -76,7 +71,6 @@ import org.mockito.stubbing.Answer;
import org.mockito.verification.VerificationMode;
import java.io.IOException;
-import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -91,11 +85,20 @@ import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.compareKeyedState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.comparePartitionableState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateChainedPartitionableStateHandle;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateKeyGroupState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generatePartitionableStateHandle;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecution;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionJobVertex;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockSubtaskState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.verifyStateRestore;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
@@ -439,7 +442,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
assertFalse(checkpoint.isDiscarded());
assertFalse(checkpoint.isFullyAcknowledged());
-
// decline checkpoint from the other task, this should cancel the checkpoint
// and trigger a new one
coord.receiveDeclineMessage(new DeclineCheckpoint(jid, attemptID1, checkpointId), TASK_MANAGER_LOCATION_INFO);
@@ -923,19 +925,19 @@ public class CheckpointCoordinatorTest extends TestLogger {
OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId());
OperatorID opID3 = OperatorID.fromJobVertexID(ackVertex3.getJobvertexId());
- TaskStateSnapshot taskOperatorSubtaskStates1_1 = spy(new TaskStateSnapshot());
- TaskStateSnapshot taskOperatorSubtaskStates1_2 = spy(new TaskStateSnapshot());
- TaskStateSnapshot taskOperatorSubtaskStates1_3 = spy(new TaskStateSnapshot());
+ TaskStateSnapshot taskOperatorSubtaskStates11 = spy(new TaskStateSnapshot());
+ TaskStateSnapshot taskOperatorSubtaskStates12 = spy(new TaskStateSnapshot());
+ TaskStateSnapshot taskOperatorSubtaskStates13 = spy(new TaskStateSnapshot());
- OperatorSubtaskState subtaskState1_1 = mock(OperatorSubtaskState.class);
- OperatorSubtaskState subtaskState1_2 = mock(OperatorSubtaskState.class);
- OperatorSubtaskState subtaskState1_3 = mock(OperatorSubtaskState.class);
- taskOperatorSubtaskStates1_1.putSubtaskStateByOperatorID(opID1, subtaskState1_1);
- taskOperatorSubtaskStates1_2.putSubtaskStateByOperatorID(opID2, subtaskState1_2);
- taskOperatorSubtaskStates1_3.putSubtaskStateByOperatorID(opID3, subtaskState1_3);
+ OperatorSubtaskState subtaskState11 = mock(OperatorSubtaskState.class);
+ OperatorSubtaskState subtaskState12 = mock(OperatorSubtaskState.class);
+ OperatorSubtaskState subtaskState13 = mock(OperatorSubtaskState.class);
+ taskOperatorSubtaskStates11.putSubtaskStateByOperatorID(opID1, subtaskState11);
+ taskOperatorSubtaskStates12.putSubtaskStateByOperatorID(opID2, subtaskState12);
+ taskOperatorSubtaskStates13.putSubtaskStateByOperatorID(opID3, subtaskState13);
// acknowledge one of the three tasks
- coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_2), TASK_MANAGER_LOCATION_INFO);
+ coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates12), TASK_MANAGER_LOCATION_INFO);
// start the second checkpoint
// trigger the first checkpoint. this should succeed
@@ -953,17 +955,17 @@ public class CheckpointCoordinatorTest extends TestLogger {
}
long checkpointId2 = pending2.getCheckpointId();
- TaskStateSnapshot taskOperatorSubtaskStates2_1 = spy(new TaskStateSnapshot());
- TaskStateSnapshot taskOperatorSubtaskStates2_2 = spy(new TaskStateSnapshot());
- TaskStateSnapshot taskOperatorSubtaskStates2_3 = spy(new TaskStateSnapshot());
+ TaskStateSnapshot taskOperatorSubtaskStates21 = spy(new TaskStateSnapshot());
+ TaskStateSnapshot taskOperatorSubtaskStates22 = spy(new TaskStateSnapshot());
+ TaskStateSnapshot taskOperatorSubtaskStates23 = spy(new TaskStateSnapshot());
- OperatorSubtaskState subtaskState2_1 = mock(OperatorSubtaskState.class);
- OperatorSubtaskState subtaskState2_2 = mock(OperatorSubtaskState.class);
- OperatorSubtaskState subtaskState2_3 = mock(OperatorSubtaskState.class);
+ OperatorSubtaskState subtaskState21 = mock(OperatorSubtaskState.class);
+ OperatorSubtaskState subtaskState22 = mock(OperatorSubtaskState.class);
+ OperatorSubtaskState subtaskState23 = mock(OperatorSubtaskState.class);
- taskOperatorSubtaskStates2_1.putSubtaskStateByOperatorID(opID1, subtaskState2_1);
- taskOperatorSubtaskStates2_2.putSubtaskStateByOperatorID(opID2, subtaskState2_2);
- taskOperatorSubtaskStates2_3.putSubtaskStateByOperatorID(opID3, subtaskState2_3);
+ taskOperatorSubtaskStates21.putSubtaskStateByOperatorID(opID1, subtaskState21);
+ taskOperatorSubtaskStates22.putSubtaskStateByOperatorID(opID2, subtaskState22);
+ taskOperatorSubtaskStates23.putSubtaskStateByOperatorID(opID3, subtaskState23);
// trigger messages should have been sent
verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class));
@@ -972,13 +974,13 @@ public class CheckpointCoordinatorTest extends TestLogger {
// we acknowledge one more task from the first checkpoint and the second
// checkpoint completely. The second checkpoint should then subsume the first checkpoint
- coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_3), TASK_MANAGER_LOCATION_INFO);
+ coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates23), TASK_MANAGER_LOCATION_INFO);
- coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_1), TASK_MANAGER_LOCATION_INFO);
+ coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates21), TASK_MANAGER_LOCATION_INFO);
- coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_1), TASK_MANAGER_LOCATION_INFO);
+ coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates11), TASK_MANAGER_LOCATION_INFO);
- coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_2), TASK_MANAGER_LOCATION_INFO);
+ coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates22), TASK_MANAGER_LOCATION_INFO);
// now, the second checkpoint should be confirmed, and the first discarded
// actually both pending checkpoints are discarded, and the second has been transformed
@@ -990,13 +992,13 @@ public class CheckpointCoordinatorTest extends TestLogger {
assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
// validate that all received subtask states in the first checkpoint have been discarded
- verify(subtaskState1_1, times(1)).discardState();
- verify(subtaskState1_2, times(1)).discardState();
+ verify(subtaskState11, times(1)).discardState();
+ verify(subtaskState12, times(1)).discardState();
// validate that all subtask states in the second checkpoint are not discarded
- verify(subtaskState2_1, never()).discardState();
- verify(subtaskState2_2, never()).discardState();
- verify(subtaskState2_3, never()).discardState();
+ verify(subtaskState21, never()).discardState();
+ verify(subtaskState22, never()).discardState();
+ verify(subtaskState23, never()).discardState();
// validate the committed checkpoints
List<CompletedCheckpoint> scs = coord.getSuccessfulCheckpoints();
@@ -1010,15 +1012,15 @@ public class CheckpointCoordinatorTest extends TestLogger {
verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2));
// send the last remaining ack for the first checkpoint. This should not do anything
- coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_3), TASK_MANAGER_LOCATION_INFO);
- verify(subtaskState1_3, times(1)).discardState();
+ coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates13), TASK_MANAGER_LOCATION_INFO);
+ verify(subtaskState13, times(1)).discardState();
coord.shutdown(JobStatus.FINISHED);
// validate that the states in the second checkpoint have been discarded
- verify(subtaskState2_1, times(1)).discardState();
- verify(subtaskState2_2, times(1)).discardState();
- verify(subtaskState2_3, times(1)).discardState();
+ verify(subtaskState21, times(1)).discardState();
+ verify(subtaskState22, times(1)).discardState();
+ verify(subtaskState23, times(1)).discardState();
}
catch (Exception e) {
@@ -1377,7 +1379,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
coord.stopCheckpointScheduler();
-
// for 400 ms, no further calls may come.
// there may be the case that one trigger was fired and about to
// acquire the lock, such that after cancelling it will still do
@@ -1385,7 +1386,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
int numCallsSoFar = numCalls.get();
Thread.sleep(400);
assertTrue(numCallsSoFar == numCalls.get() ||
- numCallsSoFar+1 == numCalls.get());
+ numCallsSoFar + 1 == numCalls.get());
// start another sequence of periodic scheduling
numCalls.set(0);
@@ -2200,7 +2201,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
*
* @throws Exception
*/
- @Test(expected=IllegalStateException.class)
+ @Test(expected = IllegalStateException.class)
public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws Exception {
final JobID jid = new JobID();
final long timestamp = System.currentTimeMillis();
@@ -2275,7 +2276,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
}
-
for (int index = 0; index < jobVertex2.getParallelism(); index++) {
KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, null, keyGroupState, null);
@@ -2391,9 +2391,9 @@ public class CheckpointCoordinatorTest extends TestLogger {
long checkpointId = checkpointIDCounter.getLast();
- KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
+ KeyGroupRange keyGroupRange = KeyGroupRange.of(0, 0);
List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
- KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+ KeyedStateHandle serializedKeyGroupStates = generateKeyGroupState(keyGroupRange, testStates);
TaskStateSnapshot subtaskStatesForCheckpoint = new TaskStateSnapshot();
@@ -2416,10 +2416,9 @@ public class CheckpointCoordinatorTest extends TestLogger {
timestamp = System.currentTimeMillis();
CompletableFuture<CompletedCheckpoint> savepointFuture = coord.triggerSavepoint(timestamp, savepointDir);
-
KeyGroupRange keyGroupRangeForSavepoint = KeyGroupRange.of(1, 1);
List<SerializableObject> testStatesForSavepoint = Collections.singletonList(new SerializableObject());
- KeyedStateHandle serializedKeyGroupStatesForSavepoint = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRangeForSavepoint, testStatesForSavepoint);
+ KeyedStateHandle serializedKeyGroupStatesForSavepoint = generateKeyGroupState(keyGroupRangeForSavepoint, testStatesForSavepoint);
TaskStateSnapshot subtaskStatesForSavepoint = new TaskStateSnapshot();
@@ -2677,16 +2676,18 @@ public class CheckpointCoordinatorTest extends TestLogger {
OperatorID operatorID = OperatorID.fromJobVertexID(jobVertexID);
return new Tuple2<>(jobVertexID, operatorID);
}
-
+
/**
- * old topology
+ * <p>
+ * old topology.
* [operator1,operator2] * parallelism1 -> [operator3,operator4] * parallelism2
+ * </p>
*
- *
+ * <p>
* new topology
*
* [operator5,operator1,operator3] * newParallelism1 -> [operator3, operator6] * newParallelism2
- *
+ * </p>
* scaleType:
* 0 increase parallelism
* 1 decrease parallelism
@@ -2956,7 +2957,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
Collection<KeyedStateHandle> keyedStateBackend = headOpState.getManagedKeyedState();
Collection<KeyedStateHandle> keyGroupStateRaw = headOpState.getRawKeyedState();
-
compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
}
@@ -3019,383 +3019,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
}
}
- // ------------------------------------------------------------------------
- // Utilities
- // ------------------------------------------------------------------------
-
- public static KeyGroupsStateHandle generateKeyGroupState(
- JobVertexID jobVertexID,
- KeyGroupRange keyGroupPartition, boolean rawState) throws IOException {
-
- List<Integer> testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups());
-
- // generate state for one keygroup
- for (int keyGroupIndex : keyGroupPartition) {
- int vertexHash = jobVertexID.hashCode();
- int seed = rawState ? (vertexHash * (31 + keyGroupIndex)) : (vertexHash + keyGroupIndex);
- Random random = new Random(seed);
- int simulatedStateValue = random.nextInt();
- testStatesLists.add(simulatedStateValue);
- }
-
- return generateKeyGroupState(keyGroupPartition, testStatesLists);
- }
-
- public static KeyGroupsStateHandle generateKeyGroupState(
- KeyGroupRange keyGroupRange,
- List<? extends Serializable> states) throws IOException {
-
- Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size());
-
- Tuple2<byte[], List<long[]>> serializedDataWithOffsets =
- serializeTogetherAndTrackOffsets(Collections.<List<? extends Serializable>>singletonList(states));
-
- KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0));
-
- ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
- String.valueOf(UUID.randomUUID()),
- serializedDataWithOffsets.f0);
-
- return new KeyGroupsStateHandle(keyGroupRangeOffsets, allSerializedStatesHandle);
- }
-
- public static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets(
- List<List<? extends Serializable>> serializables) throws IOException {
-
- List<long[]> offsets = new ArrayList<>(serializables.size());
- List<byte[]> serializedGroupValues = new ArrayList<>();
-
- int runningGroupsOffset = 0;
- for(List<? extends Serializable> list : serializables) {
-
- long[] currentOffsets = new long[list.size()];
- offsets.add(currentOffsets);
-
- for (int i = 0; i < list.size(); ++i) {
- currentOffsets[i] = runningGroupsOffset;
- byte[] serializedValue = InstantiationUtil.serializeObject(list.get(i));
- serializedGroupValues.add(serializedValue);
- runningGroupsOffset += serializedValue.length;
- }
- }
-
- //write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray
- byte[] allSerializedValuesConcatenated = new byte[runningGroupsOffset];
- runningGroupsOffset = 0;
- for (byte[] serializedGroupValue : serializedGroupValues) {
- System.arraycopy(
- serializedGroupValue,
- 0,
- allSerializedValuesConcatenated,
- runningGroupsOffset,
- serializedGroupValue.length);
- runningGroupsOffset += serializedGroupValue.length;
- }
- return new Tuple2<>(allSerializedValuesConcatenated, offsets);
- }
-
- public static OperatorStateHandle generatePartitionableStateHandle(
- JobVertexID jobVertexID,
- int index,
- int namedStates,
- int partitionsPerState,
- boolean rawState) throws IOException {
-
- Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
-
- for (int i = 0; i < namedStates; ++i) {
- List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
- // generate state
- int seed = jobVertexID.hashCode() * index + i * namedStates;
- if (rawState) {
- seed = (seed + 1) * 31;
- }
- Random random = new Random(seed);
- for (int j = 0; j < partitionsPerState; ++j) {
- int simulatedStateValue = random.nextInt();
- testStatesLists.add(simulatedStateValue);
- }
- statesListsMap.put("state-" + i, testStatesLists);
- }
-
- return generatePartitionableStateHandle(statesListsMap);
- }
-
- public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
- JobVertexID jobVertexID,
- int index,
- int namedStates,
- int partitionsPerState,
- boolean rawState) throws IOException {
-
- Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
-
- for (int i = 0; i < namedStates; ++i) {
- List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
- // generate state
- int seed = jobVertexID.hashCode() * index + i * namedStates;
- if (rawState) {
- seed = (seed + 1) * 31;
- }
- Random random = new Random(seed);
- for (int j = 0; j < partitionsPerState; ++j) {
- int simulatedStateValue = random.nextInt();
- testStatesLists.add(simulatedStateValue);
- }
- statesListsMap.put("state-" + i, testStatesLists);
- }
-
- return ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap));
- }
-
- private static OperatorStateHandle generatePartitionableStateHandle(
- Map<String, List<? extends Serializable>> states) throws IOException {
-
- List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size());
-
- for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
- namedStateSerializables.add(entry.getValue());
- }
-
- Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables);
-
- Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(states.size());
-
- int idx = 0;
- for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
- offsetsMap.put(
- entry.getKey(),
- new OperatorStateHandle.StateMetaInfo(
- serializationWithOffsets.f1.get(idx),
- OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
- ++idx;
- }
-
- ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle(
- String.valueOf(UUID.randomUUID()),
- serializationWithOffsets.f0);
-
- return new OperatorStreamStateHandle(offsetsMap, streamStateHandle);
- }
-
- static ExecutionJobVertex mockExecutionJobVertex(
- JobVertexID jobVertexID,
- int parallelism,
- int maxParallelism) {
-
- return mockExecutionJobVertex(
- jobVertexID,
- Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
- parallelism,
- maxParallelism
- );
- }
-
- static ExecutionJobVertex mockExecutionJobVertex(
- JobVertexID jobVertexID,
- List<OperatorID> jobVertexIDs,
- int parallelism,
- int maxParallelism) {
- final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class);
-
- ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism];
-
- for (int i = 0; i < parallelism; i++) {
- executionVertices[i] = mockExecutionVertex(
- new ExecutionAttemptID(),
- jobVertexID,
- jobVertexIDs,
- parallelism,
- maxParallelism,
- ExecutionState.RUNNING);
-
- when(executionVertices[i].getParallelSubtaskIndex()).thenReturn(i);
- }
-
- when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID);
- when(executionJobVertex.getTaskVertices()).thenReturn(executionVertices);
- when(executionJobVertex.getParallelism()).thenReturn(parallelism);
- when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism);
- when(executionJobVertex.isMaxParallelismConfigured()).thenReturn(true);
- when(executionJobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
- when(executionJobVertex.getUserDefinedOperatorIDs()).thenReturn(Arrays.asList(new OperatorID[jobVertexIDs.size()]));
-
- return executionJobVertex;
- }
-
- static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) {
- JobVertexID jobVertexID = new JobVertexID();
- return mockExecutionVertex(
- attemptID,
- jobVertexID,
- Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
- 1,
- 1,
- ExecutionState.RUNNING);
- }
-
- private static ExecutionVertex mockExecutionVertex(
- ExecutionAttemptID attemptID,
- JobVertexID jobVertexID,
- List<OperatorID> jobVertexIDs,
- int parallelism,
- int maxParallelism,
- ExecutionState state,
- ExecutionState ... successiveStates) {
-
- ExecutionVertex vertex = mock(ExecutionVertex.class);
-
- final Execution exec = spy(new Execution(
- mock(Executor.class),
- vertex,
- 1,
- 1L,
- 1L,
- Time.milliseconds(500L)
- ));
- when(exec.getAttemptId()).thenReturn(attemptID);
- when(exec.getState()).thenReturn(state, successiveStates);
-
- when(vertex.getJobvertexId()).thenReturn(jobVertexID);
- when(vertex.getCurrentExecutionAttempt()).thenReturn(exec);
- when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
- when(vertex.getMaxParallelism()).thenReturn(maxParallelism);
-
- ExecutionJobVertex jobVertex = mock(ExecutionJobVertex.class);
- when(jobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
-
- when(vertex.getJobVertex()).thenReturn(jobVertex);
-
- return vertex;
- }
-
- static TaskStateSnapshot mockSubtaskState(
- JobVertexID jobVertexID,
- int index,
- KeyGroupRange keyGroupRange) throws IOException {
-
- OperatorStateHandle partitionableState = generatePartitionableStateHandle(jobVertexID, index, 2, 8, false);
- KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false);
-
- TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot());
- OperatorSubtaskState subtaskState = spy(new OperatorSubtaskState(
- partitionableState, null, partitionedKeyGroupState, null)
- );
-
- subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), subtaskState);
-
- return subtaskStates;
- }
-
- public static void verifyStateRestore(
- JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex,
- List<KeyGroupRange> keyGroupPartitions) throws Exception {
-
- for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
-
- JobManagerTaskRestore taskRestore = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
- Assert.assertEquals(1L, taskRestore.getRestoreCheckpointId());
- TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot();
-
- OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID));
-
- ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend =
- generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false);
-
- assertTrue(CommonTestUtils.isStreamContentEqual(
- expectedOpStateBackend.get(0).openInputStream(),
- operatorState.getManagedOperatorState().iterator().next().openInputStream()));
-
- KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
- jobVertexID, keyGroupPartitions.get(i), false);
- compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), operatorState.getManagedKeyedState());
- }
- }
-
- public static void compareKeyedState(
- Collection<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
- Collection<? extends KeyedStateHandle> actualPartitionedKeyGroupState) throws Exception {
-
- KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next();
- int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
- int actualTotalKeyGroups = 0;
- for(KeyedStateHandle keyedStateHandle: actualPartitionedKeyGroupState) {
- assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle);
-
- actualTotalKeyGroups += keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
- }
-
- assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
-
- try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.openInputStream()) {
- for (int groupId : expectedHeadOpKeyGroupStateHandle.getKeyGroupRange()) {
- long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
- inputStream.seek(offset);
- int expectedKeyGroupState =
- InstantiationUtil.deserializeObject(inputStream, Thread.currentThread().getContextClassLoader());
- for (KeyedStateHandle oneActualKeyedStateHandle : actualPartitionedKeyGroupState) {
-
- assertTrue(oneActualKeyedStateHandle instanceof KeyGroupsStateHandle);
-
- KeyGroupsStateHandle oneActualKeyGroupStateHandle = (KeyGroupsStateHandle) oneActualKeyedStateHandle;
- if (oneActualKeyGroupStateHandle.getKeyGroupRange().contains(groupId)) {
- long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
- try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) {
- actualInputStream.seek(actualOffset);
- int actualGroupState = InstantiationUtil.
- deserializeObject(actualInputStream, Thread.currentThread().getContextClassLoader());
- assertEquals(expectedKeyGroupState, actualGroupState);
- }
- }
- }
- }
- }
- }
-
- public static void comparePartitionableState(
- List<ChainedStateHandle<OperatorStateHandle>> expected,
- List<List<Collection<OperatorStateHandle>>> actual) throws Exception {
-
- List<String> expectedResult = new ArrayList<>();
- for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) {
- for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
- OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i);
- collectResult(i, operatorStateHandle, expectedResult);
- }
- }
- Collections.sort(expectedResult);
-
- List<String> actualResult = new ArrayList<>();
- for (List<Collection<OperatorStateHandle>> collectionList : actual) {
- if (collectionList != null) {
- for (int i = 0; i < collectionList.size(); ++i) {
- Collection<OperatorStateHandle> stateHandles = collectionList.get(i);
- Assert.assertNotNull(stateHandles);
- for (OperatorStateHandle operatorStateHandle : stateHandles) {
- collectResult(i, operatorStateHandle, actualResult);
- }
- }
- }
- }
-
- Collections.sort(actualResult);
- Assert.assertEquals(expectedResult, actualResult);
- }
-
- private static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception {
- try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
- for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
- for (long offset : entry.getValue().getOffsets()) {
- in.seek(offset);
- Integer state = InstantiationUtil.
- deserializeObject(in, Thread.currentThread().getContextClassLoader());
- resultCollector.add(opIdx + " : " + entry.getKey() + " : " + state);
- }
- }
- }
- }
-
-
@Test
public void testCreateKeyGroupPartitions() {
testCreateKeyGroupPartitions(1, 1);
@@ -4071,36 +3694,4 @@ public class CheckpointCoordinatorTest extends TestLogger {
coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
}
}
-
- private Execution mockExecution() {
- Execution mock = mock(Execution.class);
- when(mock.getAttemptId()).thenReturn(new ExecutionAttemptID());
- when(mock.getState()).thenReturn(ExecutionState.RUNNING);
- return mock;
- }
-
- private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask, int parallelism) {
- ExecutionVertex mock = mock(ExecutionVertex.class);
- when(mock.getJobvertexId()).thenReturn(vertexId);
- when(mock.getParallelSubtaskIndex()).thenReturn(subtask);
- when(mock.getCurrentExecutionAttempt()).thenReturn(execution);
- when(mock.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
- when(mock.getMaxParallelism()).thenReturn(parallelism);
- return mock;
- }
-
- private ExecutionJobVertex mockExecutionJobVertex(JobVertexID id, ExecutionVertex[] vertices) {
- ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
- when(vertex.getParallelism()).thenReturn(vertices.length);
- when(vertex.getMaxParallelism()).thenReturn(vertices.length);
- when(vertex.getJobVertexId()).thenReturn(id);
- when(vertex.getTaskVertices()).thenReturn(vertices);
- when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorID.fromJobVertexID(id)));
- when(vertex.getUserDefinedOperatorIDs()).thenReturn(Collections.<OperatorID>singletonList(null));
-
- for (ExecutionVertex v : vertices) {
- when(v.getJobVertex()).thenReturn(vertex);
- }
- return vertex;
- }
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
new file mode 100644
index 0000000..9578ac7
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
@@ -0,0 +1,472 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.api.common.time.Time;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.OperatorStreamStateHandle;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.testutils.CommonTestUtils;
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
+
+import org.junit.Assert;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.UUID;
+import java.util.concurrent.Executor;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+/**
+ * Testing utils for checkpoint coordinator.
+ */
+public class CheckpointCoordinatorTestingUtils {
+
+ public static OperatorStateHandle generatePartitionableStateHandle(
+ JobVertexID jobVertexID,
+ int index,
+ int namedStates,
+ int partitionsPerState,
+ boolean rawState) throws IOException {
+
+ Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
+
+ for (int i = 0; i < namedStates; ++i) {
+ List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
+ // generate state
+ int seed = jobVertexID.hashCode() * index + i * namedStates;
+ if (rawState) {
+ seed = (seed + 1) * 31;
+ }
+ Random random = new Random(seed);
+ for (int j = 0; j < partitionsPerState; ++j) {
+ int simulatedStateValue = random.nextInt();
+ testStatesLists.add(simulatedStateValue);
+ }
+ statesListsMap.put("state-" + i, testStatesLists);
+ }
+
+ return generatePartitionableStateHandle(statesListsMap);
+ }
+
+ static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
+ JobVertexID jobVertexID,
+ int index,
+ int namedStates,
+ int partitionsPerState,
+ boolean rawState) throws IOException {
+
+ Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
+
+ for (int i = 0; i < namedStates; ++i) {
+ List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
+ // generate state
+ int seed = jobVertexID.hashCode() * index + i * namedStates;
+ if (rawState) {
+ seed = (seed + 1) * 31;
+ }
+ Random random = new Random(seed);
+ for (int j = 0; j < partitionsPerState; ++j) {
+ int simulatedStateValue = random.nextInt();
+ testStatesLists.add(simulatedStateValue);
+ }
+ statesListsMap.put("state-" + i, testStatesLists);
+ }
+
+ return ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap));
+ }
+
+ static OperatorStateHandle generatePartitionableStateHandle(
+ Map<String, List<? extends Serializable>> states) throws IOException {
+
+ List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size());
+
+ for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
+ namedStateSerializables.add(entry.getValue());
+ }
+
+ Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables);
+
+ Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(states.size());
+
+ int idx = 0;
+ for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
+ offsetsMap.put(
+ entry.getKey(),
+ new OperatorStateHandle.StateMetaInfo(
+ serializationWithOffsets.f1.get(idx),
+ OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+ ++idx;
+ }
+
+ ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle(
+ String.valueOf(UUID.randomUUID()),
+ serializationWithOffsets.f0);
+
+ return new OperatorStreamStateHandle(offsetsMap, streamStateHandle);
+ }
+
+ static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets(
+ List<List<? extends Serializable>> serializables) throws IOException {
+
+ List<long[]> offsets = new ArrayList<>(serializables.size());
+ List<byte[]> serializedGroupValues = new ArrayList<>();
+
+ int runningGroupsOffset = 0;
+ for (List<? extends Serializable> list : serializables) {
+
+ long[] currentOffsets = new long[list.size()];
+ offsets.add(currentOffsets);
+
+ for (int i = 0; i < list.size(); ++i) {
+ currentOffsets[i] = runningGroupsOffset;
+ byte[] serializedValue = InstantiationUtil.serializeObject(list.get(i));
+ serializedGroupValues.add(serializedValue);
+ runningGroupsOffset += serializedValue.length;
+ }
+ }
+
+ //write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray
+ byte[] allSerializedValuesConcatenated = new byte[runningGroupsOffset];
+ runningGroupsOffset = 0;
+ for (byte[] serializedGroupValue : serializedGroupValues) {
+ System.arraycopy(
+ serializedGroupValue,
+ 0,
+ allSerializedValuesConcatenated,
+ runningGroupsOffset,
+ serializedGroupValue.length);
+ runningGroupsOffset += serializedGroupValue.length;
+ }
+ return new Tuple2<>(allSerializedValuesConcatenated, offsets);
+ }
+
+ public static void verifyStateRestore(
+ JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex,
+ List<KeyGroupRange> keyGroupPartitions) throws Exception {
+
+ for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
+
+ JobManagerTaskRestore taskRestore = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
+ Assert.assertEquals(1L, taskRestore.getRestoreCheckpointId());
+ TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot();
+
+ OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID));
+
+ ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend =
+ generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false);
+
+ assertTrue(CommonTestUtils.isStreamContentEqual(
+ expectedOpStateBackend.get(0).openInputStream(),
+ operatorState.getManagedOperatorState().iterator().next().openInputStream()));
+
+ KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
+ jobVertexID, keyGroupPartitions.get(i), false);
+ compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), operatorState.getManagedKeyedState());
+ }
+ }
+
+ static void compareKeyedState(
+ Collection<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
+ Collection<? extends KeyedStateHandle> actualPartitionedKeyGroupState) throws Exception {
+
+ KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next();
+ int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
+ int actualTotalKeyGroups = 0;
+ for (KeyedStateHandle keyedStateHandle: actualPartitionedKeyGroupState) {
+ assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle);
+
+ actualTotalKeyGroups += keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
+ }
+
+ assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
+
+ try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.openInputStream()) {
+ for (int groupId : expectedHeadOpKeyGroupStateHandle.getKeyGroupRange()) {
+ long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+ inputStream.seek(offset);
+ int expectedKeyGroupState =
+ InstantiationUtil.deserializeObject(inputStream, Thread.currentThread().getContextClassLoader());
+ for (KeyedStateHandle oneActualKeyedStateHandle : actualPartitionedKeyGroupState) {
+
+ assertTrue(oneActualKeyedStateHandle instanceof KeyGroupsStateHandle);
+
+ KeyGroupsStateHandle oneActualKeyGroupStateHandle = (KeyGroupsStateHandle) oneActualKeyedStateHandle;
+ if (oneActualKeyGroupStateHandle.getKeyGroupRange().contains(groupId)) {
+ long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+ try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) {
+ actualInputStream.seek(actualOffset);
+ int actualGroupState = InstantiationUtil.
+ deserializeObject(actualInputStream, Thread.currentThread().getContextClassLoader());
+ assertEquals(expectedKeyGroupState, actualGroupState);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ static void comparePartitionableState(
+ List<ChainedStateHandle<OperatorStateHandle>> expected,
+ List<List<Collection<OperatorStateHandle>>> actual) throws Exception {
+
+ List<String> expectedResult = new ArrayList<>();
+ for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) {
+ for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
+ OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i);
+ collectResult(i, operatorStateHandle, expectedResult);
+ }
+ }
+ Collections.sort(expectedResult);
+
+ List<String> actualResult = new ArrayList<>();
+ for (List<Collection<OperatorStateHandle>> collectionList : actual) {
+ if (collectionList != null) {
+ for (int i = 0; i < collectionList.size(); ++i) {
+ Collection<OperatorStateHandle> stateHandles = collectionList.get(i);
+ Assert.assertNotNull(stateHandles);
+ for (OperatorStateHandle operatorStateHandle : stateHandles) {
+ collectResult(i, operatorStateHandle, actualResult);
+ }
+ }
+ }
+ }
+
+ Collections.sort(actualResult);
+ Assert.assertEquals(expectedResult, actualResult);
+ }
+
+ static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception {
+ try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
+ for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+ for (long offset : entry.getValue().getOffsets()) {
+ in.seek(offset);
+ Integer state = InstantiationUtil.
+ deserializeObject(in, Thread.currentThread().getContextClassLoader());
+ resultCollector.add(opIdx + " : " + entry.getKey() + " : " + state);
+ }
+ }
+ }
+ }
+
+ static ExecutionJobVertex mockExecutionJobVertex(
+ JobVertexID jobVertexID,
+ int parallelism,
+ int maxParallelism) {
+
+ return mockExecutionJobVertex(
+ jobVertexID,
+ Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
+ parallelism,
+ maxParallelism
+ );
+ }
+
+ static ExecutionJobVertex mockExecutionJobVertex(
+ JobVertexID jobVertexID,
+ List<OperatorID> jobVertexIDs,
+ int parallelism,
+ int maxParallelism) {
+ final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class);
+
+ ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism];
+
+ for (int i = 0; i < parallelism; i++) {
+ executionVertices[i] = mockExecutionVertex(
+ new ExecutionAttemptID(),
+ jobVertexID,
+ jobVertexIDs,
+ parallelism,
+ maxParallelism,
+ ExecutionState.RUNNING);
+
+ when(executionVertices[i].getParallelSubtaskIndex()).thenReturn(i);
+ }
+
+ when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID);
+ when(executionJobVertex.getTaskVertices()).thenReturn(executionVertices);
+ when(executionJobVertex.getParallelism()).thenReturn(parallelism);
+ when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism);
+ when(executionJobVertex.isMaxParallelismConfigured()).thenReturn(true);
+ when(executionJobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
+ when(executionJobVertex.getUserDefinedOperatorIDs()).thenReturn(Arrays.asList(new OperatorID[jobVertexIDs.size()]));
+
+ return executionJobVertex;
+ }
+
+ static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) {
+ JobVertexID jobVertexID = new JobVertexID();
+ return mockExecutionVertex(
+ attemptID,
+ jobVertexID,
+ Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
+ 1,
+ 1,
+ ExecutionState.RUNNING);
+ }
+
+ static ExecutionVertex mockExecutionVertex(
+ ExecutionAttemptID attemptID,
+ JobVertexID jobVertexID,
+ List<OperatorID> jobVertexIDs,
+ int parallelism,
+ int maxParallelism,
+ ExecutionState state,
+ ExecutionState ... successiveStates) {
+
+ ExecutionVertex vertex = mock(ExecutionVertex.class);
+
+ final Execution exec = spy(new Execution(
+ mock(Executor.class),
+ vertex,
+ 1,
+ 1L,
+ 1L,
+ Time.milliseconds(500L)
+ ));
+ when(exec.getAttemptId()).thenReturn(attemptID);
+ when(exec.getState()).thenReturn(state, successiveStates);
+
+ when(vertex.getJobvertexId()).thenReturn(jobVertexID);
+ when(vertex.getCurrentExecutionAttempt()).thenReturn(exec);
+ when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
+ when(vertex.getMaxParallelism()).thenReturn(maxParallelism);
+
+ ExecutionJobVertex jobVertex = mock(ExecutionJobVertex.class);
+ when(jobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
+
+ when(vertex.getJobVertex()).thenReturn(jobVertex);
+
+ return vertex;
+ }
+
+ static TaskStateSnapshot mockSubtaskState(
+ JobVertexID jobVertexID,
+ int index,
+ KeyGroupRange keyGroupRange) throws IOException {
+
+ OperatorStateHandle partitionableState = generatePartitionableStateHandle(jobVertexID, index, 2, 8, false);
+ KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false);
+
+ TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot());
+ OperatorSubtaskState subtaskState = spy(new OperatorSubtaskState(
+ partitionableState, null, partitionedKeyGroupState, null)
+ );
+
+ subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), subtaskState);
+
+ return subtaskStates;
+ }
+
+ public static KeyGroupsStateHandle generateKeyGroupState(
+ JobVertexID jobVertexID,
+ KeyGroupRange keyGroupPartition, boolean rawState) throws IOException {
+
+ List<Integer> testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups());
+
+ // generate state for one keygroup
+ for (int keyGroupIndex : keyGroupPartition) {
+ int vertexHash = jobVertexID.hashCode();
+ int seed = rawState ? (vertexHash * (31 + keyGroupIndex)) : (vertexHash + keyGroupIndex);
+ Random random = new Random(seed);
+ int simulatedStateValue = random.nextInt();
+ testStatesLists.add(simulatedStateValue);
+ }
+
+ return generateKeyGroupState(keyGroupPartition, testStatesLists);
+ }
+
+ public static KeyGroupsStateHandle generateKeyGroupState(
+ KeyGroupRange keyGroupRange,
+ List<? extends Serializable> states) throws IOException {
+
+ Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size());
+
+ Tuple2<byte[], List<long[]>> serializedDataWithOffsets =
+ serializeTogetherAndTrackOffsets(Collections.<List<? extends Serializable>>singletonList(states));
+
+ KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0));
+
+ ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
+ String.valueOf(UUID.randomUUID()),
+ serializedDataWithOffsets.f0);
+
+ return new KeyGroupsStateHandle(keyGroupRangeOffsets, allSerializedStatesHandle);
+ }
+
+ static Execution mockExecution() {
+ Execution mock = mock(Execution.class);
+ when(mock.getAttemptId()).thenReturn(new ExecutionAttemptID());
+ when(mock.getState()).thenReturn(ExecutionState.RUNNING);
+ return mock;
+ }
+
+ static ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask, int parallelism) {
+ ExecutionVertex mock = mock(ExecutionVertex.class);
+ when(mock.getJobvertexId()).thenReturn(vertexId);
+ when(mock.getParallelSubtaskIndex()).thenReturn(subtask);
+ when(mock.getCurrentExecutionAttempt()).thenReturn(execution);
+ when(mock.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
+ when(mock.getMaxParallelism()).thenReturn(parallelism);
+ return mock;
+ }
+
+ static ExecutionJobVertex mockExecutionJobVertex(JobVertexID id, ExecutionVertex[] vertices) {
+ ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
+ when(vertex.getParallelism()).thenReturn(vertices.length);
+ when(vertex.getMaxParallelism()).thenReturn(vertices.length);
+ when(vertex.getJobVertexId()).thenReturn(id);
+ when(vertex.getTaskVertices()).thenReturn(vertices);
+ when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorID.fromJobVertexID(id)));
+ when(vertex.getUserDefinedOperatorIDs()).thenReturn(Collections.<OperatorID>singletonList(null));
+
+ for (ExecutionVertex v : vertices) {
+ when(v.getJobVertex()).thenReturn(vertex);
+ }
+ return vertex;
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index af2e5a3..42849b5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -79,9 +79,9 @@ public class CheckpointStateRestoreTest {
public void testSetState() {
try {
- KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
+ KeyGroupRange keyGroupRange = KeyGroupRange.of(0, 0);
List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
- final KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+ final KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTestingUtils.generateKeyGroupState(keyGroupRange, testStates);
final JobID jid = new JobID();
final JobVertexID statefulId = new JobVertexID();
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index f18b4c8..0fb46e0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -21,7 +21,7 @@ package org.apache.flink.runtime.messages;
import org.apache.flink.api.common.JobID;
import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.core.testutils.CommonTestUtils;
-import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest;
+import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils;
import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
@@ -45,6 +45,9 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
+/**
+ * Tests for checkpoint messages.
+ */
public class CheckpointMessagesTest {
@Test
@@ -69,15 +72,15 @@ public class CheckpointMessagesTest {
AcknowledgeCheckpoint noState = new AcknowledgeCheckpoint(
new JobID(), new ExecutionAttemptID(), 569345L);
- KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42);
+ KeyGroupRange keyGroupRange = KeyGroupRange.of(42, 42);
TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot();
checkpointStateHandles.putSubtaskStateByOperatorID(
new OperatorID(),
new OperatorSubtaskState(
- CheckpointCoordinatorTest.generatePartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
+ CheckpointCoordinatorTestingUtils.generatePartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
null,
- CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
+ CheckpointCoordinatorTestingUtils.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
null
)
);
@@ -105,7 +108,7 @@ public class CheckpointMessagesTest {
assertNotNull(copy.toString());
}
- public static class MyHandle implements StreamStateHandle {
+ private static class MyHandle implements StreamStateHandle {
private static final long serialVersionUID = 8128146204128728332L;