You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2017/08/15 12:57:12 UTC

[4/7] flink git commit: [FLINK-7213] Introduce state management by OperatorID in TaskManager

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index 923b912..09e9a1b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -32,7 +32,9 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -49,6 +51,7 @@ import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
@@ -56,7 +59,6 @@ import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
-import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupRange;
@@ -65,7 +67,6 @@ import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackendFactory;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
@@ -128,6 +129,7 @@ import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyCollectionOf;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Matchers.eq;
@@ -158,6 +160,7 @@ public class StreamTaskTest extends TestLogger {
 	public void testEarlyCanceling() throws Exception {
 		Deadline deadline = new FiniteDuration(2, TimeUnit.MINUTES).fromNow();
 		StreamConfig cfg = new StreamConfig(new Configuration());
+		cfg.setOperatorID(new OperatorID(4711L, 42L));
 		cfg.setStreamOperator(new SlowlyDeserializingOperator());
 		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
 
@@ -203,6 +206,7 @@ public class StreamTaskTest extends TestLogger {
 		taskManagerConfig.setString(CoreOptions.STATE_BACKEND, MockStateBackend.class.getName());
 
 		StreamConfig cfg = new StreamConfig(new Configuration());
+		cfg.setOperatorID(new OperatorID(4711L, 42L));
 		cfg.setStreamOperator(new StreamSource<>(new MockSourceFunction()));
 		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
 
@@ -227,6 +231,7 @@ public class StreamTaskTest extends TestLogger {
 		taskManagerConfig.setString(CoreOptions.STATE_BACKEND, MockStateBackend.class.getName());
 
 		StreamConfig cfg = new StreamConfig(new Configuration());
+		cfg.setOperatorID(new OperatorID(4711L, 42L));
 		cfg.setStreamOperator(new StreamSource<>(new MockSourceFunction()));
 		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
 
@@ -324,6 +329,13 @@ public class StreamTaskTest extends TestLogger {
 		when(streamOperator2.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle2);
 		when(streamOperator3.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle3);
 
+		OperatorID operatorID1 = new OperatorID();
+		OperatorID operatorID2 = new OperatorID();
+		OperatorID operatorID3 = new OperatorID();
+		when(streamOperator1.getOperatorID()).thenReturn(operatorID1);
+		when(streamOperator2.getOperatorID()).thenReturn(operatorID2);
+		when(streamOperator3.getOperatorID()).thenReturn(operatorID3);
+
 		// set up the task
 
 		StreamOperator<?>[] streamOperators = {streamOperator1, streamOperator2, streamOperator3};
@@ -399,6 +411,13 @@ public class StreamTaskTest extends TestLogger {
 		when(streamOperator2.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle2);
 		when(streamOperator3.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle3);
 
+		OperatorID operatorID1 = new OperatorID();
+		OperatorID operatorID2 = new OperatorID();
+		OperatorID operatorID3 = new OperatorID();
+		when(streamOperator1.getOperatorID()).thenReturn(operatorID1);
+		when(streamOperator2.getOperatorID()).thenReturn(operatorID2);
+		when(streamOperator3.getOperatorID()).thenReturn(operatorID3);
+
 		StreamOperator<?>[] streamOperators = {streamOperator1, streamOperator2, streamOperator3};
 
 		OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain = mock(OperatorChain.class);
@@ -455,7 +474,7 @@ public class StreamTaskTest extends TestLogger {
 
 				return null;
 			}
-		}).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(SubtaskState.class));
+		}).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class));
 
 		StreamTask<?, AbstractStreamOperator<?>> streamTask = mock(StreamTask.class, Mockito.CALLS_REAL_METHODS);
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp);
@@ -505,18 +524,19 @@ public class StreamTaskTest extends TestLogger {
 
 		acknowledgeCheckpointLatch.await();
 
-		ArgumentCaptor<SubtaskState> subtaskStateCaptor = ArgumentCaptor.forClass(SubtaskState.class);
+		ArgumentCaptor<TaskStateSnapshot> subtaskStateCaptor = ArgumentCaptor.forClass(TaskStateSnapshot.class);
 
 		// check that the checkpoint has been completed
 		verify(mockEnvironment).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), subtaskStateCaptor.capture());
 
-		SubtaskState subtaskState = subtaskStateCaptor.getValue();
+		TaskStateSnapshot subtaskStates = subtaskStateCaptor.getValue();
+		OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateMappings().iterator().next().getValue();
 
 		// check that the subtask state contains the expected state handles
-		assertEquals(managedKeyedStateHandle, subtaskState.getManagedKeyedState());
-		assertEquals(rawKeyedStateHandle, subtaskState.getRawKeyedState());
-		assertEquals(new ChainedStateHandle<>(Collections.singletonList(managedOperatorStateHandle)), subtaskState.getManagedOperatorState());
-		assertEquals(new ChainedStateHandle<>(Collections.singletonList(rawOperatorStateHandle)), subtaskState.getRawOperatorState());
+		assertEquals(Collections.singletonList(managedKeyedStateHandle), subtaskState.getManagedKeyedState());
+		assertEquals(Collections.singletonList(rawKeyedStateHandle), subtaskState.getRawKeyedState());
+		assertEquals(Collections.singletonList(managedOperatorStateHandle), subtaskState.getManagedOperatorState());
+		assertEquals(Collections.singletonList(rawOperatorStateHandle), subtaskState.getRawOperatorState());
 
 		// check that the state handles have not been discarded
 		verify(managedKeyedStateHandle, never()).discardState();
@@ -558,18 +578,26 @@ public class StreamTaskTest extends TestLogger {
 		Environment mockEnvironment = mock(Environment.class);
 		when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo);
 
-		whenNew(SubtaskState.class).withAnyArguments().thenAnswer(new Answer<SubtaskState>() {
-			@Override
-			public SubtaskState answer(InvocationOnMock invocation) throws Throwable {
+		whenNew(OperatorSubtaskState.class).
+			withArguments(
+				any(StreamStateHandle.class),
+				anyCollectionOf(OperatorStateHandle.class),
+				anyCollectionOf(OperatorStateHandle.class),
+				anyCollectionOf(KeyedStateHandle.class),
+				anyCollectionOf(KeyedStateHandle.class)).
+			thenAnswer(new Answer<OperatorSubtaskState>() {
+				@Override
+			public OperatorSubtaskState answer(InvocationOnMock invocation) throws Throwable {
 				createSubtask.trigger();
 				completeSubtask.await();
-
-				return new SubtaskState(
-					(ChainedStateHandle<StreamStateHandle>) invocation.getArguments()[0],
-					(ChainedStateHandle<OperatorStateHandle>) invocation.getArguments()[1],
-					(ChainedStateHandle<OperatorStateHandle>) invocation.getArguments()[2],
-					(KeyedStateHandle) invocation.getArguments()[3],
-					(KeyedStateHandle) invocation.getArguments()[4]);
+				Object[] arguments = invocation.getArguments();
+				return new OperatorSubtaskState(
+					(StreamStateHandle) arguments[0],
+					(OperatorStateHandle) arguments[1],
+					(OperatorStateHandle) arguments[2],
+					(KeyedStateHandle) arguments[3],
+					(KeyedStateHandle) arguments[4]
+				);
 			}
 		});
 
@@ -577,7 +605,9 @@ public class StreamTaskTest extends TestLogger {
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp);
 		streamTask.setEnvironment(mockEnvironment);
 
-		StreamOperator<?> streamOperator = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class));
+		final StreamOperator<?> streamOperator = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class));
+		final OperatorID operatorID = new OperatorID();
+		when(streamOperator.getOperatorID()).thenReturn(operatorID);
 
 		KeyedStateHandle managedKeyedStateHandle = mock(KeyedStateHandle.class);
 		KeyedStateHandle rawKeyedStateHandle = mock(KeyedStateHandle.class);
@@ -636,7 +666,7 @@ public class StreamTaskTest extends TestLogger {
 		}
 
 		// check that the checkpoint has not been acknowledged
-		verify(mockEnvironment, never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), any(SubtaskState.class));
+		verify(mockEnvironment, never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), any(TaskStateSnapshot.class));
 
 		// check that the state handles have been discarded
 		verify(managedKeyedStateHandle).discardState();
@@ -676,7 +706,7 @@ public class StreamTaskTest extends TestLogger {
 				checkpointCompletedLatch.trigger();
 				return null;
 			}
-		}).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(SubtaskState.class));
+		}).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class));
 
 		when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo);
 
@@ -688,6 +718,9 @@ public class StreamTaskTest extends TestLogger {
 		StreamOperator<?> statelessOperator =
 				mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class));
 
+		final OperatorID operatorID = new OperatorID();
+		when(statelessOperator.getOperatorID()).thenReturn(operatorID);
+
 		// mock the returned empty snapshot result (all state handles are null)
 		OperatorSnapshotResult statelessOperatorSnapshotResult = new OperatorSnapshotResult();
 		when(statelessOperator.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class)))
@@ -803,7 +836,7 @@ public class StreamTaskTest extends TestLogger {
 			Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 			Collections.<InputGateDeploymentDescriptor>emptyList(),
 			0,
-			new TaskStateHandles(),
+			new TaskStateSnapshot(),
 			mock(MemoryManager.class),
 			mock(IOManager.class),
 			network,

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
index a02fe4e..19d48e1 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
@@ -142,6 +143,7 @@ public class StreamTaskTestHarness<OUT> {
 		streamConfig.setNumberOfOutputs(1);
 		streamConfig.setTypeSerializerOut(outputSerializer);
 		streamConfig.setVertexID(0);
+		streamConfig.setOperatorID(new OperatorID(4711L, 123L));
 
 		StreamOperator<OUT> dummyOperator = new AbstractStreamOperator<OUT>() {
 			private static final long serialVersionUID = 1L;

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
index 66531ac..d785c0d 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.streaming.api.functions.co.CoMapFunction;
 import org.apache.flink.streaming.api.functions.co.RichCoMapFunction;
 import org.apache.flink.streaming.api.graph.StreamConfig;
@@ -64,6 +65,7 @@ public class TwoInputStreamTaskTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		CoStreamMap<String, Integer, String> coMapOperator = new CoStreamMap<String, Integer, String>(new TestOpenCloseMapFunction());
 		streamConfig.setStreamOperator(coMapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		long initialTime = 0L;
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
@@ -110,6 +112,7 @@ public class TwoInputStreamTaskTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		CoStreamMap<String, Integer, String> coMapOperator = new CoStreamMap<String, Integer, String>(new IdentityMap());
 		streamConfig.setStreamOperator(coMapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
 		long initialTime = 0L;
@@ -216,6 +219,7 @@ public class TwoInputStreamTaskTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		CoStreamMap<String, Integer, String> coMapOperator = new CoStreamMap<String, Integer, String>(new IdentityMap());
 		streamConfig.setStreamOperator(coMapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
 		long initialTime = 0L;
@@ -296,6 +300,7 @@ public class TwoInputStreamTaskTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		CoStreamMap<String, Integer, String> coMapOperator = new CoStreamMap<String, Integer, String>(new IdentityMap());
 		streamConfig.setStreamOperator(coMapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
 		long initialTime = 0L;

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
index 47e8726..15802353 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
@@ -32,9 +32,11 @@ import org.apache.flink.migration.streaming.runtime.tasks.StreamTaskState;
 import org.apache.flink.migration.util.MigrationInstantiationUtil;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
 import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
 import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
@@ -154,6 +156,7 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 		Configuration underlyingConfig = environment.getTaskConfiguration();
 		this.config = new StreamConfig(underlyingConfig);
 		this.config.setCheckpointingEnabled(true);
+		this.config.setOperatorID(new OperatorID());
 		this.executionConfig = environment.getExecutionConfig();
 		this.closableRegistry = new CloseableRegistry();
 		this.checkpointLock = new Object();
@@ -336,7 +339,7 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 	}
 
 	/**
-	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}.
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorSubtaskState)}.
 	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)}
 	 * if it was not called before.
 	 *
@@ -393,13 +396,12 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 					rawOperatorState,
 					numSubtasks).get(subtaskIndex);
 
-			OperatorStateHandles massagedOperatorStateHandles = new OperatorStateHandles(
-					0,
-					operatorStateHandles.getLegacyOperatorState(),
-					localManagedKeyGroupState,
-					localRawKeyGroupState,
-					localManagedOperatorState,
-					localRawOperatorState);
+			OperatorSubtaskState massagedOperatorStateHandles = new OperatorSubtaskState(
+				operatorStateHandles.getLegacyOperatorState(),
+				nullToEmptyCollection(localManagedOperatorState),
+				nullToEmptyCollection(localRawOperatorState),
+				nullToEmptyCollection(localManagedKeyGroupState),
+				nullToEmptyCollection(localRawKeyGroupState));
 
 			operator.initializeState(massagedOperatorStateHandles);
 		} else {
@@ -408,6 +410,10 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 		initializeCalled = true;
 	}
 
+	private static <T> Collection<T> nullToEmptyCollection(Collection<T> collection) {
+		return collection != null ? collection : Collections.<T>emptyList();
+	}
+
 	/**
 	 * Takes the different {@link OperatorStateHandles} created by calling {@link #snapshot(long, long)}
 	 * on different instances of {@link AbstractStreamOperatorTestHarness} (each one representing one subtask)

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
index bf1bb1b..cc23545 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
@@ -354,7 +354,7 @@ public class SavepointITCase extends TestLogger {
 
 					errMsg = "Initial operator state mismatch.";
 					assertEquals(errMsg, subtaskState.getLegacyOperatorState(),
-						tdd.getTaskStateHandles().getLegacyOperatorState().get(chainIndexAndJobVertex.f0));
+						tdd.getTaskStateHandles().getSubtaskStateByOperatorID(operatorState.getOperatorID()).getLegacyOperatorState());
 				}
 			}