You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2017/01/13 20:37:29 UTC
[2/2] flink git commit: [FLINK-5265] Introduce state handle
replication mode for CheckpointCoordinator
[FLINK-5265] Introduce state handle replication mode for CheckpointCoordinator
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/1020ba2c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/1020ba2c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/1020ba2c
Branch: refs/heads/master
Commit: 1020ba2c9cfc1d01703e97c72e20a922bae0732d
Parents: 8b1b4a1
Author: Stefan Richter <s....@data-artisans.com>
Authored: Sat Dec 3 02:42:25 2016 +0100
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Fri Jan 13 21:29:19 2017 +0100
----------------------------------------------------------------------
.../api/common/state/OperatorStateStore.java | 37 +++-
.../RoundRobinOperatorStateRepartitioner.java | 133 ++++++++++++---
.../checkpoint/StateAssignmentOperation.java | 15 +-
.../savepoint/SavepointV1Serializer.java | 24 ++-
.../state/DefaultOperatorStateBackend.java | 160 +++++++++++------
.../OperatorBackendSerializationProxy.java | 51 +++++-
.../OperatorStateCheckpointOutputStream.java | 10 +-
.../runtime/state/OperatorStateHandle.java | 86 +++++++++-
.../state/StateInitializationContextImpl.java | 18 +-
.../checkpoint/CheckpointCoordinatorTest.java | 119 +++++++++----
.../checkpoint/savepoint/SavepointV1Test.java | 7 +-
.../runtime/state/OperatorStateBackendTest.java | 54 +++++-
.../runtime/state/OperatorStateHandleTest.java | 39 +++++
...OperatorStateOutputCheckpointStreamTest.java | 11 +-
.../runtime/state/SerializationProxiesTest.java | 63 ++++++-
.../StateInitializationContextImplTest.java | 6 +-
.../tasks/InterruptSensitiveRestoreTest.java | 6 +-
.../test/checkpointing/RescalingITCase.java | 171 +++++++++++--------
18 files changed, 787 insertions(+), 223 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
index 43dbe51..87a7759 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
@@ -30,8 +30,22 @@ import java.util.Set;
public interface OperatorStateStore {
/**
- * Creates a state descriptor of the given name that uses Java serialization to persist the
- * state.
+ * Creates (or restores) a list state. Each state is registered under a unique name.
+ * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore).
+ *
+ * The items in the list are repartitionable by the system in case of changed operator parallelism.
+ *
+ * @param stateDescriptor The descriptor for this state, providing a name and serializer.
+ * @param <S> The generic type of the state
+ *
+ * @return A list for all state partitions.
+ * @throws Exception
+ */
+ <S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception;
+
+ /**
+ * Creates a state of the given name that uses Java serialization to persist the state. The items in the list
+ * are repartitionable by the system in case of changed operator parallelism.
*
* <p>This is a simple convenience method. For more flexibility on how state serialization
* should happen, use the {@link #getOperatorState(ListStateDescriptor)} method.
@@ -46,13 +60,28 @@ public interface OperatorStateStore {
* Creates (or restores) a list state. Each state is registered under a unique name.
* The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore).
*
+ * On restore, all items in the list are broadcasted to all parallel operator instances.
+ *
* @param stateDescriptor The descriptor for this state, providing a name and serializer.
* @param <S> The generic type of the state
- *
+ *
* @return A list for all state partitions.
* @throws Exception
*/
- <S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception;
+ <S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception;
+
+ /**
+ * Creates a state of the given name that uses Java serialization to persist the state. On restore, all items
+ * in the list are broadcasted to all parallel operator instances.
+ *
+ * <p>This is a simple convenience method. For more flexibility on how state serialization
+ * should happen, use the {@link #getBroadcastOperatorState(ListStateDescriptor)} method.
+ *
+ * @param stateName The name of state to create
+ * @return A list state using Java serialization to serialize state objects.
+ * @throws Exception
+ */
+ <T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception;
/**
* Returns a set with the names of all currently registered states.
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
index 16a7e27..046096f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
@@ -26,6 +26,7 @@ import org.apache.flink.util.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
+import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -47,8 +48,7 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
Preconditions.checkArgument(parallelism > 0);
// Reorganize: group by (State Name -> StreamStateHandle + Offsets)
- Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState =
- groupByStateName(previousParallelSubtaskStates);
+ GroupByStateNameResults nameToStateByMode = groupByStateName(previousParallelSubtaskStates);
if (OPTIMIZE_MEMORY_USE) {
previousParallelSubtaskStates.clear(); // free for GC at to cost that old handles are no longer available
@@ -59,7 +59,7 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
// Do the actual repartitioning for all named states
List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList =
- repartition(nameToState, parallelism);
+ repartition(nameToStateByMode, parallelism);
for (int i = 0; i < mergeMapList.size(); ++i) {
result.add(i, new ArrayList<>(mergeMapList.get(i).values()));
@@ -71,16 +71,33 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
/**
* Group by the different named states.
*/
- private Map<String, List<Tuple2<StreamStateHandle, long[]>>> groupByStateName(
+ @SuppressWarnings("unchecked, rawtype")
+ private GroupByStateNameResults groupByStateName(
List<OperatorStateHandle> previousParallelSubtaskStates) {
- //Reorganize: group by (State Name -> StreamStateHandle + Offsets)
- Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState = new HashMap<>();
+ //Reorganize: group by (State Name -> StreamStateHandle + StateMetaInfo)
+ EnumMap<OperatorStateHandle.Mode,
+ Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> nameToStateByMode =
+ new EnumMap<>(OperatorStateHandle.Mode.class);
+
+ for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
+ Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> map = new HashMap<>();
+ nameToStateByMode.put(
+ mode,
+ new HashMap<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>());
+ }
+
for (OperatorStateHandle psh : previousParallelSubtaskStates) {
- for (Map.Entry<String, long[]> e : psh.getStateNameToPartitionOffsets().entrySet()) {
+ for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e :
+ psh.getStateNameToPartitionOffsets().entrySet()) {
+ OperatorStateHandle.StateMetaInfo metaInfo = e.getValue();
+
+ Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToState =
+ nameToStateByMode.get(metaInfo.getDistributionMode());
- List<Tuple2<StreamStateHandle, long[]>> stateLocations = nameToState.get(e.getKey());
+ List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> stateLocations =
+ nameToState.get(e.getKey());
if (stateLocations == null) {
stateLocations = new ArrayList<>();
@@ -90,32 +107,40 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
stateLocations.add(new Tuple2<>(psh.getDelegateStateHandle(), e.getValue()));
}
}
- return nameToState;
+
+ return new GroupByStateNameResults(nameToStateByMode);
}
/**
* Repartition all named states.
*/
private List<Map<StreamStateHandle, OperatorStateHandle>> repartition(
- Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState, int parallelism) {
+ GroupByStateNameResults nameToStateByMode,
+ int parallelism) {
// We will use this to merge w.r.t. StreamStateHandles for each parallel subtask inside the maps
List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = new ArrayList<>(parallelism);
+
// Initialize
for (int i = 0; i < parallelism; ++i) {
mergeMapList.add(new HashMap<StreamStateHandle, OperatorStateHandle>());
}
- int startParallelOP = 0;
+ // Start with the state handles we distribute round robin by splitting by offsets
+ Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> distributeNameToState =
+ nameToStateByMode.getByMode(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
+
+ int startParallelOp = 0;
// Iterate all named states and repartition one named state at a time per iteration
- for (Map.Entry<String, List<Tuple2<StreamStateHandle, long[]>>> e : nameToState.entrySet()) {
+ for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e :
+ distributeNameToState.entrySet()) {
- List<Tuple2<StreamStateHandle, long[]>> current = e.getValue();
+ List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();
// Determine actual number of partitions for this named state
int totalPartitions = 0;
- for (Tuple2<StreamStateHandle, long[]> offsets : current) {
- totalPartitions += offsets.f1.length;
+ for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> offsets : current) {
+ totalPartitions += offsets.f1.getOffsets().length;
}
// Repartition the state across the parallel operator instances
@@ -124,12 +149,12 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
int baseFraction = totalPartitions / parallelism;
int remainder = totalPartitions % parallelism;
- int newStartParallelOp = startParallelOP;
+ int newStartParallelOp = startParallelOp;
for (int i = 0; i < parallelism; ++i) {
// Preparation: calculate the actual index considering wrap around
- int parallelOpIdx = (i + startParallelOP) % parallelism;
+ int parallelOpIdx = (i + startParallelOp) % parallelism;
// Now calculate the number of partitions we will assign to the parallel instance in this round ...
int numberOfPartitionsToAssign = baseFraction;
@@ -146,11 +171,14 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
}
// Now start collection the partitions for the parallel instance into this list
- List<Tuple2<StreamStateHandle, long[]>> parallelOperatorState = new ArrayList<>();
+ List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> parallelOperatorState =
+ new ArrayList<>();
while (numberOfPartitionsToAssign > 0) {
- Tuple2<StreamStateHandle, long[]> handleWithOffsets = current.get(lstIdx);
- long[] offsets = handleWithOffsets.f1;
+ Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithOffsets =
+ current.get(lstIdx);
+
+ long[] offsets = handleWithOffsets.f1.getOffsets();
int remaining = offsets.length - offsetIdx;
// Repartition offsets
long[] offs;
@@ -166,25 +194,74 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
++lstIdx;
}
- parallelOperatorState.add(
- new Tuple2<>(handleWithOffsets.f0, offs));
+ parallelOperatorState.add(new Tuple2<>(
+ handleWithOffsets.f0,
+ new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)));
numberOfPartitionsToAssign -= remaining;
// As a last step we merge partitions that use the same StreamStateHandle in a single
// OperatorStateHandle
Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx);
- OperatorStateHandle psh = mergeMap.get(handleWithOffsets.f0);
- if (psh == null) {
- psh = new OperatorStateHandle(new HashMap<String, long[]>(), handleWithOffsets.f0);
- mergeMap.put(handleWithOffsets.f0, psh);
+ OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithOffsets.f0);
+ if (operatorStateHandle == null) {
+ operatorStateHandle = new OperatorStateHandle(
+ new HashMap<String, OperatorStateHandle.StateMetaInfo>(),
+ handleWithOffsets.f0);
+
+ mergeMap.put(handleWithOffsets.f0, operatorStateHandle);
}
- psh.getStateNameToPartitionOffsets().put(e.getKey(), offs);
+ operatorStateHandle.getStateNameToPartitionOffsets().put(
+ e.getKey(),
+ new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
}
}
- startParallelOP = newStartParallelOp;
+ startParallelOp = newStartParallelOp;
e.setValue(null);
}
+
+ // Now we also add the state handles marked for broadcast to all parallel instances
+ Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> broadcastNameToState =
+ nameToStateByMode.getByMode(OperatorStateHandle.Mode.BROADCAST);
+
+ for (int i = 0; i < parallelism; ++i) {
+
+ Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(i);
+
+ for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e :
+ broadcastNameToState.entrySet()) {
+
+ List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();
+
+ for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithMetaInfo : current) {
+ OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0);
+ if (operatorStateHandle == null) {
+ operatorStateHandle = new OperatorStateHandle(
+ new HashMap<String, OperatorStateHandle.StateMetaInfo>(),
+ handleWithMetaInfo.f0);
+
+ mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle);
+ }
+ operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), handleWithMetaInfo.f1);
+ }
+ }
+ }
return mergeMapList;
}
+
+ private static final class GroupByStateNameResults {
+ private final EnumMap<OperatorStateHandle.Mode,
+ Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode;
+
+ public GroupByStateNameResults(
+ EnumMap<OperatorStateHandle.Mode,
+ Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode) {
+ this.byMode = Preconditions.checkNotNull(byMode);
+ }
+
+ public Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> getByMode(
+ OperatorStateHandle.Mode mode) {
+ return byMode.get(mode);
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 2e05a85..f11f69b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -338,9 +338,22 @@ public class StateAssignmentOperation {
chainOpParallelStates,
newParallelism);
} else {
-
List<Collection<OperatorStateHandle>> repackStream = new ArrayList<>(newParallelism);
for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {
+
+ Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsets =
+ operatorStateHandle.getStateNameToPartitionOffsets();
+
+ for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) {
+
+ // if we find any broadcast state, we cannot take the shortcut and need to go through repartitioning
+ if (OperatorStateHandle.Mode.BROADCAST.equals(metaInfo.getDistributionMode())) {
+ return opStateRepartitioner.repartitionState(
+ chainOpParallelStates,
+ newParallelism);
+ }
+ }
+
repackStream.add(Collections.singletonList(operatorStateHandle));
}
return repackStream;
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
index 48324ca..ba1949a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
@@ -250,11 +250,18 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
if (stateHandle != null) {
dos.writeByte(PARTITIONABLE_OPERATOR_STATE_HANDLE);
- Map<String, long[]> partitionOffsetsMap = stateHandle.getStateNameToPartitionOffsets();
+ Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsetsMap =
+ stateHandle.getStateNameToPartitionOffsets();
dos.writeInt(partitionOffsetsMap.size());
- for (Map.Entry<String, long[]> entry : partitionOffsetsMap.entrySet()) {
+ for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : partitionOffsetsMap.entrySet()) {
dos.writeUTF(entry.getKey());
- long[] offsets = entry.getValue();
+
+ OperatorStateHandle.StateMetaInfo stateMetaInfo = entry.getValue();
+
+ int mode = stateMetaInfo.getDistributionMode().ordinal();
+ dos.writeByte(mode);
+
+ long[] offsets = stateMetaInfo.getOffsets();
dos.writeInt(offsets.length);
for (long offset : offsets) {
dos.writeLong(offset);
@@ -274,14 +281,21 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
return null;
} else if (PARTITIONABLE_OPERATOR_STATE_HANDLE == type) {
int mapSize = dis.readInt();
- Map<String, long[]> offsetsMap = new HashMap<>(mapSize);
+ Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(mapSize);
for (int i = 0; i < mapSize; ++i) {
String key = dis.readUTF();
+
+ int modeOrdinal = dis.readByte();
+ OperatorStateHandle.Mode mode = OperatorStateHandle.Mode.values()[modeOrdinal];
+
long[] offsets = new long[dis.readInt()];
for (int j = 0; j < offsets.length; ++j) {
offsets[j] = dis.readLong();
}
- offsetsMap.put(key, offsets);
+
+ OperatorStateHandle.StateMetaInfo metaInfo =
+ new OperatorStateHandle.StateMetaInfo(offsets, mode);
+ offsetsMap.put(key, metaInfo);
}
StreamStateHandle stateHandle = deserializeStreamStateHandle(dis);
return new OperatorStateHandle(offsetsMap, stateHandle);
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
index 10bb409..6c65088 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -19,6 +19,7 @@
package org.apache.flink.runtime.state;
import org.apache.commons.io.IOUtils;
+import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
@@ -44,6 +45,7 @@ import java.util.concurrent.RunnableFuture;
/**
* Default implementation of OperatorStateStore that provides the ability to make snapshots.
*/
+@Internal
public class DefaultOperatorStateBackend implements OperatorStateBackend {
/** The default namespace for state in cases where no state name is provided */
@@ -62,14 +64,46 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
this.registeredStates = new HashMap<>();
}
+ @Override
+ public Set<String> getRegisteredStateNames() {
+ return registeredStates.keySet();
+ }
+
+ @Override
+ public void close() throws IOException {
+ closeStreamOnCancelRegistry.close();
+ }
+
+ @Override
+ public void dispose() {
+ registeredStates.clear();
+ }
+
@SuppressWarnings("unchecked")
@Override
public <T extends Serializable> ListState<T> getSerializableListState(String stateName) throws Exception {
return (ListState<T>) getOperatorState(new ListStateDescriptor<>(stateName, javaSerializer));
}
-
+
@Override
public <S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws IOException {
+ return getOperatorState(stateDescriptor, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception {
+ return (ListState<T>) getBroadcastOperatorState(new ListStateDescriptor<>(stateName, javaSerializer));
+ }
+
+ @Override
+ public <S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception {
+ return getOperatorState(stateDescriptor, OperatorStateHandle.Mode.BROADCAST);
+ }
+
+ private <S> ListState<S> getOperatorState(
+ ListStateDescriptor<S> stateDescriptor,
+ OperatorStateHandle.Mode mode) throws IOException {
Preconditions.checkNotNull(stateDescriptor);
@@ -81,10 +115,18 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
if (null == partitionableListState) {
- partitionableListState = new PartitionableListState<>(name, partitionStateSerializer);
+ partitionableListState = new PartitionableListState<>(
+ name,
+ partitionStateSerializer,
+ mode);
+
registeredStates.put(name, partitionableListState);
} else {
Preconditions.checkState(
+ partitionableListState.getAssignmentMode().equals(mode),
+ "Incompatible assignment mode. Provided: " + mode + ", expected: " +
+ partitionableListState.getAssignmentMode());
+ Preconditions.checkState(
partitionableListState.getPartitionStateSerializer().
isCompatibleWith(stateDescriptor.getSerializer()),
"Incompatible type serializers. Provided: " + stateDescriptor.getSerializer() +
@@ -97,16 +139,21 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
private static <S> void deserializeStateValues(
PartitionableListState<S> stateListForName,
FSDataInputStream in,
- long[] offsets) throws IOException {
-
- DataInputView div = new DataInputViewStreamWrapper(in);
- TypeSerializer<S> serializer = stateListForName.getPartitionStateSerializer();
- for (long offset : offsets) {
- in.seek(offset);
- stateListForName.add(serializer.deserialize(div));
+ OperatorStateHandle.StateMetaInfo metaInfo) throws IOException {
+
+ if (null != metaInfo) {
+ long[] offsets = metaInfo.getOffsets();
+ if (null != offsets) {
+ DataInputView div = new DataInputViewStreamWrapper(in);
+ TypeSerializer<S> serializer = stateListForName.getPartitionStateSerializer();
+ for (long offset : offsets) {
+ in.seek(offset);
+ stateListForName.add(serializer.deserialize(div));
+ }
+ }
}
}
-
+
@Override
public RunnableFuture<OperatorStateHandle> snapshot(
long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
@@ -123,11 +170,12 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
OperatorBackendSerializationProxy.StateMetaInfo<?> metaInfo =
new OperatorBackendSerializationProxy.StateMetaInfo<>(
state.getName(),
- state.getPartitionStateSerializer());
+ state.getPartitionStateSerializer(),
+ state.getAssignmentMode());
metaInfoList.add(metaInfo);
}
- Map<String, long[]> writtenStatesMetaData = new HashMap<>(registeredStates.size());
+ Map<String, OperatorStateHandle.StateMetaInfo> writtenStatesMetaData = new HashMap<>(registeredStates.size());
CheckpointStreamFactory.CheckpointStateOutputStream out = streamFactory.
createCheckpointStateOutputStream(checkpointId, timestamp);
@@ -145,8 +193,10 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
dov.writeInt(registeredStates.size());
for (Map.Entry<String, PartitionableListState<?>> entry : registeredStates.entrySet()) {
- long[] partitionOffsets = entry.getValue().write(out);
- writtenStatesMetaData.put(entry.getKey(), partitionOffsets);
+ PartitionableListState<?> value = entry.getValue();
+ long[] partitionOffsets = value.write(out);
+ OperatorStateHandle.Mode mode = value.getAssignmentMode();
+ writtenStatesMetaData.put(entry.getKey(), new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
}
OperatorStateHandle handle = new OperatorStateHandle(writtenStatesMetaData, out.closeAndGetHandle());
@@ -193,7 +243,8 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
if (null == listState) {
listState = new PartitionableListState<>(
stateMetaInfo.getName(),
- stateMetaInfo.getStateSerializer());
+ stateMetaInfo.getStateSerializer(),
+ stateMetaInfo.getMode());
registeredStates.put(listState.getName(), listState);
} else {
@@ -205,7 +256,9 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
}
// Restore all the state in PartitionableListStates
- for (Map.Entry<String, long[]> nameToOffsets : stateHandle.getStateNameToPartitionOffsets().entrySet()) {
+ for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> nameToOffsets :
+ stateHandle.getStateNameToPartitionOffsets().entrySet()) {
+
PartitionableListState<?> stateListForName = registeredStates.get(nameToOffsets.getKey());
Preconditions.checkState(null != stateListForName, "Found state without " +
@@ -222,60 +275,40 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
}
}
- @Override
- public void dispose() {
- registeredStates.clear();
- }
-
- @Override
- public Set<String> getRegisteredStateNames() {
- return registeredStates.keySet();
- }
-
- @Override
- public void close() throws IOException {
- closeStreamOnCancelRegistry.close();
- }
-
static final class PartitionableListState<S> implements ListState<S> {
- private final List<S> internalList;
private final String name;
private final TypeSerializer<S> partitionStateSerializer;
+ private final OperatorStateHandle.Mode assignmentMode;
+ private final List<S> internalList;
- public PartitionableListState(String name, TypeSerializer<S> partitionStateSerializer) {
- this.internalList = new ArrayList<>();
- this.partitionStateSerializer = Preconditions.checkNotNull(partitionStateSerializer);
- this.name = Preconditions.checkNotNull(name);
- }
-
- public long[] write(FSDataOutputStream out) throws IOException {
-
- long[] partitionOffsets = new long[internalList.size()];
-
- DataOutputView dov = new DataOutputViewStreamWrapper(out);
-
- for (int i = 0; i < internalList.size(); ++i) {
- S element = internalList.get(i);
- partitionOffsets[i] = out.getPos();
- partitionStateSerializer.serialize(element, dov);
- }
-
- return partitionOffsets;
- }
+ public PartitionableListState(
+ String name,
+ TypeSerializer<S> partitionStateSerializer,
+ OperatorStateHandle.Mode assignmentMode) {
- public List<S> getInternalList() {
- return internalList;
+ this.name = Preconditions.checkNotNull(name);
+ this.partitionStateSerializer = Preconditions.checkNotNull(partitionStateSerializer);
+ this.assignmentMode = Preconditions.checkNotNull(assignmentMode);
+ this.internalList = new ArrayList<>();
}
public String getName() {
return name;
}
+ public OperatorStateHandle.Mode getAssignmentMode() {
+ return assignmentMode;
+ }
+
public TypeSerializer<S> getPartitionStateSerializer() {
return partitionStateSerializer;
}
+ public List<S> getInternalList() {
+ return internalList;
+ }
+
@Override
public void clear() {
internalList.clear();
@@ -294,8 +327,25 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
@Override
public String toString() {
return "PartitionableListState{" +
- "listState=" + internalList +
+ "name='" + name + '\'' +
+ ", assignmentMode=" + assignmentMode +
+ ", internalList=" + internalList +
'}';
}
+
+ public long[] write(FSDataOutputStream out) throws IOException {
+
+ long[] partitionOffsets = new long[internalList.size()];
+
+ DataOutputView dov = new DataOutputViewStreamWrapper(out);
+
+ for (int i = 0; i < internalList.size(); ++i) {
+ S element = internalList.get(i);
+ partitionOffsets[i] = out.getPos();
+ partitionStateSerializer.serialize(element, dov);
+ }
+
+ return partitionOffsets;
+ }
}
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java
index 61df979..d571dcc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java
@@ -18,6 +18,7 @@
package org.apache.flink.runtime.state;
+import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.typeutils.runtime.DataInputViewStream;
import org.apache.flink.api.java.typeutils.runtime.DataOutputViewStream;
@@ -91,15 +92,19 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab
private String name;
private TypeSerializer<S> stateSerializer;
+ private OperatorStateHandle.Mode mode;
+
private ClassLoader userClassLoader;
- private StateMetaInfo(ClassLoader userClassLoader) {
+ @VisibleForTesting
+ public StateMetaInfo(ClassLoader userClassLoader) {
this.userClassLoader = Preconditions.checkNotNull(userClassLoader);
}
- public StateMetaInfo(String name, TypeSerializer<S> stateSerializer) {
+ public StateMetaInfo(String name, TypeSerializer<S> stateSerializer, OperatorStateHandle.Mode mode) {
this.name = Preconditions.checkNotNull(name);
this.stateSerializer = Preconditions.checkNotNull(stateSerializer);
+ this.mode = Preconditions.checkNotNull(mode);
}
public String getName() {
@@ -118,9 +123,18 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab
this.stateSerializer = stateSerializer;
}
+ public OperatorStateHandle.Mode getMode() {
+ return mode;
+ }
+
+ public void setMode(OperatorStateHandle.Mode mode) {
+ this.mode = mode;
+ }
+
@Override
public void write(DataOutputView out) throws IOException {
out.writeUTF(getName());
+ out.writeByte(getMode().ordinal());
DataOutputViewStream dos = new DataOutputViewStream(out);
InstantiationUtil.serializeObject(dos, getStateSerializer());
}
@@ -128,6 +142,7 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab
@Override
public void read(DataInputView in) throws IOException {
setName(in.readUTF());
+ setMode(OperatorStateHandle.Mode.values()[in.readByte()]);
DataInputViewStream dis = new DataInputViewStream(in);
try {
TypeSerializer<S> stateSerializer = InstantiationUtil.deserializeObject(dis, userClassLoader);
@@ -136,5 +151,37 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab
throw new IOException(exception);
}
}
+
+ @Override
+ public boolean equals(Object o) {
+
+ if (this == o) {
+ return true;
+ }
+
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ StateMetaInfo<?> metaInfo = (StateMetaInfo<?>) o;
+
+ if (!getName().equals(metaInfo.getName())) {
+ return false;
+ }
+
+ if (!getStateSerializer().equals(metaInfo.getStateSerializer())) {
+ return false;
+ }
+
+ return getMode() == metaInfo.getMode();
+ }
+
+ @Override
+ public int hashCode() {
+ int result = getName().hashCode();
+ result = 31 * result + getStateSerializer().hashCode();
+ result = 31 * result + getMode().hashCode();
+ return result;
+ }
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java
index eaa9fd9..036aed0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java
@@ -66,8 +66,14 @@ public final class OperatorStateCheckpointOutputStream
startNewPartition();
}
- Map<String, long[]> offsetsMap = new HashMap<>(1);
- offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, partitionOffsets.toArray());
+ Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(1);
+
+ OperatorStateHandle.StateMetaInfo metaInfo =
+ new OperatorStateHandle.StateMetaInfo(
+ partitionOffsets.toArray(),
+ OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
+
+ offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, metaInfo);
return new OperatorStateHandle(offsetsMap, streamStateHandle);
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
index 3cd37c9..c59fbad 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
@@ -22,6 +22,7 @@ import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.util.Preconditions;
import java.io.IOException;
+import java.io.Serializable;
import java.util.Arrays;
import java.util.Map;
@@ -31,21 +32,27 @@ import java.util.Map;
*/
public class OperatorStateHandle implements StreamStateHandle {
+ public enum Mode {
+ SPLIT_DISTRIBUTE, BROADCAST
+ }
+
private static final long serialVersionUID = 35876522969227335L;
- /** unique state name -> offsets for available partitions in the handle stream */
- private final Map<String, long[]> stateNameToPartitionOffsets;
+ /**
+ * unique state name -> offsets for available partitions in the handle stream
+ */
+ private final Map<String, StateMetaInfo> stateNameToPartitionOffsets;
private final StreamStateHandle delegateStateHandle;
public OperatorStateHandle(
- Map<String, long[]> stateNameToPartitionOffsets,
+ Map<String, StateMetaInfo> stateNameToPartitionOffsets,
StreamStateHandle delegateStateHandle) {
this.delegateStateHandle = Preconditions.checkNotNull(delegateStateHandle);
this.stateNameToPartitionOffsets = Preconditions.checkNotNull(stateNameToPartitionOffsets);
}
- public Map<String, long[]> getStateNameToPartitionOffsets() {
+ public Map<String, StateMetaInfo> getStateNameToPartitionOffsets() {
return stateNameToPartitionOffsets;
}
@@ -80,12 +87,12 @@ public class OperatorStateHandle implements StreamStateHandle {
OperatorStateHandle that = (OperatorStateHandle) o;
- if(stateNameToPartitionOffsets.size() != that.stateNameToPartitionOffsets.size()) {
+ if (stateNameToPartitionOffsets.size() != that.stateNameToPartitionOffsets.size()) {
return false;
}
- for (Map.Entry<String, long[]> entry : stateNameToPartitionOffsets.entrySet()) {
- if (!Arrays.equals(entry.getValue(), that.stateNameToPartitionOffsets.get(entry.getKey()))) {
+ for (Map.Entry<String, StateMetaInfo> entry : stateNameToPartitionOffsets.entrySet()) {
+ if (!entry.getValue().equals(that.stateNameToPartitionOffsets.get(entry.getKey()))) {
return false;
}
}
@@ -96,14 +103,75 @@ public class OperatorStateHandle implements StreamStateHandle {
@Override
public int hashCode() {
int result = delegateStateHandle.hashCode();
- for (Map.Entry<String, long[]> entry : stateNameToPartitionOffsets.entrySet()) {
+ for (Map.Entry<String, StateMetaInfo> entry : stateNameToPartitionOffsets.entrySet()) {
int entryHash = entry.getKey().hashCode();
if (entry.getValue() != null) {
- entryHash += Arrays.hashCode(entry.getValue());
+ entryHash += entry.getValue().hashCode();
}
result = 31 * result + entryHash;
}
return result;
}
+
+ @Override
+ public String toString() {
+ return "OperatorStateHandle{" +
+ "stateNameToPartitionOffsets=" + stateNameToPartitionOffsets +
+ ", delegateStateHandle=" + delegateStateHandle +
+ '}';
+ }
+
+ public static class StateMetaInfo implements Serializable {
+
+ private static final long serialVersionUID = 3593817615858941166L;
+
+ private final long[] offsets;
+ private final Mode distributionMode;
+
+ public StateMetaInfo(long[] offsets, Mode distributionMode) {
+ this.offsets = Preconditions.checkNotNull(offsets);
+ this.distributionMode = Preconditions.checkNotNull(distributionMode);
+ }
+
+ public long[] getOffsets() {
+ return offsets;
+ }
+
+ public Mode getDistributionMode() {
+ return distributionMode;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ StateMetaInfo that = (StateMetaInfo) o;
+
+ if (!Arrays.equals(getOffsets(), that.getOffsets())) {
+ return false;
+ }
+ return getDistributionMode() == that.getDistributionMode();
+ }
+
+ @Override
+ public int hashCode() {
+ int result = Arrays.hashCode(getOffsets());
+ result = 31 * result + getDistributionMode().hashCode();
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ return "StateMetaInfo{" +
+ "offsets=" + Arrays.toString(offsets) +
+ ", distributionMode=" + distributionMode +
+ '}';
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
index 46445d2..886d214 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
@@ -220,13 +220,21 @@ public class StateInitializationContextImpl implements StateInitializationContex
while (stateHandleIterator.hasNext()) {
currentStateHandle = stateHandleIterator.next();
- long[] offsets = currentStateHandle.getStateNameToPartitionOffsets().get(stateName);
- if (null != offsets && offsets.length > 0) {
+ OperatorStateHandle.StateMetaInfo metaInfo =
+ currentStateHandle.getStateNameToPartitionOffsets().get(stateName);
- this.offsets = offsets;
- this.offPos = 0;
+ if (null != metaInfo) {
+ long[] metaOffsets = metaInfo.getOffsets();
+ if (null != metaOffsets && metaOffsets.length > 0) {
+ this.offsets = metaOffsets;
+ this.offPos = 0;
- return true;
+ closableRegistry.unregisterClosable(currentStream);
+ IOUtils.closeQuietly(currentStream);
+ currentStream = null;
+
+ return true;
+ }
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index daacbfb..ca9dbc2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -941,7 +941,7 @@ public class CheckpointCoordinatorTest {
}
@Test
- public void handleMessagesForNonExistingCheckpoints() {
+ public void testHandleMessagesForNonExistingCheckpoints() {
try {
final JobID jid = new JobID();
final long timestamp = System.currentTimeMillis();
@@ -1937,8 +1937,8 @@ public class CheckpointCoordinatorTest {
coord.restoreLatestCheckpointedState(tasks, true, false);
// verify the restored state
- verifiyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1);
- verifiyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2);
+ verifyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1);
+ verifyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2);
}
/**
@@ -2318,7 +2318,7 @@ public class CheckpointCoordinatorTest {
coord.restoreLatestCheckpointedState(tasks, true, false);
// verify the restored state
- verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
+ verifyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
List<List<Collection<OperatorStateHandle>>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism());
List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism());
for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
@@ -2390,6 +2390,49 @@ public class CheckpointCoordinatorTest {
}
}
+ @Test
+ public void testReplicateModeStateHandle() {
+ Map<String, OperatorStateHandle.StateMetaInfo> metaInfoMap = new HashMap<>(1);
+ metaInfoMap.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0, 23}, OperatorStateHandle.Mode.BROADCAST));
+ metaInfoMap.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{42, 64}, OperatorStateHandle.Mode.BROADCAST));
+ metaInfoMap.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{72, 83}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+ OperatorStateHandle osh = new OperatorStateHandle(metaInfoMap, new ByteStreamStateHandle("test", new byte[100]));
+
+ OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
+ List<Collection<OperatorStateHandle>> repartitionedStates =
+ repartitioner.repartitionState(Collections.singletonList(osh), 3);
+
+ Map<String, Integer> checkCounts = new HashMap<>(3);
+
+ for (Collection<OperatorStateHandle> operatorStateHandles : repartitionedStates) {
+ for (OperatorStateHandle operatorStateHandle : operatorStateHandles) {
+ for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> stateNameToMetaInfo :
+ operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+
+ String stateName = stateNameToMetaInfo.getKey();
+ Integer count = checkCounts.get(stateName);
+ if (null == count) {
+ checkCounts.put(stateName, 1);
+ } else {
+ checkCounts.put(stateName, 1 + count);
+ }
+
+ OperatorStateHandle.StateMetaInfo stateMetaInfo = stateNameToMetaInfo.getValue();
+ if (OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals(stateMetaInfo.getDistributionMode())) {
+ Assert.assertEquals(1, stateNameToMetaInfo.getValue().getOffsets().length);
+ } else {
+ Assert.assertEquals(2, stateNameToMetaInfo.getValue().getOffsets().length);
+ }
+ }
+ }
+ }
+
+ Assert.assertEquals(3, checkCounts.size());
+ Assert.assertEquals(3, checkCounts.get("t-1").intValue());
+ Assert.assertEquals(3, checkCounts.get("t-2").intValue());
+ Assert.assertEquals(2, checkCounts.get("t-3").intValue());
+ }
+
// ------------------------------------------------------------------------
// Utilities
// ------------------------------------------------------------------------
@@ -2520,11 +2563,15 @@ public class CheckpointCoordinatorTest {
Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables);
- Map<String, long[]> offsetsMap = new HashMap<>(states.size());
+ Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(states.size());
int idx = 0;
for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
- offsetsMap.put(entry.getKey(), serializationWithOffsets.f1.get(idx));
+ offsetsMap.put(
+ entry.getKey(),
+ new OperatorStateHandle.StateMetaInfo(
+ serializationWithOffsets.f1.get(idx),
+ OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
++idx;
}
@@ -2601,7 +2648,7 @@ public class CheckpointCoordinatorTest {
return vertex;
}
- public static void verifiyStateRestore(
+ public static void verifyStateRestore(
JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex,
List<KeyGroupRange> keyGroupPartitions) throws Exception {
@@ -2697,8 +2744,8 @@ public class CheckpointCoordinatorTest {
private static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception {
try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
- for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
- for (long offset : entry.getValue()) {
+ for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+ for (long offset : entry.getValue().getOffsets()) {
in.seek(offset);
Integer state = InstantiationUtil.
deserializeObject(in, Thread.currentThread().getContextClassLoader());
@@ -2801,17 +2848,22 @@ public class CheckpointCoordinatorTest {
for (int i = 0; i < oldParallelism; ++i) {
Path fakePath = new Path("/fake-" + i);
- Map<String, long[]> namedStatesToOffsets = new HashMap<>();
+ Map<String, OperatorStateHandle.StateMetaInfo> namedStatesToOffsets = new HashMap<>();
int off = 0;
for (int s = 0; s < numNamedStates; ++s) {
long[] offs = new long[1 + r.nextInt(maxPartitionsPerState)];
- if (offs.length > 0) {
- for (int o = 0; o < offs.length; ++o) {
- offs[o] = off;
- ++off;
- }
- namedStatesToOffsets.put("State-" + s, offs);
+
+ for (int o = 0; o < offs.length; ++o) {
+ offs[o] = off;
+ ++off;
}
+
+ OperatorStateHandle.Mode mode = r.nextInt(10) == 0 ?
+ OperatorStateHandle.Mode.BROADCAST : OperatorStateHandle.Mode.SPLIT_DISTRIBUTE;
+ namedStatesToOffsets.put(
+ "State-" + s,
+ new OperatorStateHandle.StateMetaInfo(offs, mode));
+
}
previousParallelOpInstanceStates.add(
@@ -2822,14 +2874,21 @@ public class CheckpointCoordinatorTest {
int expectedTotalPartitions = 0;
for (OperatorStateHandle psh : previousParallelOpInstanceStates) {
- Map<String, long[]> offsMap = psh.getStateNameToPartitionOffsets();
+ Map<String, OperatorStateHandle.StateMetaInfo> offsMap = psh.getStateNameToPartitionOffsets();
Map<String, List<Long>> offsMapWithList = new HashMap<>(offsMap.size());
- for (Map.Entry<String, long[]> e : offsMap.entrySet()) {
- long[] offs = e.getValue();
- expectedTotalPartitions += offs.length;
+ for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e : offsMap.entrySet()) {
+
+ long[] offs = e.getValue().getOffsets();
+ int replication = e.getValue().getDistributionMode().equals(OperatorStateHandle.Mode.BROADCAST) ?
+ newParallelism : 1;
+
+ expectedTotalPartitions += replication * offs.length;
List<Long> offsList = new ArrayList<>(offs.length);
+
for (int i = 0; i < offs.length; ++i) {
- offsList.add(i, offs[i]);
+ for(int p = 0; p < replication; ++p) {
+ offsList.add(offs[i]);
+ }
}
offsMapWithList.put(e.getKey(), offsList);
}
@@ -2851,25 +2910,25 @@ public class CheckpointCoordinatorTest {
Collection<OperatorStateHandle> pshc = pshs.get(p);
for (OperatorStateHandle sh : pshc) {
- for (Map.Entry<String, long[]> namedState : sh.getStateNameToPartitionOffsets().entrySet()) {
+ for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> namedState : sh.getStateNameToPartitionOffsets().entrySet()) {
- Map<String, List<Long>> x = actual.get(sh.getDelegateStateHandle());
- if (x == null) {
- x = new HashMap<>();
- actual.put(sh.getDelegateStateHandle(), x);
+ Map<String, List<Long>> stateToOffsets = actual.get(sh.getDelegateStateHandle());
+ if (stateToOffsets == null) {
+ stateToOffsets = new HashMap<>();
+ actual.put(sh.getDelegateStateHandle(), stateToOffsets);
}
- List<Long> actualOffs = x.get(namedState.getKey());
+ List<Long> actualOffs = stateToOffsets.get(namedState.getKey());
if (actualOffs == null) {
actualOffs = new ArrayList<>();
- x.put(namedState.getKey(), actualOffs);
+ stateToOffsets.put(namedState.getKey(), actualOffs);
}
- long[] add = namedState.getValue();
+ long[] add = namedState.getValue().getOffsets();
for (int i = 0; i < add.length; ++i) {
actualOffs.add(add[i]);
}
- partitionCount += namedState.getValue().length;
+ partitionCount += namedState.getValue().getOffsets().length;
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
index db5c35b..5184db8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
@@ -99,9 +99,10 @@ public class SavepointV1Test {
new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes());
StreamStateHandle operatorStateStream =
new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes());
- Map<String, long[]> offsetsMap = new HashMap<>();
- offsetsMap.put("A", new long[]{0, 10, 20});
- offsetsMap.put("B", new long[]{30, 40, 50});
+ Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>();
+ offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+ offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+ offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST));
if (chainIdx != noNonPartitionableStateAtIndex) {
nonPartitionableStates.add(nonPartitionableState);
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
index 515011f..cd0391f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -31,12 +31,13 @@ import java.util.Iterator;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class OperatorStateBackendTest {
- AbstractStateBackend abstractStateBackend = new MemoryStateBackend(1024);
+ AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096);
static Environment createMockEnvironment() {
Environment env = mock(Environment.class);
@@ -62,6 +63,7 @@ public class OperatorStateBackendTest {
OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
+ ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>());
ListState<Serializable> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1);
assertNotNull(listState1);
assertEquals(1, operatorStateBackend.getRegisteredStateNames().size());
@@ -89,6 +91,20 @@ public class OperatorStateBackendTest {
assertEquals(23, it.next());
assertTrue(!it.hasNext());
+ ListState<Serializable> listState3 = operatorStateBackend.getBroadcastOperatorState(stateDescriptor3);
+ assertNotNull(listState3);
+ assertEquals(3, operatorStateBackend.getRegisteredStateNames().size());
+ assertTrue(!it.hasNext());
+ listState3.add(17);
+ listState3.add(3);
+ listState3.add(123);
+
+ it = listState3.get().iterator();
+ assertEquals(17, it.next());
+ assertEquals(3, it.next());
+ assertEquals(123, it.next());
+ assertTrue(!it.hasNext());
+
ListState<Serializable> listState1b = operatorStateBackend.getOperatorState(stateDescriptor1);
assertNotNull(listState1b);
listState1b.add(123);
@@ -109,6 +125,20 @@ public class OperatorStateBackendTest {
assertEquals(4711, it.next());
assertEquals(123, it.next());
assertTrue(!it.hasNext());
+
+ try {
+ operatorStateBackend.getBroadcastOperatorState(stateDescriptor2);
+ fail("Did not detect changed mode");
+ } catch (IllegalStateException ignored) {
+
+ }
+
+ try {
+ operatorStateBackend.getOperatorState(stateDescriptor3);
+ fail("Did not detect changed mode");
+ } catch (IllegalStateException ignored) {
+
+ }
}
@Test
@@ -116,8 +146,10 @@ public class OperatorStateBackendTest {
OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
+ ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>());
ListState<Serializable> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1);
ListState<Serializable> listState2 = operatorStateBackend.getOperatorState(stateDescriptor2);
+ ListState<Serializable> listState3 = operatorStateBackend.getBroadcastOperatorState(stateDescriptor3);
listState1.add(42);
listState1.add(4711);
@@ -126,11 +158,17 @@ public class OperatorStateBackendTest {
listState2.add(13);
listState2.add(23);
+ listState3.add(17);
+ listState3.add(18);
+ listState3.add(19);
+ listState3.add(20);
+
CheckpointStreamFactory streamFactory = abstractStateBackend.createStreamFactory(new JobID(), "testOperator");
OperatorStateHandle stateHandle = operatorStateBackend.snapshot(1, 1, streamFactory).get();
try {
+ operatorStateBackend.close();
operatorStateBackend.dispose();
operatorStateBackend = abstractStateBackend.createOperatorStateBackend(
@@ -139,13 +177,13 @@ public class OperatorStateBackendTest {
operatorStateBackend.restore(Collections.singletonList(stateHandle));
- assertEquals(2, operatorStateBackend.getRegisteredStateNames().size());
+ assertEquals(3, operatorStateBackend.getRegisteredStateNames().size());
listState1 = operatorStateBackend.getOperatorState(stateDescriptor1);
listState2 = operatorStateBackend.getOperatorState(stateDescriptor2);
+ listState3 = operatorStateBackend.getBroadcastOperatorState(stateDescriptor3);
- assertEquals(2, operatorStateBackend.getRegisteredStateNames().size());
-
+ assertEquals(3, operatorStateBackend.getRegisteredStateNames().size());
Iterator<Serializable> it = listState1.get().iterator();
assertEquals(42, it.next());
@@ -158,6 +196,14 @@ public class OperatorStateBackendTest {
assertEquals(23, it.next());
assertTrue(!it.hasNext());
+ it = listState3.get().iterator();
+ assertEquals(17, it.next());
+ assertEquals(18, it.next());
+ assertEquals(19, it.next());
+ assertEquals(20, it.next());
+ assertTrue(!it.hasNext());
+
+ operatorStateBackend.close();
operatorStateBackend.dispose();
} finally {
stateHandle.discardState();
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java
new file mode 100644
index 0000000..ab801b6
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class OperatorStateHandleTest {
+
+ @Test
+ public void testFixedEnumOrder() {
+
+ // Ensure the order / ordinal of all values of enum 'mode' are fixed, as this is used for serialization
+ Assert.assertEquals(0, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.ordinal());
+ Assert.assertEquals(1, OperatorStateHandle.Mode.BROADCAST.ordinal());
+
+ // Ensure all enum values are registered and fixed forever by this test
+ Assert.assertEquals(2, OperatorStateHandle.Mode.values().length);
+
+ // Byte is used to encode enum value on serialization
+ Assert.assertTrue(OperatorStateHandle.Mode.values().length <= Byte.MAX_VALUE);
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
index c6ef0f0..7efcd0d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
@@ -27,6 +27,7 @@ import org.junit.Assert;
import org.junit.Test;
import java.io.IOException;
+import java.util.Map;
public class OperatorStateOutputCheckpointStreamTest {
@@ -77,15 +78,23 @@ public class OperatorStateOutputCheckpointStreamTest {
OperatorStateHandle fullHandle = writeAllTestKeyGroups(stream, numPartitions);
Assert.assertNotNull(fullHandle);
+ Map<String, OperatorStateHandle.StateMetaInfo> stateNameToPartitionOffsets =
+ fullHandle.getStateNameToPartitionOffsets();
+ for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : stateNameToPartitionOffsets.entrySet()) {
+
+ Assert.assertEquals(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE, entry.getValue().getDistributionMode());
+ }
verifyRead(fullHandle, numPartitions);
}
private static void verifyRead(OperatorStateHandle fullHandle, int numPartitions) throws IOException {
int count = 0;
try (FSDataInputStream in = fullHandle.openInputStream()) {
- long[] offsets = fullHandle.getStateNameToPartitionOffsets().
+ OperatorStateHandle.StateMetaInfo metaInfo = fullHandle.getStateNameToPartitionOffsets().
get(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
+ long[] offsets = metaInfo.getOffsets();
+
Assert.assertNotNull(offsets);
DataInputView div = new DataInputViewStreamWrapper(in);
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
index 832b022..2448540 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
@@ -36,7 +36,7 @@ import java.util.List;
public class SerializationProxiesTest {
@Test
- public void testSerializationRoundtrip() throws Exception {
+ public void testKeyedBackendSerializationProxyRoundtrip() throws Exception {
TypeSerializer<?> keySerializer = IntSerializer.INSTANCE;
TypeSerializer<?> namespaceSerializer = LongSerializer.INSTANCE;
@@ -67,13 +67,12 @@ public class SerializationProxiesTest {
serializationProxy.read(new DataInputViewStreamWrapper(in));
}
-
Assert.assertEquals(keySerializer, serializationProxy.getKeySerializerProxy().getTypeSerializer());
Assert.assertEquals(stateMetaInfoList, serializationProxy.getNamedStateSerializationProxies());
}
@Test
- public void testMetaInfoSerialization() throws Exception {
+ public void testKeyedStateMetaInfoSerialization() throws Exception {
String name = "test";
TypeSerializer<?> namespaceSerializer = LongSerializer.INSTANCE;
@@ -97,6 +96,64 @@ public class SerializationProxiesTest {
Assert.assertEquals(name, metaInfo.getStateName());
}
+
+ @Test
+ public void testOperatorBackendSerializationProxyRoundtrip() throws Exception {
+
+ TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE;
+
+ List<OperatorBackendSerializationProxy.StateMetaInfo<?>> stateMetaInfoList = new ArrayList<>();
+
+ stateMetaInfoList.add(
+ new OperatorBackendSerializationProxy.StateMetaInfo<>("a", stateSerializer, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+ stateMetaInfoList.add(
+ new OperatorBackendSerializationProxy.StateMetaInfo<>("b", stateSerializer, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+ stateMetaInfoList.add(
+ new OperatorBackendSerializationProxy.StateMetaInfo<>("c", stateSerializer, OperatorStateHandle.Mode.BROADCAST));
+
+ OperatorBackendSerializationProxy serializationProxy =
+ new OperatorBackendSerializationProxy(stateMetaInfoList);
+
+ byte[] serialized;
+ try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) {
+ serializationProxy.write(new DataOutputViewStreamWrapper(out));
+ serialized = out.toByteArray();
+ }
+
+ serializationProxy =
+ new OperatorBackendSerializationProxy(Thread.currentThread().getContextClassLoader());
+
+ try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) {
+ serializationProxy.read(new DataInputViewStreamWrapper(in));
+ }
+
+ Assert.assertEquals(stateMetaInfoList, serializationProxy.getNamedStateSerializationProxies());
+ }
+
+ @Test
+ public void testOperatorStateMetaInfoSerialization() throws Exception {
+
+ String name = "test";
+ TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE;
+
+ OperatorBackendSerializationProxy.StateMetaInfo<?> metaInfo =
+ new OperatorBackendSerializationProxy.StateMetaInfo<>(name, stateSerializer, OperatorStateHandle.Mode.BROADCAST);
+
+ byte[] serialized;
+ try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) {
+ metaInfo.write(new DataOutputViewStreamWrapper(out));
+ serialized = out.toByteArray();
+ }
+
+ metaInfo = new OperatorBackendSerializationProxy.StateMetaInfo<>(Thread.currentThread().getContextClassLoader());
+
+ try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) {
+ metaInfo.read(new DataInputViewStreamWrapper(in));
+ }
+
+ Assert.assertEquals(name, metaInfo.getName());
+ }
+
/**
* This test fixes the order of elements in the enum which is important for serialization. Do not modify this test
* except if you are entirely sure what you are doing.
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
index 39dc5d6..963c42c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
@@ -111,8 +111,10 @@ public class StateInitializationContextImplTest {
writtenOperatorStates.add(val);
}
- Map<String, long[]> offsetsMap = new HashMap<>();
- offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, offsets.toArray());
+ Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>();
+ offsetsMap.put(
+ DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME,
+ new OperatorStateHandle.StateMetaInfo(offsets.toArray(), OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
OperatorStateHandle operatorStateHandle =
new OperatorStateHandle(offsetsMap, new ByteStateHandleCloseChecking("os-" + i, out.toByteArray()));
operatorStateHandles.add(operatorStateHandle);
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index 0206cf5..58cfefd 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -191,8 +191,10 @@ public class InterruptSensitiveRestoreTest {
List<Collection<OperatorStateHandle>> operatorStateBackend = Collections.emptyList();
List<Collection<OperatorStateHandle>> operatorStateStream = Collections.emptyList();
- Map<String, long[]> operatorStateMetadata = new HashMap<>(1);
- operatorStateMetadata.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, new long[]{0});
+ Map<String, OperatorStateHandle.StateMetaInfo> operatorStateMetadata = new HashMap<>(1);
+ OperatorStateHandle.StateMetaInfo metaInfo =
+ new OperatorStateHandle.StateMetaInfo(new long[]{0}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
+ operatorStateMetadata.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, metaInfo);
KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(new KeyGroupRange(0,0));
http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index da4a01b..45fcc25 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -86,7 +86,7 @@ public class RescalingITCase extends TestLogger {
private static final int numSlots = numTaskManagers * slotsPerTaskManager;
enum OperatorCheckpointMethod {
- NON_PARTITIONED, CHECKPOINTED_FUNCTION, LIST_CHECKPOINTED
+ NON_PARTITIONED, CHECKPOINTED_FUNCTION, CHECKPOINTED_FUNCTION_BROADCAST, LIST_CHECKPOINTED
}
private static TestingCluster cluster;
@@ -179,7 +179,7 @@ public class RescalingITCase extends TestLogger {
Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID, Option.<String>empty()), deadline.timeLeft());
final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)
- Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
+ Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
@@ -270,7 +270,7 @@ public class RescalingITCase extends TestLogger {
assertTrue(String.valueOf(savepointResponse), savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess);
- final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)savepointResponse).savepointPath();
+ final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess) savepointResponse).savepointPath();
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
@@ -339,16 +339,16 @@ public class RescalingITCase extends TestLogger {
JobID jobID = null;
try {
- jobManager = cluster.getLeaderGateway(deadline.timeLeft());
+ jobManager = cluster.getLeaderGateway(deadline.timeLeft());
JobGraph jobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState(
- parallelism,
- maxParallelism,
- parallelism,
- numberKeys,
- numberElements,
- false,
- 100);
+ parallelism,
+ maxParallelism,
+ parallelism,
+ numberKeys,
+ numberElements,
+ false,
+ 100);
jobID = jobGraph.getJobID();
@@ -366,7 +366,7 @@ public class RescalingITCase extends TestLogger {
for (int key = 0; key < numberKeys; key++) {
int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
- expectedResult.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex) , numberElements * key));
+ expectedResult.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex), numberElements * key));
}
assertEquals(expectedResult, actualResult);
@@ -377,7 +377,7 @@ public class RescalingITCase extends TestLogger {
Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID, Option.<String>empty()), deadline.timeLeft());
final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)
- Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
+ Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
@@ -392,13 +392,13 @@ public class RescalingITCase extends TestLogger {
jobID = null;
JobGraph scaledJobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState(
- parallelism2,
- maxParallelism,
- parallelism,
- numberKeys,
- numberElements + numberElements2,
- true,
- 100);
+ parallelism2,
+ maxParallelism,
+ parallelism,
+ numberKeys,
+ numberElements + numberElements2,
+ true,
+ 100);
scaledJobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath));
@@ -447,6 +447,16 @@ public class RescalingITCase extends TestLogger {
}
@Test
+ public void testSavepointRescalingInBroadcastOperatorState() throws Exception {
+ testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
+ }
+
+ @Test
+ public void testSavepointRescalingOutBroadcastOperatorState() throws Exception {
+ testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
+ }
+
+ @Test
public void testSavepointRescalingInPartitionedOperatorStateList() throws Exception {
testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.LIST_CHECKPOINTED);
}
@@ -474,7 +484,8 @@ public class RescalingITCase extends TestLogger {
int counterSize = Math.max(parallelism, parallelism2);
- if(checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
+ if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION ||
+ checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) {
PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize];
} else {
@@ -505,11 +516,12 @@ public class RescalingITCase extends TestLogger {
if (savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess) {
break;
}
+ System.out.println(savepointResponse);
}
assertTrue(savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess);
- final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)savepointResponse).savepointPath();
+ final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess) savepointResponse).savepointPath();
Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
@@ -543,6 +555,16 @@ public class RescalingITCase extends TestLogger {
for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
sumAct += c;
}
+ } else if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) {
+ for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
+ sumExp += c;
+ }
+
+ for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
+ sumAct += c;
+ }
+
+ sumExp *= parallelism2;
} else {
for (int c : PartitionedStateSourceListCheckpointed.CHECK_CORRECT_SNAPSHOT) {
sumExp += c;
@@ -587,7 +609,10 @@ public class RescalingITCase extends TestLogger {
switch (checkpointMethod) {
case CHECKPOINTED_FUNCTION:
- src = new PartitionedStateSource();
+ src = new PartitionedStateSource(false);
+ break;
+ case CHECKPOINTED_FUNCTION_BROADCAST:
+ src = new PartitionedStateSource(true);
break;
case LIST_CHECKPOINTED:
src = new PartitionedStateSourceListCheckpointed();
@@ -607,12 +632,12 @@ public class RescalingITCase extends TestLogger {
}
private static JobGraph createJobGraphWithKeyedState(
- int parallelism,
- int maxParallelism,
- int numberKeys,
- int numberElements,
- boolean terminateAfterEmission,
- int checkpointingInterval) {
+ int parallelism,
+ int maxParallelism,
+ int numberKeys,
+ int numberElements,
+ boolean terminateAfterEmission,
+ int checkpointingInterval) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(parallelism);
@@ -621,17 +646,17 @@ public class RescalingITCase extends TestLogger {
env.setRestartStrategy(RestartStrategies.noRestart());
DataStream<Integer> input = env.addSource(new SubtaskIndexSource(
- numberKeys,
- numberElements,
- terminateAfterEmission))
- .keyBy(new KeySelector<Integer, Integer>() {
- private static final long serialVersionUID = -7952298871120320940L;
-
- @Override
- public Integer getKey(Integer value) throws Exception {
- return value;
- }
- });
+ numberKeys,
+ numberElements,
+ terminateAfterEmission))
+ .keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = -7952298871120320940L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ });
SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
@@ -643,13 +668,13 @@ public class RescalingITCase extends TestLogger {
}
private static JobGraph createJobGraphWithKeyedAndNonPartitionedOperatorState(
- int parallelism,
- int maxParallelism,
- int fixedParallelism,
- int numberKeys,
- int numberElements,
- boolean terminateAfterEmission,
- int checkpointingInterval) {
+ int parallelism,
+ int maxParallelism,
+ int fixedParallelism,
+ int numberKeys,
+ int numberElements,
+ boolean terminateAfterEmission,
+ int checkpointingInterval) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(parallelism);
@@ -658,18 +683,18 @@ public class RescalingITCase extends TestLogger {
env.setRestartStrategy(RestartStrategies.noRestart());
DataStream<Integer> input = env.addSource(new SubtaskIndexNonPartitionedStateSource(
- numberKeys,
- numberElements,
- terminateAfterEmission))
- .setParallelism(fixedParallelism)
- .keyBy(new KeySelector<Integer, Integer>() {
- private static final long serialVersionUID = -7952298871120320940L;
-
- @Override
- public Integer getKey(Integer value) throws Exception {
- return value;
- }
- });
+ numberKeys,
+ numberElements,
+ terminateAfterEmission))
+ .setParallelism(fixedParallelism)
+ .keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = -7952298871120320940L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ });
SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
@@ -681,7 +706,7 @@ public class RescalingITCase extends TestLogger {
}
private static class SubtaskIndexSource
- extends RichParallelSourceFunction<Integer> {
+ extends RichParallelSourceFunction<Integer> {
private static final long serialVersionUID = -400066323594122516L;
@@ -694,9 +719,9 @@ public class RescalingITCase extends TestLogger {
private boolean running = true;
SubtaskIndexSource(
- int numberKeys,
- int numberElements,
- boolean terminateAfterEmission) {
+ int numberKeys,
+ int numberElements,
+ boolean terminateAfterEmission) {
this.numberKeys = numberKeys;
this.numberElements = numberElements;
@@ -713,8 +738,8 @@ public class RescalingITCase extends TestLogger {
if (counter < numberElements) {
synchronized (lock) {
for (int value = subtaskIndex;
- value < numberKeys;
- value += getRuntimeContext().getNumberOfParallelSubtasks()) {
+ value < numberKeys;
+ value += getRuntimeContext().getNumberOfParallelSubtasks()) {
ctx.collect(value);
}
@@ -836,6 +861,7 @@ public class RescalingITCase extends TestLogger {
}
Thread.sleep(2);
+
if (counter == 10) {
workStartedLatch.countDown();
}
@@ -910,10 +936,14 @@ public class RescalingITCase extends TestLogger {
private static final int NUM_PARTITIONS = 7;
private ListState<Integer> counterPartitions;
+ private boolean broadcast;
private static int[] CHECK_CORRECT_SNAPSHOT;
private static int[] CHECK_CORRECT_RESTORE;
+ public PartitionedStateSource(boolean broadcast) {
+ this.broadcast = broadcast;
+ }
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
@@ -937,8 +967,15 @@ public class RescalingITCase extends TestLogger {
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
- this.counterPartitions =
- context.getOperatorStateStore().getSerializableListState("counter_partitions");
+
+ if (broadcast) {
+ this.counterPartitions =
+ context.getOperatorStateStore().getBroadcastSerializableListState("counter_partitions");
+ } else {
+ this.counterPartitions =
+ context.getOperatorStateStore().getSerializableListState("counter_partitions");
+ }
+
if (context.isRestored()) {
for (int v : counterPartitions.get()) {
counter += v;