You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2020/05/16 16:27:52 UTC
[flink] 03/13: [FLINK-17670][refactor] Refactor single test in
CheckpointMetadataLoadingTest into finer grained tests.
This is an automated email from the ASF dual-hosted git repository.
sewen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 90e4708c9a5ab24b5f50423a181d75be375e2d1e
Author: Stephan Ewen <se...@apache.org>
AuthorDate: Wed May 13 16:00:29 2020 +0200
[FLINK-17670][refactor] Refactor single test in CheckpointMetadataLoadingTest into finer grained tests.
---
.../checkpoint/CheckpointMetadataLoadingTest.java | 161 ++++++++++++++-------
1 file changed, 106 insertions(+), 55 deletions(-)
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointMetadataLoadingTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointMetadataLoadingTest.java
index 557e8ba..936fe23 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointMetadataLoadingTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointMetadataLoadingTest.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLo
import org.junit.Test;
import java.io.ByteArrayOutputStream;
+import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -42,7 +43,6 @@ import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNew
import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewResultSubpartitionStateHandle;
import static org.apache.flink.runtime.checkpoint.StateObjectCollection.singleton;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
@@ -54,37 +54,90 @@ import static org.mockito.Mockito.when;
*/
public class CheckpointMetadataLoadingTest {
+ private final ClassLoader cl = getClass().getClassLoader();
+
/**
- * Tests loading and validation of savepoints with correct setup,
- * parallelism mismatch, and a missing task.
+ * Tests correct savepoint loading.
*/
@Test
- public void testLoadAndValidateSavepoint() throws Exception {
- final Random rnd = new Random();
+ public void testAllStateRestored() throws Exception {
+ final JobID jobId = new JobID();
+ final OperatorID operatorId = new OperatorID();
+ final long checkpointId = Integer.MAX_VALUE + 123123L;
+ final int parallelism = 128128;
- int parallelism = 128128;
- long checkpointId = Integer.MAX_VALUE + 123123L;
- JobVertexID jobVertexID = new JobVertexID();
- OperatorID operatorID = OperatorID.fromJobVertexID(jobVertexID);
+ final CompletedCheckpointStorageLocation testSavepoint = createSavepointWithOperatorSubtaskState(checkpointId, operatorId, parallelism);
+ final Map<JobVertexID, ExecutionJobVertex> tasks = createTasks(operatorId, parallelism, parallelism);
- OperatorSubtaskState subtaskState = new OperatorSubtaskState(
- new OperatorStreamStateHandle(Collections.emptyMap(), new ByteStreamStateHandle("testHandler", new byte[0])),
- null,
- null,
- null,
- singleton(createNewInputChannelStateHandle(10, rnd)),
- singleton(createNewResultSubpartitionStateHandle(10, rnd)));
+ final CompletedCheckpoint loaded = Checkpoints.loadAndValidateCheckpoint(jobId, tasks, testSavepoint, cl, false);
- OperatorState state = new OperatorState(operatorID, parallelism, parallelism);
- state.putState(0, subtaskState);
+ assertEquals(jobId, loaded.getJobId());
+ assertEquals(checkpointId, loaded.getCheckpointID());
+ }
+
+ /**
+ * Tests that savepoint loading fails when there is a max-parallelism mismatch.
+ */
+ @Test
+ public void testMaxParallelismMismatch() throws Exception {
+ final OperatorID operatorId = new OperatorID();
+ final int parallelism = 128128;
- Map<OperatorID, OperatorState> taskStates = new HashMap<>();
- taskStates.put(operatorID, state);
+ final CompletedCheckpointStorageLocation testSavepoint = createSavepointWithOperatorSubtaskState(242L, operatorId, parallelism);
+ final Map<JobVertexID, ExecutionJobVertex> tasks = createTasks(operatorId, parallelism, parallelism + 1);
- JobID jobId = new JobID();
+ try {
+ Checkpoints.loadAndValidateCheckpoint(new JobID(), tasks, testSavepoint, cl, false);
+ fail("Did not throw expected Exception");
+ } catch (IllegalStateException expected) {
+ assertTrue(expected.getMessage().contains("Max parallelism mismatch"));
+ }
+ }
- // Store savepoint
- final CheckpointMetadata savepoint = new CheckpointMetadata(checkpointId, taskStates.values(), Collections.emptyList());
+ /**
+ * Tests that savepoint loading fails when there is non-restored state, but it is not allowed.
+ */
+ @Test
+ public void testNonRestoredStateWhenDisallowed() throws Exception {
+ final OperatorID operatorId = new OperatorID();
+ final int parallelism = 9;
+
+ final CompletedCheckpointStorageLocation testSavepoint = createSavepointWithOperatorSubtaskState(242L, operatorId, parallelism);
+ final Map<JobVertexID, ExecutionJobVertex> tasks = Collections.emptyMap();
+
+ try {
+ Checkpoints.loadAndValidateCheckpoint(new JobID(), tasks, testSavepoint, cl, false);
+ fail("Did not throw expected Exception");
+ } catch (IllegalStateException expected) {
+ assertTrue(expected.getMessage().contains("allowNonRestoredState"));
+ }
+ }
+
+ /**
+ * Tests that savepoint loading succeeds when there is non-restored state and it is not allowed.
+ */
+ @Test
+ public void testNonRestoredStateWhenAllowed() throws Exception {
+ final OperatorID operatorId = new OperatorID();
+ final int parallelism = 9;
+
+ final CompletedCheckpointStorageLocation testSavepoint = createSavepointWithOperatorSubtaskState(242L, operatorId, parallelism);
+ final Map<JobVertexID, ExecutionJobVertex> tasks = Collections.emptyMap();
+
+ final CompletedCheckpoint loaded = Checkpoints.loadAndValidateCheckpoint(new JobID(), tasks, testSavepoint, cl, true);
+
+ assertTrue(loaded.getOperatorStates().isEmpty());
+ }
+
+ // ------------------------------------------------------------------------
+ // setup utils
+ // ------------------------------------------------------------------------
+
+ private static CompletedCheckpointStorageLocation createSavepointWithOperatorState(
+ final long checkpointId,
+ final OperatorState state) throws IOException {
+
+ final CheckpointMetadata savepoint = new CheckpointMetadata(checkpointId, Collections.singletonList(state), Collections.emptyList());
final StreamStateHandle serializedMetadata;
try (ByteArrayOutputStream os = new ByteArrayOutputStream()) {
@@ -92,47 +145,45 @@ public class CheckpointMetadataLoadingTest {
serializedMetadata = new ByteStreamStateHandle("checkpoint", os.toByteArray());
}
- final CompletedCheckpointStorageLocation storageLocation = new TestCompletedCheckpointStorageLocation(
- serializedMetadata, "dummy/pointer");
-
- ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
- when(vertex.getParallelism()).thenReturn(parallelism);
- when(vertex.getMaxParallelism()).thenReturn(parallelism);
- when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorIDPair.generatedIDOnly(operatorID)));
+ return new TestCompletedCheckpointStorageLocation(serializedMetadata, "dummy/pointer");
+ }
- Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
- tasks.put(jobVertexID, vertex);
+ private static CompletedCheckpointStorageLocation createSavepointWithOperatorSubtaskState(
+ final long checkpointId,
+ final OperatorID operatorId,
+ final int parallelism) throws IOException {
- ClassLoader ucl = Thread.currentThread().getContextClassLoader();
+ final Random rnd = new Random();
- // 1) Load and validate: everything correct
- CompletedCheckpoint loaded = Checkpoints.loadAndValidateCheckpoint(jobId, tasks, storageLocation, ucl, false);
+ final OperatorSubtaskState subtaskState = new OperatorSubtaskState(
+ new OperatorStreamStateHandle(Collections.emptyMap(), new ByteStreamStateHandle("testHandler", new byte[0])),
+ null,
+ null,
+ null,
+ singleton(createNewInputChannelStateHandle(10, rnd)),
+ singleton(createNewResultSubpartitionStateHandle(10, rnd)));
- assertEquals(jobId, loaded.getJobId());
- assertEquals(checkpointId, loaded.getCheckpointID());
+ final OperatorState state = new OperatorState(operatorId, parallelism, parallelism);
+ state.putState(0, subtaskState);
- // 2) Load and validate: max parallelism mismatch
- when(vertex.getMaxParallelism()).thenReturn(222);
- when(vertex.isMaxParallelismConfigured()).thenReturn(true);
+ return createSavepointWithOperatorState(checkpointId, state);
+ }
- try {
- Checkpoints.loadAndValidateCheckpoint(jobId, tasks, storageLocation, ucl, false);
- fail("Did not throw expected Exception");
- } catch (IllegalStateException expected) {
- assertTrue(expected.getMessage().contains("Max parallelism mismatch"));
- }
+ private static Map<JobVertexID, ExecutionJobVertex> createTasks(OperatorID operatorId, int parallelism, int maxParallelism) {
+ final JobVertexID vertexId = new JobVertexID(operatorId.getLowerPart(), operatorId.getUpperPart());
- // 3) Load and validate: missing vertex
- assertNotNull(tasks.remove(jobVertexID));
+ ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
+ when(vertex.getParallelism()).thenReturn(parallelism);
+ when(vertex.getMaxParallelism()).thenReturn(maxParallelism);
+ when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorIDPair.generatedIDOnly(operatorId)));
- try {
- Checkpoints.loadAndValidateCheckpoint(jobId, tasks, storageLocation, ucl, false);
- fail("Did not throw expected Exception");
- } catch (IllegalStateException expected) {
- assertTrue(expected.getMessage().contains("allowNonRestoredState"));
+ if (parallelism != maxParallelism) {
+ when(vertex.isMaxParallelismConfigured()).thenReturn(true);
}
- // 4) Load and validate: ignore missing vertex
- Checkpoints.loadAndValidateCheckpoint(jobId, tasks, storageLocation, ucl, true);
+ Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+ tasks.put(vertexId, vertex);
+
+ return tasks;
}
}