You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2020/05/16 16:27:52 UTC

[flink] 03/13: [FLINK-17670][refactor] Refactor single test in CheckpointMetadataLoadingTest into finer grained tests.

This is an automated email from the ASF dual-hosted git repository.

sewen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 90e4708c9a5ab24b5f50423a181d75be375e2d1e
Author: Stephan Ewen <se...@apache.org>
AuthorDate: Wed May 13 16:00:29 2020 +0200

    [FLINK-17670][refactor] Refactor single test in CheckpointMetadataLoadingTest into finer grained tests.
---
 .../checkpoint/CheckpointMetadataLoadingTest.java  | 161 ++++++++++++++-------
 1 file changed, 106 insertions(+), 55 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointMetadataLoadingTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointMetadataLoadingTest.java
index 557e8ba..936fe23 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointMetadataLoadingTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointMetadataLoadingTest.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLo
 import org.junit.Test;
 
 import java.io.ByteArrayOutputStream;
+import java.io.IOException;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
@@ -42,7 +43,6 @@ import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNew
 import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewResultSubpartitionStateHandle;
 import static org.apache.flink.runtime.checkpoint.StateObjectCollection.singleton;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Mockito.mock;
@@ -54,37 +54,90 @@ import static org.mockito.Mockito.when;
  */
 public class CheckpointMetadataLoadingTest {
 
+	private final ClassLoader cl = getClass().getClassLoader();
+
 	/**
-	 * Tests loading and validation of savepoints with correct setup,
-	 * parallelism mismatch, and a missing task.
+	 * Tests correct savepoint loading.
 	 */
 	@Test
-	public void testLoadAndValidateSavepoint() throws Exception {
-		final Random rnd = new Random();
+	public void testAllStateRestored() throws Exception {
+		final JobID jobId = new JobID();
+		final OperatorID operatorId = new OperatorID();
+		final long checkpointId = Integer.MAX_VALUE + 123123L;
+		final int parallelism = 128128;
 
-		int parallelism = 128128;
-		long checkpointId = Integer.MAX_VALUE + 123123L;
-		JobVertexID jobVertexID = new JobVertexID();
-		OperatorID operatorID = OperatorID.fromJobVertexID(jobVertexID);
+		final CompletedCheckpointStorageLocation testSavepoint = createSavepointWithOperatorSubtaskState(checkpointId, operatorId, parallelism);
+		final Map<JobVertexID, ExecutionJobVertex> tasks = createTasks(operatorId, parallelism, parallelism);
 
-		OperatorSubtaskState subtaskState = new OperatorSubtaskState(
-				new OperatorStreamStateHandle(Collections.emptyMap(), new ByteStreamStateHandle("testHandler", new byte[0])),
-				null,
-				null,
-				null,
-				singleton(createNewInputChannelStateHandle(10, rnd)),
-				singleton(createNewResultSubpartitionStateHandle(10, rnd)));
+		final CompletedCheckpoint loaded = Checkpoints.loadAndValidateCheckpoint(jobId, tasks, testSavepoint, cl, false);
 
-		OperatorState state = new OperatorState(operatorID, parallelism, parallelism);
-		state.putState(0, subtaskState);
+		assertEquals(jobId, loaded.getJobId());
+		assertEquals(checkpointId, loaded.getCheckpointID());
+	}
+
+	/**
+	 * Tests that savepoint loading fails when there is a max-parallelism mismatch.
+	 */
+	@Test
+	public void testMaxParallelismMismatch() throws Exception {
+		final OperatorID operatorId = new OperatorID();
+		final int parallelism = 128128;
 
-		Map<OperatorID, OperatorState> taskStates = new HashMap<>();
-		taskStates.put(operatorID, state);
+		final CompletedCheckpointStorageLocation testSavepoint = createSavepointWithOperatorSubtaskState(242L, operatorId, parallelism);
+		final Map<JobVertexID, ExecutionJobVertex> tasks = createTasks(operatorId, parallelism, parallelism + 1);
 
-		JobID jobId = new JobID();
+		try {
+			Checkpoints.loadAndValidateCheckpoint(new JobID(), tasks, testSavepoint, cl, false);
+			fail("Did not throw expected Exception");
+		} catch (IllegalStateException expected) {
+			assertTrue(expected.getMessage().contains("Max parallelism mismatch"));
+		}
+	}
 
-		// Store savepoint
-		final CheckpointMetadata savepoint = new CheckpointMetadata(checkpointId, taskStates.values(), Collections.emptyList());
+	/**
+	 * Tests that savepoint loading fails when there is non-restored state, but it is not allowed.
+	 */
+	@Test
+	public void testNonRestoredStateWhenDisallowed() throws Exception {
+		final OperatorID operatorId = new OperatorID();
+		final int parallelism = 9;
+
+		final CompletedCheckpointStorageLocation testSavepoint = createSavepointWithOperatorSubtaskState(242L, operatorId, parallelism);
+		final Map<JobVertexID, ExecutionJobVertex> tasks = Collections.emptyMap();
+
+		try {
+			Checkpoints.loadAndValidateCheckpoint(new JobID(), tasks, testSavepoint, cl, false);
+			fail("Did not throw expected Exception");
+		} catch (IllegalStateException expected) {
+			assertTrue(expected.getMessage().contains("allowNonRestoredState"));
+		}
+	}
+
+	/**
+	 * Tests that savepoint loading succeeds when there is non-restored state and it is not allowed.
+	 */
+	@Test
+	public void testNonRestoredStateWhenAllowed() throws Exception {
+		final OperatorID operatorId = new OperatorID();
+		final int parallelism = 9;
+
+		final CompletedCheckpointStorageLocation testSavepoint = createSavepointWithOperatorSubtaskState(242L, operatorId, parallelism);
+		final Map<JobVertexID, ExecutionJobVertex> tasks = Collections.emptyMap();
+
+		final CompletedCheckpoint loaded = Checkpoints.loadAndValidateCheckpoint(new JobID(), tasks, testSavepoint, cl, true);
+
+		assertTrue(loaded.getOperatorStates().isEmpty());
+	}
+
+	// ------------------------------------------------------------------------
+	//  setup utils
+	// ------------------------------------------------------------------------
+
+	private static CompletedCheckpointStorageLocation createSavepointWithOperatorState(
+			final long checkpointId,
+			final OperatorState state) throws IOException {
+
+		final CheckpointMetadata savepoint = new CheckpointMetadata(checkpointId, Collections.singletonList(state), Collections.emptyList());
 		final StreamStateHandle serializedMetadata;
 
 		try (ByteArrayOutputStream os = new ByteArrayOutputStream()) {
@@ -92,47 +145,45 @@ public class CheckpointMetadataLoadingTest {
 			serializedMetadata = new ByteStreamStateHandle("checkpoint", os.toByteArray());
 		}
 
-		final CompletedCheckpointStorageLocation storageLocation = new TestCompletedCheckpointStorageLocation(
-				serializedMetadata, "dummy/pointer");
-
-		ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
-		when(vertex.getParallelism()).thenReturn(parallelism);
-		when(vertex.getMaxParallelism()).thenReturn(parallelism);
-		when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorIDPair.generatedIDOnly(operatorID)));
+		return new TestCompletedCheckpointStorageLocation(serializedMetadata, "dummy/pointer");
+	}
 
-		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
-		tasks.put(jobVertexID, vertex);
+	private static CompletedCheckpointStorageLocation createSavepointWithOperatorSubtaskState(
+			final long checkpointId,
+			final OperatorID operatorId,
+			final int parallelism) throws IOException {
 
-		ClassLoader ucl = Thread.currentThread().getContextClassLoader();
+		final Random rnd = new Random();
 
-		// 1) Load and validate: everything correct
-		CompletedCheckpoint loaded = Checkpoints.loadAndValidateCheckpoint(jobId, tasks, storageLocation, ucl, false);
+		final OperatorSubtaskState subtaskState = new OperatorSubtaskState(
+			new OperatorStreamStateHandle(Collections.emptyMap(), new ByteStreamStateHandle("testHandler", new byte[0])),
+			null,
+			null,
+			null,
+			singleton(createNewInputChannelStateHandle(10, rnd)),
+			singleton(createNewResultSubpartitionStateHandle(10, rnd)));
 
-		assertEquals(jobId, loaded.getJobId());
-		assertEquals(checkpointId, loaded.getCheckpointID());
+		final OperatorState state = new OperatorState(operatorId, parallelism, parallelism);
+		state.putState(0, subtaskState);
 
-		// 2) Load and validate: max parallelism mismatch
-		when(vertex.getMaxParallelism()).thenReturn(222);
-		when(vertex.isMaxParallelismConfigured()).thenReturn(true);
+		return createSavepointWithOperatorState(checkpointId, state);
+	}
 
-		try {
-			Checkpoints.loadAndValidateCheckpoint(jobId, tasks, storageLocation, ucl, false);
-			fail("Did not throw expected Exception");
-		} catch (IllegalStateException expected) {
-			assertTrue(expected.getMessage().contains("Max parallelism mismatch"));
-		}
+	private static Map<JobVertexID, ExecutionJobVertex> createTasks(OperatorID operatorId, int parallelism, int maxParallelism) {
+		final JobVertexID vertexId = new JobVertexID(operatorId.getLowerPart(), operatorId.getUpperPart());
 
-		// 3) Load and validate: missing vertex
-		assertNotNull(tasks.remove(jobVertexID));
+		ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
+		when(vertex.getParallelism()).thenReturn(parallelism);
+		when(vertex.getMaxParallelism()).thenReturn(maxParallelism);
+		when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorIDPair.generatedIDOnly(operatorId)));
 
-		try {
-			Checkpoints.loadAndValidateCheckpoint(jobId, tasks, storageLocation, ucl, false);
-			fail("Did not throw expected Exception");
-		} catch (IllegalStateException expected) {
-			assertTrue(expected.getMessage().contains("allowNonRestoredState"));
+		if (parallelism != maxParallelism) {
+			when(vertex.isMaxParallelismConfigured()).thenReturn(true);
 		}
 
-		// 4) Load and validate: ignore missing vertex
-		Checkpoints.loadAndValidateCheckpoint(jobId, tasks, storageLocation, ucl, true);
+		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+		tasks.put(vertexId, vertex);
+
+		return tasks;
 	}
 }