You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by uc...@apache.org on 2016/10/07 12:15:34 UTC
flink git commit: [FLINK-4731] Fix HeapKeyedStateBackend Scale-In
Repository: flink
Updated Branches:
refs/heads/master 97c71675a -> 8d953bf26
[FLINK-4731] Fix HeapKeyedStateBackend Scale-In
Adds additional tests in RescalingITCase for scale-in
This closes #2584.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/8d953bf2
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/8d953bf2
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/8d953bf2
Branch: refs/heads/master
Commit: 8d953bf2626012e3e497334641962bd8f96098de
Parents: 97c7167
Author: Stefan Richter <s....@data-artisans.com>
Authored: Sun Oct 2 16:56:41 2016 +0200
Committer: Ufuk Celebi <uc...@apache.org>
Committed: Fri Oct 7 14:14:27 2016 +0200
----------------------------------------------------------------------
.../savepoint/SavepointV1Serializer.java | 4 +-
.../filesystem/FsCheckpointStreamFactory.java | 11 +-
.../state/heap/HeapKeyedStateBackend.java | 164 ++++++------
.../state/memory/ByteStreamStateHandle.java | 79 +++---
.../memory/MemCheckpointStreamFactory.java | 3 +-
.../checkpoint/CheckpointCoordinatorTest.java | 20 +-
.../checkpoint/savepoint/SavepointV1Test.java | 9 +-
.../jobmanager/JobManagerHARecoveryTest.java | 4 +-
.../ZooKeeperSubmittedJobGraphsStoreITCase.java | 7 +-
.../runtime/testutils/CommonTestUtils.java | 23 ++
.../TestByteStreamStateHandleDeepCompare.java | 54 ++++
.../test/checkpointing/RescalingITCase.java | 253 ++++++++++++++++---
12 files changed, 465 insertions(+), 166 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
index f120e1d..666176b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
@@ -290,6 +290,7 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
} else if (stateHandle instanceof ByteStreamStateHandle) {
dos.writeByte(BYTE_STREAM_STATE_HANDLE);
ByteStreamStateHandle byteStreamStateHandle = (ByteStreamStateHandle) stateHandle;
+ dos.writeUTF(byteStreamStateHandle.getHandleName());
byte[] internalData = byteStreamStateHandle.getData();
dos.writeInt(internalData.length);
dos.write(byteStreamStateHandle.getData());
@@ -310,10 +311,11 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
String pathString = dis.readUTF();
return new FileStateHandle(new Path(pathString), size);
} else if (BYTE_STREAM_STATE_HANDLE == type) {
+ String handleName = dis.readUTF();
int numBytes = dis.readInt();
byte[] data = new byte[numBytes];
dis.readFully(data);
- return new ByteStreamStateHandle(data);
+ return new ByteStreamStateHandle(handleName, data);
} else {
throw new IOException("Unknown implementation of StreamStateHandle, code: " + type);
}
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java
index e4f7eba..fcc97b3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java
@@ -177,7 +177,6 @@ public class FsCheckpointStreamFactory implements CheckpointStreamFactory {
this.localStateThreshold = localStateThreshold;
}
-
@Override
public void write(int b) throws IOException {
if (pos >= writeBuffer.length) {
@@ -219,7 +218,7 @@ public class FsCheckpointStreamFactory implements CheckpointStreamFactory {
@Override
public long getPos() throws IOException {
- return outStream == null ? pos : outStream.getPos();
+ return pos + (outStream == null ? 0 : outStream.getPos());
}
@Override
@@ -233,7 +232,7 @@ public class FsCheckpointStreamFactory implements CheckpointStreamFactory {
Exception latestException = null;
for (int attempt = 0; attempt < 10; attempt++) {
try {
- statePath = new Path(basePath, UUID.randomUUID().toString());
+ statePath = createStatePath();
outStream = fs.create(statePath, false);
break;
}
@@ -297,7 +296,7 @@ public class FsCheckpointStreamFactory implements CheckpointStreamFactory {
if (outStream == null && pos <= localStateThreshold) {
closed = true;
byte[] bytes = Arrays.copyOf(writeBuffer, pos);
- return new ByteStreamStateHandle(bytes);
+ return new ByteStreamStateHandle(createStatePath().toString(), bytes);
}
else {
flush();
@@ -318,5 +317,9 @@ public class FsCheckpointStreamFactory implements CheckpointStreamFactory {
}
}
}
+
+ private Path createStatePath() {
+ return new Path(basePath, UUID.randomUUID().toString());
+ }
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 040677b..b283494 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -30,7 +30,9 @@ import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.runtime.query.TaskKvStateRegistry;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
@@ -46,6 +48,7 @@ import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@@ -106,12 +109,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
// ------------------------------------------------------------------------
// state backend operations
// ------------------------------------------------------------------------
-
+ @SuppressWarnings("unchecked")
@Override
public <N, V> ValueState<V> createValueState(TypeSerializer<N> namespaceSerializer, ValueStateDescriptor<V> stateDesc) throws Exception {
- @SuppressWarnings("unchecked,rawtypes")
- StateTable<K, N, V> stateTable = (StateTable) stateTables.get(stateDesc.getName());
-
+ StateTable<K, N, V> stateTable = (StateTable<K, N, V>) stateTables.get(stateDesc.getName());
if (stateTable == null) {
stateTable = new StateTable<>(stateDesc.getSerializer(), namespaceSerializer, keyGroupRange);
@@ -121,10 +122,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
return new HeapValueState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer);
}
+ @SuppressWarnings("unchecked")
@Override
public <N, T> ListState<T> createListState(TypeSerializer<N> namespaceSerializer, ListStateDescriptor<T> stateDesc) throws Exception {
- @SuppressWarnings("unchecked,rawtypes")
- StateTable<K, N, ArrayList<T>> stateTable = (StateTable) stateTables.get(stateDesc.getName());
+ StateTable<K, N, ArrayList<T>> stateTable = (StateTable<K, N, ArrayList<T>>) stateTables.get(stateDesc.getName());
if (stateTable == null) {
stateTable = new StateTable<>(new ArrayListSerializer<>(stateDesc.getSerializer()), namespaceSerializer, keyGroupRange);
@@ -133,11 +134,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
return new HeapListState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer);
}
-
+ @SuppressWarnings("unchecked")
@Override
public <N, T> ReducingState<T> createReducingState(TypeSerializer<N> namespaceSerializer, ReducingStateDescriptor<T> stateDesc) throws Exception {
- @SuppressWarnings("unchecked,rawtypes")
- StateTable<K, N, T> stateTable = (StateTable) stateTables.get(stateDesc.getName());
+ StateTable<K, N, T> stateTable = (StateTable<K, N, T>) stateTables.get(stateDesc.getName());
if (stateTable == null) {
stateTable = new StateTable<>(stateDesc.getSerializer(), namespaceSerializer, keyGroupRange);
@@ -146,11 +146,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
return new HeapReducingState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer);
}
-
+ @SuppressWarnings("unchecked")
@Override
protected <N, T, ACC> FoldingState<T, ACC> createFoldingState(TypeSerializer<N> namespaceSerializer, FoldingStateDescriptor<T, ACC> stateDesc) throws Exception {
- @SuppressWarnings("unchecked,rawtypes")
- StateTable<K, N, ACC> stateTable = (StateTable) stateTables.get(stateDesc.getName());
+ StateTable<K, N, ACC> stateTable = (StateTable<K, N, ACC>) stateTables.get(stateDesc.getName());
if (stateTable == null) {
stateTable = new StateTable<>(stateDesc.getSerializer(), namespaceSerializer, keyGroupRange);
@@ -161,7 +160,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
}
@Override
- @SuppressWarnings("rawtypes,unchecked")
+ @SuppressWarnings("unchecked")
public RunnableFuture<KeyGroupsStateHandle> snapshot(
long checkpointId,
long timestamp,
@@ -188,8 +187,8 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
outView.writeUTF(kvState.getKey());
- TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
- TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
+ TypeSerializer<?> namespaceSerializer = kvState.getValue().getNamespaceSerializer();
+ TypeSerializer<?> stateSerializer = kvState.getValue().getStateSerializer();
InstantiationUtil.serializeObject(stream, namespaceSerializer);
InstantiationUtil.serializeObject(stream, stateSerializer);
@@ -203,39 +202,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); keyGroupIndex++) {
keyGroupRangeOffsets[offsetCounter++] = stream.getPos();
outView.writeInt(keyGroupIndex);
-
for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
-
outView.writeShort(kVStateToId.get(kvState.getKey()));
-
- TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
- TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
-
- // Map<NamespaceT, Map<KeyT, StateT>>
- Map<?, ? extends Map<K, ?>> namespaceMap = kvState.getValue().get(keyGroupIndex);
- if (namespaceMap == null) {
- outView.writeByte(0);
- continue;
- }
-
- outView.writeByte(1);
-
- // number of namespaces
- outView.writeInt(namespaceMap.size());
- for (Map.Entry<?, ? extends Map<K, ?>> namespace : namespaceMap.entrySet()) {
- namespaceSerializer.serialize(namespace.getKey(), outView);
-
- Map<K, ?> entryMap = namespace.getValue();
-
- // number of entries
- outView.writeInt(entryMap.size());
- for (Map.Entry<K, ?> entry : entryMap.entrySet()) {
- keySerializer.serialize(entry.getKey(), outView);
- stateSerializer.serialize(entry.getValue(), outView);
- }
- }
+ writeStateTableForKeyGroup(outView, kvState.getValue(), keyGroupIndex);
}
- outView.flush();
}
StreamStateHandle streamStateHandle = stream.closeAndGetHandle();
@@ -246,8 +216,42 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
}
}
- @SuppressWarnings({"unchecked", "rawtypes"})
- public void restorePartitionedState(List<KeyGroupsStateHandle> state) throws Exception {
+ private <N, S> void writeStateTableForKeyGroup(
+ DataOutputView outView,
+ StateTable<K, N, S> stateTable,
+ int keyGroupIndex) throws IOException {
+
+ TypeSerializer<N> namespaceSerializer = stateTable.getNamespaceSerializer();
+ TypeSerializer<S> stateSerializer = stateTable.getStateSerializer();
+
+ Map<N, Map<K, S>> namespaceMap = stateTable.get(keyGroupIndex);
+ if (namespaceMap == null) {
+ outView.writeByte(0);
+ } else {
+ outView.writeByte(1);
+
+ // number of namespaces
+ outView.writeInt(namespaceMap.size());
+ for (Map.Entry<N, Map<K, S>> namespace : namespaceMap.entrySet()) {
+ namespaceSerializer.serialize(namespace.getKey(), outView);
+
+ Map<K, S> entryMap = namespace.getValue();
+
+ // number of entries
+ outView.writeInt(entryMap.size());
+ for (Map.Entry<K, S> entry : entryMap.entrySet()) {
+ keySerializer.serialize(entry.getKey(), outView);
+ stateSerializer.serialize(entry.getValue(), outView);
+ }
+ }
+ }
+ }
+
+ @SuppressWarnings({"unchecked"})
+ private void restorePartitionedState(List<KeyGroupsStateHandle> state) throws Exception {
+
+ int numRegisteredKvStates = 0;
+ Map<Integer, String> kvStatesById = new HashMap<>();
for (KeyGroupsStateHandle keyGroupsHandle : state) {
@@ -266,21 +270,23 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
int numKvStates = inView.readShort();
- Map<Integer, String> kvStatesById = new HashMap<>(numKvStates);
-
for (int i = 0; i < numKvStates; ++i) {
String stateName = inView.readUTF();
- TypeSerializer namespaceSerializer =
+ TypeSerializer<?> namespaceSerializer =
InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
- TypeSerializer stateSerializer =
+ TypeSerializer<?> stateSerializer =
InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
- StateTable<K, ?, ?> stateTable = new StateTable(stateSerializer,
- namespaceSerializer,
- keyGroupRange);
- stateTables.put(stateName, stateTable);
- kvStatesById.put(i, stateName);
+ StateTable<K, ?, ?> stateTable = stateTables.get(stateName);
+
+ //important: only create a new table we did not already create it previously
+ if (null == stateTable) {
+ stateTable = new StateTable<>(stateSerializer, namespaceSerializer, keyGroupRange);
+ stateTables.put(stateName, stateTable);
+ kvStatesById.put(numRegisteredKvStates, stateName);
+ ++numRegisteredKvStates;
+ }
}
for (Tuple2<Integer, Long> groupOffset : keyGroupsHandle.getGroupRangeOffsets()) {
@@ -302,25 +308,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
StateTable<K, ?, ?> stateTable = stateTables.get(kvStatesById.get(kvStateId));
Preconditions.checkNotNull(stateTable);
- TypeSerializer namespaceSerializer = stateTable.getNamespaceSerializer();
- TypeSerializer stateSerializer = stateTable.getStateSerializer();
-
- Map namespaceMap = new HashMap<>();
- stateTable.set(keyGroupIndex, namespaceMap);
-
- int numNamespaces = inView.readInt();
- for (int k = 0; k < numNamespaces; k++) {
- Object namespace = namespaceSerializer.deserialize(inView);
- Map entryMap = new HashMap<>();
- namespaceMap.put(namespace, entryMap);
-
- int numEntries = inView.readInt();
- for (int l = 0; l < numEntries; l++) {
- Object key = keySerializer.deserialize(inView);
- Object value = stateSerializer.deserialize(inView);
- entryMap.put(key, value);
- }
- }
+ readStateTableForKeyGroup(inView, stateTable, keyGroupIndex);
}
}
} finally {
@@ -330,6 +318,32 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
}
}
+ private <N, S> void readStateTableForKeyGroup(
+ DataInputView inView,
+ StateTable<K, N, S> stateTable,
+ int keyGroupIndex) throws IOException {
+
+ TypeSerializer<N> namespaceSerializer = stateTable.getNamespaceSerializer();
+ TypeSerializer<S> stateSerializer = stateTable.getStateSerializer();
+
+ Map<N, Map<K, S>> namespaceMap = new HashMap<>();
+ stateTable.set(keyGroupIndex, namespaceMap);
+
+ int numNamespaces = inView.readInt();
+ for (int k = 0; k < numNamespaces; k++) {
+ N namespace = namespaceSerializer.deserialize(inView);
+ Map<K, S> entryMap = new HashMap<>();
+ namespaceMap.put(namespace, entryMap);
+
+ int numEntries = inView.readInt();
+ for (int l = 0; l < numEntries; l++) {
+ K key = keySerializer.deserialize(inView);
+ S state = stateSerializer.deserialize(inView);
+ entryMap.put(key, state);
+ }
+ }
+ }
+
@Override
public String toString() {
return "HeapKeyedStateBackend";
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
index 7d8b6ce..42703f8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
@@ -20,62 +20,49 @@ package org.apache.flink.runtime.state.memory;
import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.Preconditions;
import java.io.IOException;
-import java.io.Serializable;
-import java.util.Arrays;
/**
* A state handle that contains stream state in a byte array.
*/
public class ByteStreamStateHandle implements StreamStateHandle {
- private static final long serialVersionUID = -5280226231200217594L;
+ private static final long serialVersionUID = -5280226231202517594L;
/**
- * the state data
+ * The state data.
*/
protected final byte[] data;
/**
+ * A unique name of by which this state handle is identified and compared. Like a filename, all
+ * {@link ByteStreamStateHandle} with the exact same name must also have the exact same content in data.
+ */
+ protected final String handleName;
+
+ /**
* Creates a new ByteStreamStateHandle containing the given data.
- *
- * @param data The state data.
*/
- public ByteStreamStateHandle(byte[] data) {
- this.data = data;
+ public ByteStreamStateHandle(String handleName, byte[] data) {
+ this.handleName = Preconditions.checkNotNull(handleName);
+ this.data = Preconditions.checkNotNull(data);
}
@Override
public FSDataInputStream openInputStream() throws IOException {
-
- return new FSDataInputStream() {
- int index = 0;
-
- @Override
- public void seek(long desired) throws IOException {
- Preconditions.checkArgument(desired >= 0 && desired < Integer.MAX_VALUE);
- index = (int) desired;
- }
-
- @Override
- public long getPos() throws IOException {
- return index;
- }
-
- @Override
- public int read() throws IOException {
- return index < data.length ? data[index++] & 0xFF : -1;
- }
- };
+ return new ByteStateHandleInputStream();
}
public byte[] getData() {
return data;
}
+ public String getHandleName() {
+ return handleName;
+ }
+
@Override
public void discardState() {
}
@@ -94,17 +81,41 @@ public class ByteStreamStateHandle implements StreamStateHandle {
return false;
}
- ByteStreamStateHandle that = (ByteStreamStateHandle) o;
- return Arrays.equals(data, that.data);
+ ByteStreamStateHandle that = (ByteStreamStateHandle) o;
+ return handleName.equals(that.handleName);
}
@Override
public int hashCode() {
- return Arrays.hashCode(data);
+ return 31 * handleName.hashCode();
}
- public static StreamStateHandle fromSerializable(Serializable value) throws IOException {
- return new ByteStreamStateHandle(InstantiationUtil.serializeObject(value));
+ /**
+ * An input stream view on a byte array.
+ */
+ private final class ByteStateHandleInputStream extends FSDataInputStream {
+
+ private int index;
+
+ public ByteStateHandleInputStream() {
+ this.index = 0;
+ }
+
+ @Override
+ public void seek(long desired) throws IOException {
+ Preconditions.checkArgument(desired >= 0 && desired < Integer.MAX_VALUE);
+ index = (int) desired;
+ }
+
+ @Override
+ public long getPos() throws IOException {
+ return index;
+ }
+
+ @Override
+ public int read() throws IOException {
+ return index < data.length ? data[index++] & 0xFF : -1;
+ }
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
index 6f0a814..028f8c8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
@@ -23,6 +23,7 @@ import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.StreamStateHandle;
import java.io.IOException;
+import java.util.UUID;
/**
* {@link CheckpointStreamFactory} that produces streams that write to in-memory byte arrays.
@@ -118,7 +119,7 @@ public class MemCheckpointStreamFactory implements CheckpointStreamFactory {
if (isEmpty) {
return null;
}
- return new ByteStreamStateHandle(closeAndGetBytes());
+ return new ByteStreamStateHandle(String.valueOf(UUID.randomUUID()), closeAndGetBytes());
}
@Override
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 6289fcb..728c7d5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -45,6 +45,8 @@ import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.filesystem.FileStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.testutils.CommonTestUtils;
+import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.Preconditions;
import org.junit.Assert;
@@ -65,6 +67,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
+import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
@@ -2287,7 +2290,8 @@ public class CheckpointCoordinatorTest {
KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0));
- ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
+ ByteStreamStateHandle allSerializedStatesHandle = new TestByteStreamStateHandleDeepCompare(
+ String.valueOf(UUID.randomUUID()),
serializedDataWithOffsets.f0);
KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(
keyGroupRangeOffsets,
@@ -2343,7 +2347,8 @@ public class CheckpointCoordinatorTest {
public static ChainedStateHandle<StreamStateHandle> generateChainedStateHandle(
Serializable value) throws IOException {
- return ChainedStateHandle.wrapSingleHandle(ByteStreamStateHandle.fromSerializable(value));
+ return ChainedStateHandle.wrapSingleHandle(
+ TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value));
}
public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
@@ -2387,7 +2392,8 @@ public class CheckpointCoordinatorTest {
++idx;
}
- ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle(
+ ByteStreamStateHandle streamStateHandle = new TestByteStreamStateHandleDeepCompare(
+ String.valueOf(UUID.randomUUID()),
serializationWithOffsets.f0);
OperatorStateHandle operatorStateHandle =
@@ -2468,7 +2474,9 @@ public class CheckpointCoordinatorTest {
ChainedStateHandle<StreamStateHandle> expectNonPartitionedState = generateStateForVertex(jobVertexID, i);
ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = executionJobVertex.
getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
- assertEquals(expectNonPartitionedState.get(0), actualNonPartitionedState.get(0));
+ assertTrue(CommonTestUtils.isSteamContentEqual(
+ expectNonPartitionedState.get(0).openInputStream(),
+ actualNonPartitionedState.get(0).openInputStream()));
ChainedStateHandle<OperatorStateHandle> expectedPartitionableState =
generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8);
@@ -2476,7 +2484,9 @@ public class CheckpointCoordinatorTest {
List<Collection<OperatorStateHandle>> actualPartitionableState = executionJobVertex.
getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
- assertEquals(expectedPartitionableState.get(0), actualPartitionableState.get(0).iterator().next());
+ assertTrue(CommonTestUtils.isSteamContentEqual(
+ expectedPartitionableState.get(0).openInputStream(),
+ actualPartitionableState.get(0).iterator().next().openInputStream()));
List<KeyGroupsStateHandle> expectPartitionedKeyGroupState = generateKeyGroupState(
jobVertexID,
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
index c82be18..e38e5fb 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
@@ -26,7 +26,7 @@ import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
import org.junit.Test;
import java.io.IOException;
@@ -72,11 +72,11 @@ public class SavepointV1Test {
for (int i = 0; i < numTaskStates; i++) {
TaskState taskState = new TaskState(new JobVertexID(), numSubtaskStates, numSubtaskStates, 1);
for (int j = 0; j < numSubtaskStates; j++) {
- StreamStateHandle stateHandle = new ByteStreamStateHandle("Hello".getBytes());
+ StreamStateHandle stateHandle = new TestByteStreamStateHandleDeepCompare("a", "Hello".getBytes());
taskState.putState(i, new SubtaskState(
new ChainedStateHandle<>(Collections.singletonList(stateHandle)), 0));
- stateHandle = new ByteStreamStateHandle("Beautiful".getBytes());
+ stateHandle = new TestByteStreamStateHandleDeepCompare("b", "Beautiful".getBytes());
Map<String, long[]> offsetsMap = new HashMap<>();
offsetsMap.put("A", new long[]{0, 10, 20});
offsetsMap.put("B", new long[]{30, 40, 50});
@@ -93,7 +93,8 @@ public class SavepointV1Test {
taskState.putKeyedState(
0,
new KeyGroupsStateHandle(
- new KeyGroupRangeOffsets(1,1, new long[] {42}), new ByteStreamStateHandle("World".getBytes())));
+ new KeyGroupRangeOffsets(1, 1, new long[]{42}),
+ new TestByteStreamStateHandleDeepCompare("c", "World".getBytes())));
taskStates.add(taskState);
}
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index 6100856..b9c2bdf 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -69,6 +69,7 @@ import org.apache.flink.runtime.testingUtils.TestingMessages;
import org.apache.flink.runtime.testingUtils.TestingTaskManager;
import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages;
import org.apache.flink.runtime.testingUtils.TestingUtils;
+import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
import org.apache.flink.util.InstantiationUtil;
import org.junit.AfterClass;
import org.junit.BeforeClass;
@@ -461,7 +462,8 @@ public class JobManagerHARecoveryTest {
@Override
public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData) {
try {
- ByteStreamStateHandle byteStreamStateHandle = new ByteStreamStateHandle(
+ ByteStreamStateHandle byteStreamStateHandle = new TestByteStreamStateHandleDeepCompare(
+ String.valueOf(UUID.randomUUID()),
InstantiationUtil.serializeObject(checkpointMetaData.getCheckpointId()));
RetrievableStreamStateHandle<Long> state = new RetrievableStreamStateHandle<Long>(byteStreamStateHandle);
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
index 6ef184d..7d21cfd 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.jobmanager.SubmittedJobGraphStore.SubmittedJobGr
import org.apache.flink.runtime.state.RetrievableStateHandle;
import org.apache.flink.runtime.state.RetrievableStreamStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment;
import org.apache.flink.util.InstantiationUtil;
@@ -40,6 +41,7 @@ import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
+import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import static org.junit.Assert.assertEquals;
@@ -62,9 +64,10 @@ public class ZooKeeperSubmittedJobGraphsStoreITCase extends TestLogger {
private final static RetrievableStateStorageHelper<SubmittedJobGraph> localStateStorage = new RetrievableStateStorageHelper<SubmittedJobGraph>() {
@Override
public RetrievableStateHandle<SubmittedJobGraph> store(SubmittedJobGraph state) throws IOException {
- ByteStreamStateHandle byteStreamStateHandle = new ByteStreamStateHandle(
+ ByteStreamStateHandle byteStreamStateHandle = new TestByteStreamStateHandleDeepCompare(
+ String.valueOf(UUID.randomUUID()),
InstantiationUtil.serializeObject(state));
- return new RetrievableStreamStateHandle<SubmittedJobGraph>(byteStreamStateHandle);
+ return new RetrievableStreamStateHandle<>(byteStreamStateHandle);
}
};
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
index 2a787f3..d857a19 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.testutils;
import org.apache.flink.util.FileUtils;
+import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
@@ -193,4 +194,26 @@ public class CommonTestUtils {
}
}
}
+
+ public static boolean isSteamContentEqual(InputStream input1, InputStream input2) throws IOException {
+
+ if (!(input1 instanceof BufferedInputStream)) {
+ input1 = new BufferedInputStream(input1);
+ }
+ if (!(input2 instanceof BufferedInputStream)) {
+ input2 = new BufferedInputStream(input2);
+ }
+
+ int ch = input1.read();
+ while (-1 != ch) {
+ int ch2 = input2.read();
+ if (ch != ch2) {
+ return false;
+ }
+ ch = input1.read();
+ }
+
+ int ch2 = input2.read();
+ return (ch2 == -1);
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestByteStreamStateHandleDeepCompare.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestByteStreamStateHandleDeepCompare.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestByteStreamStateHandleDeepCompare.java
new file mode 100644
index 0000000..7d8797b
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestByteStreamStateHandleDeepCompare.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.util;
+
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.util.InstantiationUtil;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Arrays;
+
+public class TestByteStreamStateHandleDeepCompare extends ByteStreamStateHandle {
+
+ private static final long serialVersionUID = -4946526195523509L;
+
+ public TestByteStreamStateHandleDeepCompare(String handleName, byte[] data) {
+ super(handleName, data);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!super.equals(o)) {
+ return false;
+ }
+ ByteStreamStateHandle other = (ByteStreamStateHandle) o;
+ return Arrays.equals(getData(), other.getData());
+ }
+
+ @Override
+ public int hashCode() {
+ return 31 * super.hashCode() + Arrays.hashCode(getData());
+ }
+
+ public static StreamStateHandle fromSerializable(String handleName, Serializable value) throws IOException {
+ return new TestByteStreamStateHandleDeepCompare(handleName, InstantiationUtil.serializeObject(value));
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index 263bf79..848a579 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -39,6 +39,7 @@ import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
import org.apache.flink.runtime.testingUtils.TestingCluster;
import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
@@ -57,7 +58,9 @@ import scala.concurrent.duration.Deadline;
import scala.concurrent.duration.FiniteDuration;
import java.io.File;
+import java.util.ArrayList;
import java.util.HashSet;
+import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
@@ -67,6 +70,10 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
+/**
+ * TODO : parameterize to test all different state backends!
+ * TODO: reactivate ignored test as soon as savepoints work with deactivated checkpoints.
+ */
public class RescalingITCase extends TestLogger {
private static final int numTaskManagers = 2;
@@ -103,17 +110,26 @@ public class RescalingITCase extends TestLogger {
}
}
+ @Test
+ public void testSavepointRescalingInKeyedState() throws Exception {
+ testSavepointRescalingKeyedState(false);
+ }
+
+ @Test
+ public void testSavepointRescalingOutKeyedState() throws Exception {
+ testSavepointRescalingKeyedState(true);
+ }
+
/**
- * Tests that a a job with purely partitioned state can be restarted from a savepoint
+ * Tests that a a job with purely keyed state can be restarted from a savepoint
* with a different parallelism.
*/
- @Test
- public void testSavepointRescalingWithPartitionedState() throws Exception {
+ public void testSavepointRescalingKeyedState(boolean scaleOut) throws Exception {
final int numberKeys = 42;
final int numberElements = 1000;
final int numberElements2 = 500;
- final int parallelism = numSlots / 2;
- final int parallelism2 = numSlots;
+ final int parallelism = scaleOut ? numSlots / 2 : numSlots;
+ final int parallelism2 = scaleOut ? numSlots : numSlots / 2;
final int maxParallelism = 13;
FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
@@ -125,7 +141,7 @@ public class RescalingITCase extends TestLogger {
try {
jobManager = cluster.getLeaderGateway(deadline.timeLeft());
- JobGraph jobGraph = createPartitionedStateJobGraph(parallelism, maxParallelism, numberKeys, numberElements, false, 100);
+ JobGraph jobGraph = createJobGraphWithKeyedState(parallelism, maxParallelism, numberKeys, numberElements, false, 100);
jobID = jobGraph.getJobID();
@@ -168,7 +184,7 @@ public class RescalingITCase extends TestLogger {
jobID = null;
- JobGraph scaledJobGraph = createPartitionedStateJobGraph(parallelism2, maxParallelism, numberKeys, numberElements2, true, 100);
+ JobGraph scaledJobGraph = createJobGraphWithKeyedState(parallelism2, maxParallelism, numberKeys, numberElements2, true, 100);
scaledJobGraph.setSavepointPath(savepointPath);
@@ -189,6 +205,7 @@ public class RescalingITCase extends TestLogger {
assertEquals(expectedResult2, actualResult2);
+
} finally {
// clear the CollectionSink set for the restarted job
CollectionSink.clearElementsSet();
@@ -213,7 +230,7 @@ public class RescalingITCase extends TestLogger {
* @throws Exception
*/
@Test
- public void testSavepointRescalingFailureWithNonPartitionedState() throws Exception {
+ public void testSavepointRescalingNonPartitionedStateCausesException() throws Exception {
final int parallelism = numSlots / 2;
final int parallelism2 = numSlots;
final int maxParallelism = 13;
@@ -227,7 +244,7 @@ public class RescalingITCase extends TestLogger {
try {
jobManager = cluster.getLeaderGateway(deadline.timeLeft());
- JobGraph jobGraph = createNonPartitionedStateJobGraph(parallelism, maxParallelism, 500);
+ JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, false);
jobID = jobGraph.getJobID();
@@ -236,7 +253,7 @@ public class RescalingITCase extends TestLogger {
Object savepointResponse = null;
// wait until the operator is started
- NonPartitionedStateSource.workStartedLatch.await();
+ StateSourceBase.workStartedLatch.await();
while (deadline.hasTimeLeft()) {
@@ -266,7 +283,7 @@ public class RescalingITCase extends TestLogger {
// job successfully removed
jobID = null;
- JobGraph scaledJobGraph = createNonPartitionedStateJobGraph(parallelism2, maxParallelism, 500);
+ JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, false);
scaledJobGraph.setSavepointPath(savepointPath);
@@ -311,7 +328,7 @@ public class RescalingITCase extends TestLogger {
* @throws Exception
*/
@Test
- public void testSavepointRescalingWithPartiallyNonPartitionedState() throws Exception {
+ public void testSavepointRescalingWithKeyedAndNonPartitionedState() throws Exception {
int numberKeys = 42;
int numberElements = 1000;
int numberElements2 = 500;
@@ -328,7 +345,7 @@ public class RescalingITCase extends TestLogger {
try {
jobManager = cluster.getLeaderGateway(deadline.timeLeft());
- JobGraph jobGraph = createPartitionedNonPartitionedStateJobGraph(
+ JobGraph jobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState(
parallelism,
maxParallelism,
parallelism,
@@ -378,7 +395,7 @@ public class RescalingITCase extends TestLogger {
jobID = null;
- JobGraph scaledJobGraph = createPartitionedNonPartitionedStateJobGraph(
+ JobGraph scaledJobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState(
parallelism2,
maxParallelism,
parallelism,
@@ -423,23 +440,137 @@ public class RescalingITCase extends TestLogger {
}
}
- private static JobGraph createNonPartitionedStateJobGraph(int parallelism, int maxParallelism, long checkpointInterval) {
+ @Test
+ public void testSavepointRescalingInPartitionedOperatorState() throws Exception {
+ testSavepointRescalingPartitionedOperatorState(false);
+ }
+
+ @Test
+ public void testSavepointRescalingOutPartitionedOperatorState() throws Exception {
+ testSavepointRescalingPartitionedOperatorState(true);
+ }
+
+
+ /**
+ * Tests rescaling of partitioned operator state. More specific, we test the mechanism with {@link ListCheckpointed}
+ * as it subsumes {@link org.apache.flink.streaming.api.checkpoint.CheckpointedFunction}.
+ */
+ public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut) throws Exception {
+ final int parallelism = scaleOut ? numSlots : numSlots / 2;
+ final int parallelism2 = scaleOut ? numSlots / 2 : numSlots;
+ final int maxParallelism = 13;
+
+ FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
+ Deadline deadline = timeout.fromNow();
+
+ JobID jobID = null;
+ ActorGateway jobManager = null;
+
+ int counterSize = Math.max(parallelism, parallelism2);
+
+ PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
+ PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize];
+
+ try {
+ jobManager = cluster.getLeaderGateway(deadline.timeLeft());
+
+ JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, true);
+
+ jobID = jobGraph.getJobID();
+
+ cluster.submitJobDetached(jobGraph);
+
+ Object savepointResponse = null;
+
+ // wait until the operator is started
+ StateSourceBase.workStartedLatch.await();
+
+ while (deadline.hasTimeLeft()) {
+
+ Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID), deadline.timeLeft());
+ FiniteDuration waitingTime = new FiniteDuration(10, TimeUnit.SECONDS);
+ savepointResponse = Await.result(savepointPathFuture, waitingTime);
+
+ if (savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess) {
+ break;
+ }
+ }
+
+ assertTrue(savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess);
+
+ final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)savepointResponse).savepointPath();
+
+ Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
+
+ Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
+
+ Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
+
+ assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
+
+ Await.ready(jobRemovedFuture, deadline.timeLeft());
+
+ // job successfully removed
+ jobID = null;
+
+ JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, true);
+
+ scaledJobGraph.setSavepointPath(savepointPath);
+
+ jobID = scaledJobGraph.getJobID();
+
+ cluster.submitJobAndWait(scaledJobGraph, false);
+
+ int sumExp = 0;
+ int sumAct = 0;
+
+ for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
+ sumExp += c;
+ }
+
+ for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
+ sumAct += c;
+ }
+
+ assertEquals(sumExp, sumAct);
+ jobID = null;
+
+ } finally {
+ // clear any left overs from a possibly failed job
+ if (jobID != null && jobManager != null) {
+ Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
+
+ try {
+ Await.ready(jobRemovedFuture, timeout);
+ } catch (TimeoutException | InterruptedException ie) {
+ fail("Failed while cleaning up the cluster.");
+ }
+ }
+ }
+ }
+
+ //------------------------------------------------------------------------------------------------------------------
+
+ private static JobGraph createJobGraphWithOperatorState(
+ int parallelism, int maxParallelism, boolean partitionedOperatorState) {
+
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(parallelism);
env.getConfig().setMaxParallelism(maxParallelism);
- env.enableCheckpointing(checkpointInterval);
+ env.enableCheckpointing(Long.MAX_VALUE);
env.setRestartStrategy(RestartStrategies.noRestart());
- NonPartitionedStateSource.workStartedLatch = new CountDownLatch(1);
+ StateSourceBase.workStartedLatch = new CountDownLatch(1);
- DataStream<Integer> input = env.addSource(new NonPartitionedStateSource());
+ DataStream<Integer> input = env.addSource(
+ partitionedOperatorState ? new PartitionedStateSource() : new NonPartitionedStateSource());
input.addSink(new DiscardingSink<Integer>());
return env.getStreamGraph().getJobGraph();
}
- private static JobGraph createPartitionedStateJobGraph(
+ private static JobGraph createJobGraphWithKeyedState(
int parallelism,
int maxParallelism,
int numberKeys,
@@ -475,7 +606,7 @@ public class RescalingITCase extends TestLogger {
return env.getStreamGraph().getJobGraph();
}
- private static JobGraph createPartitionedNonPartitionedStateJobGraph(
+ private static JobGraph createJobGraphWithKeyedAndNonPartitionedOperatorState(
int parallelism,
int maxParallelism,
int fixedParallelism,
@@ -606,12 +737,13 @@ public class RescalingITCase extends TestLogger {
@Override
public void open(Configuration configuration) {
- counter = getRuntimeContext().getState(new ValueStateDescriptor<Integer>("counter", Integer.class, 0));
- sum = getRuntimeContext().getState(new ValueStateDescriptor<Integer>("sum", Integer.class, 0));
+ counter = getRuntimeContext().getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
+ sum = getRuntimeContext().getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
}
@Override
public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out) throws Exception {
+
int count = counter.value() + 1;
counter.update(count);
@@ -645,14 +777,43 @@ public class RescalingITCase extends TestLogger {
}
}
- private static class NonPartitionedStateSource extends RichParallelSourceFunction<Integer> implements Checkpointed<Integer> {
-
- private static final long serialVersionUID = -8108185918123186841L;
+ private static class StateSourceBase extends RichParallelSourceFunction<Integer> {
private static volatile CountDownLatch workStartedLatch = new CountDownLatch(1);
- private volatile int counter = 0;
- private volatile boolean running = true;
+ protected volatile int counter = 0;
+ protected volatile boolean running = true;
+
+ @Override
+ public void run(SourceContext<Integer> ctx) throws Exception {
+ final Object lock = ctx.getCheckpointLock();
+
+ while (running) {
+ synchronized (lock) {
+ ++counter;
+ ctx.collect(1);
+ }
+
+ Thread.sleep(2);
+ if (counter == 10) {
+ workStartedLatch.countDown();
+ }
+
+ if (counter >= 500) {
+ break;
+ }
+ }
+ }
+
+ @Override
+ public void cancel() {
+ running = false;
+ }
+ }
+
+ private static class NonPartitionedStateSource extends StateSourceBase implements Checkpointed<Integer> {
+
+ private static final long serialVersionUID = -8108185918123186841L;
@Override
public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
@@ -663,28 +824,42 @@ public class RescalingITCase extends TestLogger {
public void restoreState(Integer state) throws Exception {
counter = state;
}
+ }
+
+ private static class PartitionedStateSource extends StateSourceBase implements ListCheckpointed<Integer> {
+
+ private static final long serialVersionUID = -359715965103593462L;
+ private static final int NUM_PARTITIONS = 7;
+
+ private static int[] CHECK_CORRECT_SNAPSHOT;
+ private static int[] CHECK_CORRECT_RESTORE;
@Override
- public void run(SourceContext<Integer> ctx) throws Exception {
- final Object lock = ctx.getCheckpointLock();
+ public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
- while (running) {
- synchronized (lock) {
- counter++;
+ CHECK_CORRECT_SNAPSHOT[getRuntimeContext().getIndexOfThisSubtask()] = counter;
- ctx.collect(counter * getRuntimeContext().getIndexOfThisSubtask());
- }
+ int div = counter / NUM_PARTITIONS;
+ int mod = counter % NUM_PARTITIONS;
- Thread.sleep(2);
- if(counter == 10) {
- workStartedLatch.countDown();
+ List<Integer> split = new ArrayList<>();
+ for (int i = 0; i < NUM_PARTITIONS; ++i) {
+ int partitionValue = div;
+ if (mod > 0) {
+ --mod;
+ ++partitionValue;
}
+ split.add(partitionValue);
}
+ return split;
}
@Override
- public void cancel() {
- running = false;
+ public void restoreState(List<Integer> state) throws Exception {
+ for (Integer v : state) {
+ counter += v;
+ }
+ CHECK_CORRECT_RESTORE[getRuntimeContext().getIndexOfThisSubtask()] = counter;
}
}
}