You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2016/08/31 17:28:20 UTC
[02/27] flink git commit: [FLINK-4380] Introduce KeyGroupAssigner and
Max-Parallelism Parameter
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
index 5a86c5c..17bea68 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
@@ -52,7 +52,7 @@ import org.apache.flink.streaming.api.windowing.triggers.PurgingTrigger;
import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
import org.apache.flink.streaming.runtime.partitioner.CustomPartitionerWrapper;
-import org.apache.flink.streaming.runtime.partitioner.HashPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.GlobalPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
@@ -672,7 +672,7 @@ public class DataStreamTest {
assertNotNull(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStateKeySerializer());
assertNotNull(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStateKeySerializer());
assertEquals(key1, env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStatePartitioner1());
- assertTrue(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof HashPartitioner);
+ assertTrue(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof KeyGroupStreamPartitioner);
KeySelector<Long, Long> key2 = new KeySelector<Long, Long>() {
@@ -688,7 +688,7 @@ public class DataStreamTest {
assertTrue(env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getStatePartitioner1() != null);
assertEquals(key2, env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getStatePartitioner1());
- assertTrue(env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof HashPartitioner);
+ assertTrue(env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof KeyGroupStreamPartitioner);
}
@Test
@@ -783,7 +783,7 @@ public class DataStreamTest {
private static boolean isPartitioned(List<StreamEdge> edges) {
boolean result = true;
for (StreamEdge edge: edges) {
- if (!(edge.getPartitioner() instanceof HashPartitioner)) {
+ if (!(edge.getPartitioner() instanceof KeyGroupStreamPartitioner)) {
result = false;
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/RestartStrategyTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/RestartStrategyTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/RestartStrategyTest.java
index c57bea7..d6fcd61 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/RestartStrategyTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/RestartStrategyTest.java
@@ -22,11 +22,11 @@ import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.graph.StreamGraph;
-
+import org.apache.flink.util.TestLogger;
import org.junit.Assert;
import org.junit.Test;
-public class RestartStrategyTest {
+public class RestartStrategyTest extends TestLogger {
/**
* Tests that in a streaming use case where checkpointing is enabled, a
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/SlotAllocationTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/SlotAllocationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/SlotAllocationTest.java
index bab43fa..d873771 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/SlotAllocationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/SlotAllocationTest.java
@@ -28,6 +28,7 @@ import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.CoMapFunction;
+import org.apache.flink.util.TestLogger;
import org.junit.Test;
/**
@@ -37,7 +38,7 @@ import org.junit.Test;
* resource groups/slot sharing groups.
*/
@SuppressWarnings("serial")
-public class SlotAllocationTest {
+public class SlotAllocationTest extends TestLogger {
@Test
public void testTwoPipelines() {
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
index a4ee18e..06d381f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
@@ -19,13 +19,17 @@
package org.apache.flink.streaming.api.graph;
import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.state.KeyGroupAssigner;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.runtime.state.HashKeyGroupAssigner;
import org.apache.flink.streaming.api.datastream.ConnectedStreams;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
+import org.apache.flink.streaming.api.functions.co.CoMapFunction;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.Output;
@@ -34,8 +38,10 @@ import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
import org.apache.flink.streaming.runtime.partitioner.GlobalPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
import org.apache.flink.streaming.runtime.partitioner.ShufflePartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.streaming.util.EvenOddOutputSelector;
@@ -236,6 +242,207 @@ public class StreamGraphGeneratorTest {
assertEquals(BasicTypeInfo.INT_TYPE_INFO, outputTypeConfigurableOperation.getTypeInformation());
}
+ /**
+ * Tests that the KeyGroupStreamPartitioner are properly set up with the correct value of
+ * maximum parallelism.
+ */
+ @Test
+ public void testSetupOfKeyGroupPartitioner() {
+ int maxParallelism = 42;
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.getConfig().setMaxParallelism(maxParallelism);
+
+ DataStream<Integer> source = env.fromElements(1, 2, 3);
+
+ DataStream<Integer> keyedResult = source.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 9205556348021992189L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap());
+
+ keyedResult.addSink(new DiscardingSink<Integer>());
+
+ StreamGraph graph = env.getStreamGraph();
+
+ StreamNode keyedResultNode = graph.getStreamNode(keyedResult.getId());
+
+ StreamPartitioner<?> streamPartitioner = keyedResultNode.getInEdges().get(0).getPartitioner();
+
+ HashKeyGroupAssigner<?> hashKeyGroupAssigner = extractHashKeyGroupAssigner(streamPartitioner);
+
+ assertEquals(maxParallelism, hashKeyGroupAssigner.getNumberKeyGroups());
+ }
+
+ /**
+ * Tests that the global and operator-wide max parallelism setting is respected
+ */
+ @Test
+ public void testMaxParallelismForwarding() {
+ int globalMaxParallelism = 42;
+ int keyedResult2MaxParallelism = 17;
+
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.getConfig().setMaxParallelism(globalMaxParallelism);
+
+ DataStream<Integer> source = env.fromElements(1, 2, 3);
+
+ DataStream<Integer> keyedResult1 = source.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 9205556348021992189L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap());
+
+ DataStream<Integer> keyedResult2 = keyedResult1.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 1250168178707154838L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap()).setMaxParallelism(keyedResult2MaxParallelism);
+
+ keyedResult2.addSink(new DiscardingSink<Integer>());
+
+ StreamGraph graph = env.getStreamGraph();
+
+ StreamNode keyedResult1Node = graph.getStreamNode(keyedResult1.getId());
+ StreamNode keyedResult2Node = graph.getStreamNode(keyedResult2.getId());
+
+ assertEquals(globalMaxParallelism, keyedResult1Node.getMaxParallelism());
+ assertEquals(keyedResult2MaxParallelism, keyedResult2Node.getMaxParallelism());
+ }
+
+ /**
+ * Tests that the max parallelism is automatically set to the parallelism if it has not been
+ * specified.
+ */
+ @Test
+ public void testAutoMaxParallelism() {
+ int globalParallelism = 42;
+ int mapParallelism = 17;
+ int maxParallelism = 21;
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(globalParallelism);
+
+ DataStream<Integer> source = env.fromElements(1, 2, 3);
+
+ DataStream<Integer> keyedResult1 = source.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 9205556348021992189L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap());
+
+ DataStream<Integer> keyedResult2 = keyedResult1.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 1250168178707154838L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap()).setParallelism(mapParallelism);
+
+ DataStream<Integer> keyedResult3 = keyedResult2.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 1250168178707154838L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap()).setMaxParallelism(maxParallelism);
+
+ DataStream<Integer> keyedResult4 = keyedResult3.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 1250168178707154838L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap()).setMaxParallelism(maxParallelism).setParallelism(mapParallelism);
+
+ keyedResult4.addSink(new DiscardingSink<Integer>());
+
+ StreamGraph graph = env.getStreamGraph();
+
+ StreamNode keyedResult1Node = graph.getStreamNode(keyedResult1.getId());
+ StreamNode keyedResult2Node = graph.getStreamNode(keyedResult2.getId());
+ StreamNode keyedResult3Node = graph.getStreamNode(keyedResult3.getId());
+ StreamNode keyedResult4Node = graph.getStreamNode(keyedResult4.getId());
+
+ assertEquals(globalParallelism, keyedResult1Node.getMaxParallelism());
+ assertEquals(mapParallelism, keyedResult2Node.getMaxParallelism());
+ assertEquals(maxParallelism, keyedResult3Node.getMaxParallelism());
+ assertEquals(maxParallelism, keyedResult4Node.getMaxParallelism());
+ }
+
+ /**
+ * Tests that the max parallelism and the key group partitioner is properly set for connected
+ * streams.
+ */
+ @Test
+ public void testMaxParallelismWithConnectedKeyedStream() {
+ int maxParallelism = 42;
+
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ DataStream<Integer> input1 = env.fromElements(1, 2, 3, 4).setMaxParallelism(128);
+ DataStream<Integer> input2 = env.fromElements(1, 2, 3, 4).setMaxParallelism(129);
+
+ env.getConfig().setMaxParallelism(maxParallelism);
+
+ DataStream<Integer> keyedResult = input1.connect(input2).keyBy(
+ new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = -6908614081449363419L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ },
+ new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 3195683453223164931L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntCoMap());
+
+ keyedResult.addSink(new DiscardingSink<Integer>());
+
+ StreamGraph graph = env.getStreamGraph();
+
+ StreamNode keyedResultNode = graph.getStreamNode(keyedResult.getId());
+
+ StreamPartitioner<?> streamPartitioner1 = keyedResultNode.getInEdges().get(0).getPartitioner();
+ StreamPartitioner<?> streamPartitioner2 = keyedResultNode.getInEdges().get(1).getPartitioner();
+
+ HashKeyGroupAssigner<?> hashKeyGroupAssigner1 = extractHashKeyGroupAssigner(streamPartitioner1);
+ assertEquals(maxParallelism, hashKeyGroupAssigner1.getNumberKeyGroups());
+
+ HashKeyGroupAssigner<?> hashKeyGroupAssigner2 = extractHashKeyGroupAssigner(streamPartitioner2);
+ assertEquals(maxParallelism, hashKeyGroupAssigner2.getNumberKeyGroups());
+ }
+
+ private HashKeyGroupAssigner<?> extractHashKeyGroupAssigner(StreamPartitioner<?> streamPartitioner) {
+ assertTrue(streamPartitioner instanceof KeyGroupStreamPartitioner);
+
+ KeyGroupStreamPartitioner<?, ?> keyGroupStreamPartitioner = (KeyGroupStreamPartitioner<?, ?>) streamPartitioner;
+
+ KeyGroupAssigner<?> keyGroupAssigner = keyGroupStreamPartitioner.getKeyGroupAssigner();
+
+ assertTrue(keyGroupAssigner instanceof HashKeyGroupAssigner);
+
+ return (HashKeyGroupAssigner<?>) keyGroupAssigner;
+ }
+
private static class OutputTypeConfigurableOperationWithTwoInputs
extends AbstractStreamOperator<Integer>
implements TwoInputStreamOperator<Integer, Integer, Integer>, OutputTypeConfigurable<Integer> {
@@ -297,4 +504,17 @@ public class StreamGraphGeneratorTest {
}
}
+ static class NoOpIntCoMap implements CoMapFunction<Integer, Integer, Integer> {
+ private static final long serialVersionUID = 1886595528149124270L;
+
+ public Integer map1(Integer value) throws Exception {
+ return value;
+ }
+
+ public Integer map2(Integer value) throws Exception {
+ return value;
+ }
+
+ };
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
index 7f94aa0..277fab4 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
@@ -18,24 +18,33 @@
package org.apache.flink.streaming.api.graph;
import java.io.IOException;
+import java.util.List;
import java.util.Random;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.KeyGroupAssigner;
+import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.state.HashKeyGroupAssigner;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.util.NoOpIntMap;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.SerializedValue;
+import org.apache.flink.util.TestLogger;
import org.junit.Test;
import static org.junit.Assert.*;
@SuppressWarnings("serial")
-public class StreamingJobGraphGeneratorTest {
+public class StreamingJobGraphGeneratorTest extends TestLogger {
@Test
public void testExecutionConfigSerialization() throws IOException, ClassNotFoundException {
@@ -114,6 +123,8 @@ public class StreamingJobGraphGeneratorTest {
DataStream<Tuple2<String, String>> input = env
.fromElements("a", "b", "c", "d", "e", "f")
.map(new MapFunction<String, Tuple2<String, String>>() {
+ private static final long serialVersionUID = 471891682418382583L;
+
@Override
public Tuple2<String, String> map(String value) {
return new Tuple2<>(value, value);
@@ -124,6 +135,8 @@ public class StreamingJobGraphGeneratorTest {
.keyBy(0)
.map(new MapFunction<Tuple2<String, String>, Tuple2<String, String>>() {
+ private static final long serialVersionUID = 3583760206245136188L;
+
@Override
public Tuple2<String, String> map(Tuple2<String, String> value) {
return value;
@@ -131,6 +144,8 @@ public class StreamingJobGraphGeneratorTest {
});
result.addSink(new SinkFunction<Tuple2<String, String>>() {
+ private static final long serialVersionUID = -5614849094269539342L;
+
@Override
public void invoke(Tuple2<String, String> value) {}
});
@@ -145,4 +160,203 @@ public class StreamingJobGraphGeneratorTest {
assertEquals(1, jobGraph.getVerticesAsArray()[0].getParallelism());
assertEquals(1, jobGraph.getVerticesAsArray()[1].getParallelism());
}
+
+ /**
+ * Tests that the KeyGroupAssigner is properly set in the {@link StreamConfig} if the max
+ * parallelism is set for the whole job.
+ */
+ @Test
+ public void testKeyGroupAssignerProperlySet() {
+ int maxParallelism = 42;
+
+ final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.getConfig().setMaxParallelism(maxParallelism);
+
+ DataStream<Integer> input = env.fromElements(1, 2, 3);
+
+ DataStream<Integer> keyedResult = input.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 350461576474507944L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap());
+
+ keyedResult.addSink(new DiscardingSink<Integer>());
+
+ JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+
+ List<JobVertex> jobVertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+
+ assertEquals(maxParallelism, jobVertices.get(1).getMaxParallelism());
+
+ HashKeyGroupAssigner<Integer> hashKeyGroupAssigner = extractHashKeyGroupAssigner(jobVertices.get(1));
+
+ assertEquals(maxParallelism, hashKeyGroupAssigner.getNumberKeyGroups());
+ }
+
+ /**
+ * Tests that the key group assigner for the keyed streams in the stream config is properly
+ * initialized with the max parallelism value if there is no max parallelism defined for the
+ * whole job.
+ */
+ @Test
+ public void testKeyGroupAssignerProperlySetAutoMaxParallelism() {
+ int globalParallelism = 42;
+ int mapParallelism = 17;
+ int maxParallelism = 43;
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(globalParallelism);
+
+ DataStream<Integer> source = env.fromElements(1, 2, 3);
+
+ DataStream<Integer> keyedResult1 = source.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 9205556348021992189L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap());
+
+ DataStream<Integer> keyedResult2 = keyedResult1.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 1250168178707154838L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap()).setParallelism(mapParallelism);
+
+ DataStream<Integer> keyedResult3 = keyedResult2.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 1250168178707154838L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap()).setMaxParallelism(maxParallelism);
+
+ DataStream<Integer> keyedResult4 = keyedResult3.keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 1250168178707154838L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new NoOpIntMap()).setMaxParallelism(maxParallelism).setParallelism(mapParallelism);
+
+ keyedResult4.addSink(new DiscardingSink<Integer>());
+
+ JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+
+ List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+
+ JobVertex keyedResultJV1 = vertices.get(1);
+ JobVertex keyedResultJV2 = vertices.get(2);
+ JobVertex keyedResultJV3 = vertices.get(3);
+ JobVertex keyedResultJV4 = vertices.get(4);
+
+ HashKeyGroupAssigner<Integer> hashKeyGroupAssigner1 = extractHashKeyGroupAssigner(keyedResultJV1);
+ HashKeyGroupAssigner<Integer> hashKeyGroupAssigner2 = extractHashKeyGroupAssigner(keyedResultJV2);
+ HashKeyGroupAssigner<Integer> hashKeyGroupAssigner3 = extractHashKeyGroupAssigner(keyedResultJV3);
+ HashKeyGroupAssigner<Integer> hashKeyGroupAssigner4 = extractHashKeyGroupAssigner(keyedResultJV4);
+
+ assertEquals(globalParallelism, hashKeyGroupAssigner1.getNumberKeyGroups());
+ assertEquals(mapParallelism, hashKeyGroupAssigner2.getNumberKeyGroups());
+ assertEquals(maxParallelism, hashKeyGroupAssigner3.getNumberKeyGroups());
+ assertEquals(maxParallelism, hashKeyGroupAssigner4.getNumberKeyGroups());
+ }
+
+ /**
+ * Tests that the {@link KeyGroupAssigner} is properly set in the {@link StreamConfig} for
+ * connected streams.
+ */
+ @Test
+ public void testMaxParallelismWithConnectedKeyedStream() {
+ int maxParallelism = 42;
+
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ DataStream<Integer> input1 = env.fromElements(1, 2, 3, 4).setMaxParallelism(128).name("input1");
+ DataStream<Integer> input2 = env.fromElements(1, 2, 3, 4).setMaxParallelism(129).name("input2");
+
+ env.getConfig().setMaxParallelism(maxParallelism);
+
+ DataStream<Integer> keyedResult = input1.connect(input2).keyBy(
+ new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = -6908614081449363419L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ },
+ new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = 3195683453223164931L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ }).map(new StreamGraphGeneratorTest.NoOpIntCoMap());
+
+ keyedResult.addSink(new DiscardingSink<Integer>());
+
+ JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+
+ List<JobVertex> jobVertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+
+ JobVertex input1JV = jobVertices.get(0);
+ JobVertex input2JV = jobVertices.get(1);
+ JobVertex connectedJV = jobVertices.get(2);
+
+ // disambiguate the partial order of the inputs
+ if (input1JV.getName().equals("Source: input1")) {
+ assertEquals(128, input1JV.getMaxParallelism());
+ assertEquals(129, input2JV.getMaxParallelism());
+ } else {
+ assertEquals(128, input2JV.getMaxParallelism());
+ assertEquals(129, input1JV.getMaxParallelism());
+ }
+
+ assertEquals(maxParallelism, connectedJV.getMaxParallelism());
+
+ HashKeyGroupAssigner<Integer> hashKeyGroupAssigner = extractHashKeyGroupAssigner(connectedJV);
+
+ assertEquals(maxParallelism, hashKeyGroupAssigner.getNumberKeyGroups());
+ }
+
+ /**
+ * Tests that the {@link JobGraph} creation fails if the parallelism is greater than the max
+ * parallelism.
+ */
+ @Test(expected=IllegalStateException.class)
+ public void testFailureOfJobJobCreationIfParallelismGreaterThanMaxParallelism() {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.getConfig().setMaxParallelism(42);
+
+ DataStream<Integer> input = env.fromElements(1, 2, 3, 4);
+
+ DataStream<Integer> result = input.map(new NoOpIntMap()).setParallelism(43);
+
+ result.addSink(new DiscardingSink<Integer>());
+
+ env.getStreamGraph().getJobGraph();
+
+ fail("The JobGraph should not have been created because the parallelism is greater than " +
+ "the max parallelism.");
+ }
+
+ private HashKeyGroupAssigner<Integer> extractHashKeyGroupAssigner(JobVertex jobVertex) {
+ Configuration config = jobVertex.getConfiguration();
+
+ StreamConfig streamConfig = new StreamConfig(config);
+
+ KeyGroupAssigner<Integer> keyGroupAssigner = streamConfig.getKeyGroupAssigner(getClass().getClassLoader());
+
+ assertTrue(keyGroupAssigner instanceof HashKeyGroupAssigner);
+
+ return (HashKeyGroupAssigner<Integer>) keyGroupAssigner;
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java
index bcf621a..340981b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java
@@ -27,12 +27,11 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.streaming.api.graph.StreamNode;
-
+import org.apache.flink.util.TestLogger;
import org.junit.Test;
import java.util.HashMap;
@@ -52,7 +51,7 @@ import static org.junit.Assert.assertTrue;
* {@link JobGraph} instances.
*/
@SuppressWarnings("serial")
-public class StreamingJobGraphGeneratorNodeHashTest {
+public class StreamingJobGraphGeneratorNodeHashTest extends TestLogger {
// ------------------------------------------------------------------------
// Deterministic hash assignment
@@ -126,53 +125,6 @@ public class StreamingJobGraphGeneratorNodeHashTest {
}
/**
- * Verifies that parallelism affects the node hash.
- */
- @Test
- public void testNodeHashParallelism() throws Exception {
- StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment();
- env.disableOperatorChaining();
-
- env.addSource(new NoOpSourceFunction(), "src").setParallelism(4)
- .addSink(new DiscardingSink<String>()).name("sink").setParallelism(4);
-
- JobGraph jobGraph = env.getStreamGraph().getJobGraph();
-
- Map<JobVertexID, String> ids = rememberIds(jobGraph);
-
- // Change parallelism of source
- env = StreamExecutionEnvironment.createLocalEnvironment();
- env.disableOperatorChaining();
-
- env.addSource(new NoOpSourceFunction(), "src").setParallelism(8)
- .addSink(new DiscardingSink<String>()).name("sink").setParallelism(4);
-
- jobGraph = env.getStreamGraph().getJobGraph();
-
- verifyIdsNotEqual(jobGraph, ids);
-
- // Change parallelism of sink
- env = StreamExecutionEnvironment.createLocalEnvironment();
- env.disableOperatorChaining();
-
- env.addSource(new NoOpSourceFunction(), "src").setParallelism(4)
- .addSink(new DiscardingSink<String>()).name("sink").setParallelism(8);
-
- jobGraph = env.getStreamGraph().getJobGraph();
-
- // The source hash will should be the same
- JobVertex[] vertices = jobGraph.getVerticesAsArray();
- if (vertices[0].isInputVertex()) {
- assertTrue(ids.containsKey(vertices[0].getID()));
- assertFalse(ids.containsKey(vertices[1].getID()));
- }
- else {
- assertTrue(ids.containsKey(vertices[1].getID()));
- assertFalse(ids.containsKey(vertices[0].getID()));
- }
- }
-
- /**
* Tests that there are no collisions with two identical sources.
*
* <pre>
@@ -516,6 +468,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
private static class NoOpSourceFunction implements ParallelSourceFunction<String> {
+ private static final long serialVersionUID = -5459224792698512636L;
+
@Override
public void run(SourceContext<String> ctx) throws Exception {
}
@@ -527,6 +481,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
private static class NoOpSinkFunction implements SinkFunction<String> {
+ private static final long serialVersionUID = -5654199886203297279L;
+
@Override
public void invoke(String value) throws Exception {
}
@@ -534,6 +490,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
private static class NoOpMapFunction implements MapFunction<String, String> {
+ private static final long serialVersionUID = 6584823409744624276L;
+
@Override
public String map(String value) throws Exception {
return value;
@@ -542,6 +500,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
private static class NoOpFilterFunction implements FilterFunction<String> {
+ private static final long serialVersionUID = 500005424900187476L;
+
@Override
public boolean filter(String value) throws Exception {
return true;
@@ -550,6 +510,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
private static class NoOpKeySelector implements KeySelector<String, String> {
+ private static final long serialVersionUID = -96127515593422991L;
+
@Override
public String getKey(String value) throws Exception {
return value;
@@ -557,6 +519,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
}
private static class NoOpReduceFunction implements ReduceFunction<String> {
+ private static final long serialVersionUID = -8775747640749256372L;
+
@Override
public String reduce(String value1, String value2) throws Exception {
return value1;
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
index ebe6bea..7ac9e13 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
@@ -253,6 +253,8 @@ public class AllWindowTranslationTest {
try {
windowedStream.fold("", new FoldFunction<String, String>() {
+ private static final long serialVersionUID = -8722899157560218917L;
+
@Override
public String fold(String accumulator, String value) throws Exception {
return accumulator;
@@ -278,6 +280,8 @@ public class AllWindowTranslationTest {
try {
windowedStream.trigger(new Trigger<String, TimeWindow>() {
+ private static final long serialVersionUID = 8360971631424870421L;
+
@Override
public TriggerResult onElement(String element,
long timestamp,
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
index 39d89cf..2707108 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
@@ -76,6 +76,8 @@ public class WindowTranslationTest {
.keyBy(0)
.window(SlidingEventTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS)))
.reduce(new RichReduceFunction<Tuple2<String, Integer>>() {
+ private static final long serialVersionUID = -6448847205314995812L;
+
@Override
public Tuple2<String, Integer> reduce(Tuple2<String, Integer> value1,
Tuple2<String, Integer> value2) throws Exception {
@@ -242,6 +244,8 @@ public class WindowTranslationTest {
WindowedStream<String, String, TimeWindow> windowedStream = env.fromElements("Hello", "Ciao")
.keyBy(new KeySelector<String, String>() {
+ private static final long serialVersionUID = -3298887124448443076L;
+
@Override
public String getKey(String value) throws Exception {
return value;
@@ -251,6 +255,8 @@ public class WindowTranslationTest {
try {
windowedStream.fold("", new FoldFunction<String, String>() {
+ private static final long serialVersionUID = -4567902917104921706L;
+
@Override
public String fold(String accumulator, String value) throws Exception {
return accumulator;
@@ -273,6 +279,8 @@ public class WindowTranslationTest {
WindowedStream<String, String, TimeWindow> windowedStream = env.fromElements("Hello", "Ciao")
.keyBy(new KeySelector<String, String>() {
+ private static final long serialVersionUID = 598309916882894293L;
+
@Override
public String getKey(String value) throws Exception {
return value;
@@ -282,6 +290,8 @@ public class WindowTranslationTest {
try {
windowedStream.trigger(new Trigger<String, TimeWindow>() {
+ private static final long serialVersionUID = 6558046711583024443L;
+
@Override
public TriggerResult onElement(String element,
long timestamp,
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/HashPartitionerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/HashPartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/HashPartitionerTest.java
deleted file mode 100644
index 6dbf932..0000000
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/HashPartitionerTest.java
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.runtime.partitioner;
-
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-
-import org.apache.flink.api.java.functions.KeySelector;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.runtime.plugable.SerializationDelegate;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.junit.Before;
-import org.junit.Test;
-
-public class HashPartitionerTest {
-
- private HashPartitioner<Tuple2<String, Integer>> hashPartitioner;
- private StreamRecord<Tuple2<String, Integer>> streamRecord1 = new StreamRecord<Tuple2<String, Integer>>(new Tuple2<String, Integer>("test", 0));
- private StreamRecord<Tuple2<String, Integer>> streamRecord2 = new StreamRecord<Tuple2<String, Integer>>(new Tuple2<String, Integer>("test", 42));
- private SerializationDelegate<StreamRecord<Tuple2<String, Integer>>> sd1 = new SerializationDelegate<StreamRecord<Tuple2<String, Integer>>>(null);
- private SerializationDelegate<StreamRecord<Tuple2<String, Integer>>> sd2 = new SerializationDelegate<StreamRecord<Tuple2<String, Integer>>>(null);
-
- @Before
- public void setPartitioner() {
- hashPartitioner = new HashPartitioner<Tuple2<String, Integer>>(new KeySelector<Tuple2<String, Integer>, String>() {
-
- private static final long serialVersionUID = 1L;
-
- @Override
- public String getKey(Tuple2<String, Integer> value) throws Exception {
- return value.getField(0);
- }
- });
- }
-
- @Test
- public void testSelectChannelsLength() {
- sd1.setInstance(streamRecord1);
- assertEquals(1, hashPartitioner.selectChannels(sd1, 1).length);
- assertEquals(1, hashPartitioner.selectChannels(sd1, 2).length);
- assertEquals(1, hashPartitioner.selectChannels(sd1, 1024).length);
- }
-
- @Test
- public void testSelectChannelsGrouping() {
- sd1.setInstance(streamRecord1);
- sd2.setInstance(streamRecord2);
-
- assertArrayEquals(hashPartitioner.selectChannels(sd1, 1),
- hashPartitioner.selectChannels(sd2, 1));
- assertArrayEquals(hashPartitioner.selectChannels(sd1, 2),
- hashPartitioner.selectChannels(sd2, 2));
- assertArrayEquals(hashPartitioner.selectChannels(sd1, 1024),
- hashPartitioner.selectChannels(sd2, 1024));
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java
new file mode 100644
index 0000000..6fbf35e
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.partitioner;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.plugable.SerializationDelegate;
+import org.apache.flink.runtime.state.HashKeyGroupAssigner;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.TestLogger;
+import org.junit.Before;
+import org.junit.Test;
+
+public class KeyGroupStreamPartitionerTest extends TestLogger {
+
+ private KeyGroupStreamPartitioner<Tuple2<String, Integer>, String> keyGroupPartitioner;
+ private StreamRecord<Tuple2<String, Integer>> streamRecord1 = new StreamRecord<Tuple2<String, Integer>>(new Tuple2<String, Integer>("test", 0));
+ private StreamRecord<Tuple2<String, Integer>> streamRecord2 = new StreamRecord<Tuple2<String, Integer>>(new Tuple2<String, Integer>("test", 42));
+ private SerializationDelegate<StreamRecord<Tuple2<String, Integer>>> sd1 = new SerializationDelegate<StreamRecord<Tuple2<String, Integer>>>(null);
+ private SerializationDelegate<StreamRecord<Tuple2<String, Integer>>> sd2 = new SerializationDelegate<StreamRecord<Tuple2<String, Integer>>>(null);
+
+ @Before
+ public void setPartitioner() {
+ keyGroupPartitioner = new KeyGroupStreamPartitioner<Tuple2<String, Integer>, String>(new KeySelector<Tuple2<String, Integer>, String>() {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public String getKey(Tuple2<String, Integer> value) throws Exception {
+ return value.getField(0);
+ }
+ },
+ new HashKeyGroupAssigner<String>(1024));
+ }
+
+ @Test
+ public void testSelectChannelsLength() {
+ sd1.setInstance(streamRecord1);
+ assertEquals(1, keyGroupPartitioner.selectChannels(sd1, 1).length);
+ assertEquals(1, keyGroupPartitioner.selectChannels(sd1, 2).length);
+ assertEquals(1, keyGroupPartitioner.selectChannels(sd1, 1024).length);
+ }
+
+ @Test
+ public void testSelectChannelsGrouping() {
+ sd1.setInstance(streamRecord1);
+ sd2.setInstance(streamRecord2);
+
+ assertArrayEquals(keyGroupPartitioner.selectChannels(sd1, 1),
+ keyGroupPartitioner.selectChannels(sd2, 1));
+ assertArrayEquals(keyGroupPartitioner.selectChannels(sd1, 2),
+ keyGroupPartitioner.selectChannels(sd2, 2));
+ assertArrayEquals(keyGroupPartitioner.selectChannels(sd1, 1024),
+ keyGroupPartitioner.selectChannels(sd2, 1024));
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java
index 8c7360a..37ea68a 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java
@@ -94,6 +94,8 @@ public class RescalePartitionerTest extends TestLogger {
// get input data
DataStream<String> text = env.addSource(new ParallelSourceFunction<String>() {
+ private static final long serialVersionUID = 7772338606389180774L;
+
@Override
public void run(SourceContext<String> ctx) throws Exception {
@@ -108,6 +110,8 @@ public class RescalePartitionerTest extends TestLogger {
DataStream<Tuple2<String, Integer>> counts = text
.rescale()
.flatMap(new FlatMapFunction<String, Tuple2<String, Integer>>() {
+ private static final long serialVersionUID = -5255930322161596829L;
+
@Override
public void flatMap(String value,
Collector<Tuple2<String, Integer>> out) throws Exception {
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
index 145edc2..5f73e25 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.java.ClosureCleaner;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.runtime.event.AbstractEvent;
import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate;
+import org.apache.flink.runtime.state.HashKeyGroupAssigner;
import java.io.IOException;
@@ -105,6 +106,7 @@ public class OneInputStreamTaskTestHarness<IN, OUT> extends StreamTaskTestHarnes
ClosureCleaner.clean(keySelector, false);
streamConfig.setStatePartitioner(0, keySelector);
streamConfig.setStateKeySerializer(keyType.createSerializer(executionConfig));
+ streamConfig.setKeyGroupAssigner(new HashKeyGroupAssigner<K>(10));
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index bcd8a5f..3d9d50f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -19,8 +19,11 @@
package org.apache.flink.streaming.runtime.tasks;
import akka.actor.ActorRef;
+
+import akka.dispatch.Futures;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.blob.BlobKey;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
@@ -40,9 +43,11 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
+import org.apache.flink.runtime.messages.TaskMessages;
import org.apache.flink.runtime.query.TaskKvStateRegistry;
import org.apache.flink.runtime.taskmanager.Task;
import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
+import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.Output;
@@ -51,17 +56,27 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.SerializedValue;
import org.junit.Test;
+
+import scala.concurrent.Await;
import scala.concurrent.ExecutionContext;
import scala.concurrent.Future;
+import scala.concurrent.duration.Deadline;
import scala.concurrent.duration.FiniteDuration;
+import scala.concurrent.impl.Promise;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.net.URL;
import java.util.Collections;
+import java.util.Comparator;
+import java.util.PriorityQueue;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
@@ -72,55 +87,140 @@ import static org.mockito.Mockito.when;
public class StreamTaskTest {
- /**
+ /**
* This test checks that cancel calls that are issued before the operator is
* instantiated still lead to proper canceling.
*/
@Test
- public void testEarlyCanceling() {
- try {
- StreamConfig cfg = new StreamConfig(new Configuration());
- cfg.setStreamOperator(new SlowlyDeserializingOperator());
-
- Task task = createTask(SourceStreamTask.class, cfg);
- task.startTaskThread();
-
- // wait until the task thread reached state RUNNING
- while (task.getExecutionState() == ExecutionState.CREATED ||
- task.getExecutionState() == ExecutionState.DEPLOYING)
- {
- Thread.sleep(5);
- }
-
- // make sure the task is really running
- if (task.getExecutionState() != ExecutionState.RUNNING) {
- fail("Task entered state " + task.getExecutionState() + " with error "
- + ExceptionUtils.stringifyException(task.getFailureCause()));
- }
-
- // send a cancel. because the operator takes a long time to deserialize, this should
- // hit the task before the operator is deserialized
- task.cancelExecution();
-
- // the task should reach state canceled eventually
- assertTrue(task.getExecutionState() == ExecutionState.CANCELING ||
- task.getExecutionState() == ExecutionState.CANCELED);
-
- task.getExecutingThread().join(60000);
-
- assertFalse("Task did not cancel", task.getExecutingThread().isAlive());
- assertEquals(ExecutionState.CANCELED, task.getExecutionState());
- }
- catch (Exception e) {
- e.printStackTrace();
- fail(e.getMessage());
+ public void testEarlyCanceling() throws Exception {
+ Deadline deadline = new FiniteDuration(2, TimeUnit.MINUTES).fromNow();
+ StreamConfig cfg = new StreamConfig(new Configuration());
+ cfg.setStreamOperator(new SlowlyDeserializingOperator());
+ cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+ Task task = createTask(SourceStreamTask.class, cfg);
+
+ ExecutionStateListener executionStateListener = new ExecutionStateListener();
+
+ task.registerExecutionListener(executionStateListener);
+ task.startTaskThread();
+
+ Future<ExecutionState> running = executionStateListener.notifyWhenExecutionState(ExecutionState.RUNNING);
+
+ // wait until the task thread reached state RUNNING
+ ExecutionState executionState = Await.result(running, deadline.timeLeft());
+
+ // make sure the task is really running
+ if (executionState != ExecutionState.RUNNING) {
+ fail("Task entered state " + task.getExecutionState() + " with error "
+ + ExceptionUtils.stringifyException(task.getFailureCause()));
}
+
+ // send a cancel. because the operator takes a long time to deserialize, this should
+ // hit the task before the operator is deserialized
+ task.cancelExecution();
+
+ Future<ExecutionState> canceling = executionStateListener.notifyWhenExecutionState(ExecutionState.CANCELING);
+
+ executionState = Await.result(canceling, deadline.timeLeft());
+
+ // the task should reach state canceled eventually
+ assertTrue(executionState == ExecutionState.CANCELING ||
+ executionState == ExecutionState.CANCELED);
+
+ task.getExecutingThread().join(deadline.timeLeft().toMillis());
+
+ assertFalse("Task did not cancel", task.getExecutingThread().isAlive());
+ assertEquals(ExecutionState.CANCELED, task.getExecutionState());
}
+
// ------------------------------------------------------------------------
// Test Utilities
// ------------------------------------------------------------------------
+ private static class ExecutionStateListener implements ActorGateway {
+
+ private static final long serialVersionUID = 8926442805035692182L;
+
+ ExecutionState executionState = null;
+
+ PriorityQueue<Tuple2<ExecutionState, Promise<ExecutionState>>> priorityQueue = new PriorityQueue<>(
+ 1,
+ new Comparator<Tuple2<ExecutionState, Promise<ExecutionState>>>() {
+ @Override
+ public int compare(Tuple2<ExecutionState, Promise<ExecutionState>> o1, Tuple2<ExecutionState, Promise<ExecutionState>> o2) {
+ return o1.f0.ordinal() - o2.f0.ordinal();
+ }
+ });
+
+ public Future<ExecutionState> notifyWhenExecutionState(ExecutionState executionState) {
+ synchronized (priorityQueue) {
+ if (this.executionState != null && this.executionState.ordinal() >= executionState.ordinal()) {
+ return Futures.successful(executionState);
+ } else {
+ Promise<ExecutionState> promise = new Promise.DefaultPromise<ExecutionState>();
+
+ priorityQueue.offer(Tuple2.of(executionState, promise));
+
+ return promise.future();
+ }
+ }
+ }
+
+ @Override
+ public Future<Object> ask(Object message, FiniteDuration timeout) {
+ return null;
+ }
+
+ @Override
+ public void tell(Object message) {
+ this.tell(message, null);
+ }
+
+ @Override
+ public void tell(Object message, ActorGateway sender) {
+ if (message instanceof TaskMessages.UpdateTaskExecutionState) {
+ TaskMessages.UpdateTaskExecutionState updateTaskExecutionState = (TaskMessages.UpdateTaskExecutionState) message;
+
+ synchronized (priorityQueue) {
+ this.executionState = updateTaskExecutionState.taskExecutionState().getExecutionState();
+
+ while (!priorityQueue.isEmpty() && priorityQueue.peek().f0.ordinal() <= this.executionState.ordinal()) {
+ Promise<ExecutionState> promise = priorityQueue.poll().f1;
+
+ promise.success(this.executionState);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void forward(Object message, ActorGateway sender) {
+
+ }
+
+ @Override
+ public Future<Object> retry(Object message, int numberRetries, FiniteDuration timeout, ExecutionContext executionContext) {
+ return null;
+ }
+
+ @Override
+ public String path() {
+ return null;
+ }
+
+ @Override
+ public ActorRef actor() {
+ return null;
+ }
+
+ @Override
+ public UUID leaderSessionID() {
+ return null;
+ }
+ }
+
private Task createTask(Class<? extends AbstractInvokable> invokable, StreamConfig taskConfig) throws Exception {
LibraryCacheManager libCache = mock(LibraryCacheManager.class);
when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
index 00e95b9..cb10c5c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
@@ -186,23 +186,58 @@ public class StreamTaskTestHarness<OUT> {
taskThread.start();
}
+ /**
+ * Waits for the task completion.
+ *
+ * @throws Exception
+ */
public void waitForTaskCompletion() throws Exception {
+ waitForTaskCompletion(Long.MAX_VALUE);
+ }
+
+ /**
+ * Waits for the task completion. If this does not happen within the timeout, then a
+ * TimeoutException is thrown.
+ *
+ * @param timeout Timeout for the task completion
+ * @throws Exception
+ */
+ public void waitForTaskCompletion(long timeout) throws Exception {
if (taskThread == null) {
throw new IllegalStateException("Task thread was not started.");
}
- taskThread.join();
+ taskThread.join(timeout);
if (taskThread.getError() != null) {
throw new Exception("error in task", taskThread.getError());
}
}
+ /**
+ * Waits for the task to be running.
+ *
+ * @throws Exception
+ */
public void waitForTaskRunning() throws Exception {
+ waitForTaskRunning(Long.MAX_VALUE);
+ }
+
+ /**
+ * Waits fro the task to be running. If this does not happen within the timeout, then a
+ * TimeoutException is thrown.
+ *
+ * @param timeout Timeout for the task to be running.
+ * @throws Exception
+ */
+ public void waitForTaskRunning(long timeout) throws Exception {
if (taskThread == null) {
throw new IllegalStateException("Task thread was not started.");
}
else {
if (taskThread.task instanceof StreamTask) {
+ long base = System.currentTimeMillis();
+ long now = 0;
+
StreamTask<?, ?> streamTask = (StreamTask<?, ?>) taskThread.task;
while (!streamTask.isRunning()) {
Thread.sleep(100);
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/resources/log4j-test.properties
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/resources/log4j-test.properties b/flink-streaming-java/src/test/resources/log4j-test.properties
index 0b686e5..881dc06 100644
--- a/flink-streaming-java/src/test/resources/log4j-test.properties
+++ b/flink-streaming-java/src/test/resources/log4j-test.properties
@@ -24,4 +24,4 @@ log4j.appender.A1=org.apache.log4j.ConsoleAppender
# A1 uses PatternLayout.
log4j.appender.A1.layout=org.apache.log4j.PatternLayout
-log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n
\ No newline at end of file
+log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
----------------------------------------------------------------------
diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
index 8693834..4fe73e9 100644
--- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
+++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
@@ -133,6 +133,17 @@ class DataStream[T](stream: JavaStream[T]) {
this
}
+ def setMaxParallelism(maxParallelism: Int): DataStream[T] = {
+ stream match {
+ case ds: SingleOutputStreamOperator[T] => ds.setMaxParallelism(maxParallelism)
+ case _ =>
+ throw new UnsupportedOperationException("Operator " + stream + " cannot set the maximum" +
+ "paralllelism")
+ }
+
+ this
+ }
+
/**
* Gets the name of the current data stream. This name is
* used by the visualization and logging during runtime.
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala
index 9cb36a5..2e432ba 100644
--- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala
+++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala
@@ -59,12 +59,30 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) {
}
/**
+ * Sets the maximum degree of parallelism defined for the program.
+ * The maximum degree of parallelism specifies the upper limit for dynamic scaling. It also
+ * defines the number of key groups used for partitioned state.
+ **/
+ def setMaxParallelism(maxParallelism: Int): Unit = {
+ javaEnv.setMaxParallelism(maxParallelism)
+ }
+
+ /**
* Returns the default parallelism for this execution environment. Note that this
* value can be overridden by individual operations using [[DataStream#setParallelism(int)]]
*/
def getParallelism = javaEnv.getParallelism
/**
+ * Returns the maximum degree of parallelism defined for the program.
+ *
+ * The maximum degree of parallelism specifies the upper limit for dynamic scaling. It also
+ * defines the number of key groups used for partitioned state.
+ *
+ */
+ def getMaxParallelism = javaEnv.getMaxParallelism
+
+ /**
* Sets the maximum time frequency (milliseconds) for the flushing of the
* output buffers. By default the output buffers flush frequently to provide
* low latency and to aid smooth developer experience. Setting the parameter
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala
----------------------------------------------------------------------
diff --git a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala
index 16fcfc3..b73eae8 100644
--- a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala
+++ b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala
@@ -512,7 +512,7 @@ class DataStreamTest extends StreamingMultipleProgramsTestBase {
private def isPartitioned(edges: java.util.List[StreamEdge]): Boolean = {
import scala.collection.JavaConverters._
- edges.asScala.forall( _.getPartitioner.isInstanceOf[HashPartitioner[_]])
+ edges.asScala.forall( _.getPartitioner.isInstanceOf[KeyGroupStreamPartitioner[_, _]])
}
private def isCustomPartitioned(edges: java.util.List[StreamEdge]): Boolean = {
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeAllWindowCheckpointingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeAllWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeAllWindowCheckpointingITCase.java
index 6faee45..163fb42 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeAllWindowCheckpointingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeAllWindowCheckpointingITCase.java
@@ -162,7 +162,7 @@ public class EventTimeAllWindowCheckpointingITCase extends TestLogger {
env.setParallelism(PARALLELISM);
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
env.enableCheckpointing(100);
- env.setRestartStrategy(RestartStrategies.fixedDelayRestart(3, 0));
+ env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 0));
env.getConfig().disableSysoutLogging();
env
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
new file mode 100644
index 0000000..0de2a75
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing;
+
+import io.netty.util.internal.ConcurrentSet;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.ConfigConstants;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.checkpoint.savepoint.SavepointStoreFactory;
+import org.apache.flink.runtime.client.JobExecutionException;
+import org.apache.flink.runtime.execution.SuppressRestartsException;
+import org.apache.flink.runtime.instance.ActorGateway;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.messages.JobManagerMessages;
+import org.apache.flink.runtime.state.HashKeyGroupAssigner;
+import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
+import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages;
+import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.test.util.ForkableFlinkMiniCluster;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.TestLogger;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import scala.concurrent.Await;
+import scala.concurrent.Future;
+import scala.concurrent.duration.Deadline;
+import scala.concurrent.duration.FiniteDuration;
+
+import java.io.File;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public class RescalingITCase extends TestLogger {
+
+ private static int numTaskManagers = 2;
+ private static int slotsPerTaskManager = 2;
+ private static int numSlots = numTaskManagers * slotsPerTaskManager;
+
+ private static ForkableFlinkMiniCluster cluster;
+
+ @ClassRule
+ public static TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+ @BeforeClass
+ public static void setup() throws Exception {
+ Configuration config = new Configuration();
+ config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, numTaskManagers);
+ config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, slotsPerTaskManager);
+
+ final File checkpointDir = temporaryFolder.newFolder();
+ final File savepointDir = temporaryFolder.newFolder();
+
+ config.setString(ConfigConstants.STATE_BACKEND, "filesystem");
+ config.setString(FsStateBackendFactory.CHECKPOINT_DIRECTORY_URI_CONF_KEY, checkpointDir.toURI().toString());
+ config.setString(SavepointStoreFactory.SAVEPOINT_BACKEND_KEY, "filesystem");
+ config.setString(SavepointStoreFactory.SAVEPOINT_DIRECTORY_KEY, savepointDir.toURI().toString());
+
+ cluster = new ForkableFlinkMiniCluster(config);
+ cluster.start();
+ }
+
+ @AfterClass
+ public static void teardown() {
+ if (cluster != null) {
+ cluster.shutdown();
+ }
+ }
+
+ /**
+ * Tests that a a job with purely partitioned state can be restarted from a savepoint
+ * with a different parallelism.
+ */
+ @Test
+ public void testSavepointRescalingWithPartitionedState() throws Exception {
+ int numberKeys = 42;
+ int numberElements = 1000;
+ int numberElements2 = 500;
+ int parallelism = numSlots / 2;
+ int parallelism2 = numSlots;
+ int maxParallelism = 13;
+
+ FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
+ Deadline deadline = timeout.fromNow();
+
+ ActorGateway jobManager = null;
+ JobID jobID = null;
+
+ try {
+ jobManager = cluster.getLeaderGateway(deadline.timeLeft());
+
+ JobGraph jobGraph = createPartitionedStateJobGraph(parallelism, maxParallelism, numberKeys, numberElements, false, 100);
+
+ jobID = jobGraph.getJobID();
+
+ cluster.submitJobDetached(jobGraph);
+
+ // wait til the sources have emitted numberElements for each key and completed a checkpoint
+ SubtaskIndexFlatMapper.workCompletedLatch.await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
+
+ // verify the current state
+
+ Set<Tuple2<Integer, Integer>> actualResult = CollectionSink.getElementsSet();
+
+ Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
+
+ HashKeyGroupAssigner<Integer> keyGroupAssigner = new HashKeyGroupAssigner<>(maxParallelism);
+
+ for (int key = 0; key < numberKeys; key++) {
+ int keyGroupIndex = keyGroupAssigner.getKeyGroupIndex(key);
+
+ expectedResult.add(Tuple2.of(keyGroupIndex % parallelism, numberElements * key));
+ }
+
+ assertEquals(expectedResult, actualResult);
+
+ // clear the CollectionSink set for the restarted job
+ CollectionSink.clearElementsSet();
+
+ Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID), deadline.timeLeft());
+
+ final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)
+ Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
+
+ Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
+
+ Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
+
+ Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
+
+ assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
+
+ Await.ready(jobRemovedFuture, deadline.timeLeft());
+
+ jobID = null;
+
+ JobGraph scaledJobGraph = createPartitionedStateJobGraph(parallelism2, maxParallelism, numberKeys, numberElements2, true, 100);
+
+ scaledJobGraph.setSavepointPath(savepointPath);
+
+ jobID = scaledJobGraph.getJobID();
+
+ cluster.submitJobAndWait(scaledJobGraph, false);
+
+ jobID = null;
+
+ Set<Tuple2<Integer, Integer>> actualResult2 = CollectionSink.getElementsSet();
+
+ Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
+
+ HashKeyGroupAssigner<Integer> keyGroupAssigner2 = new HashKeyGroupAssigner<>(maxParallelism);
+
+ for (int key = 0; key < numberKeys; key++) {
+ int keyGroupIndex = keyGroupAssigner2.getKeyGroupIndex(key);
+ expectedResult2.add(Tuple2.of(keyGroupIndex % parallelism2, key * (numberElements + numberElements2)));
+ }
+
+ assertEquals(expectedResult2, actualResult2);
+
+ } finally {
+ // clear the CollectionSink set for the restarted job
+ CollectionSink.clearElementsSet();
+
+ // clear any left overs from a possibly failed job
+ if (jobID != null && jobManager != null) {
+ Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
+
+ try {
+ Await.ready(jobRemovedFuture, timeout);
+ } catch (TimeoutException | InterruptedException ie) {
+ fail("Failed while cleaning up the cluster.");
+ }
+ }
+ }
+ }
+
+ /**
+ * Tests that a job cannot be restarted from a savepoint with a different parallelism if the
+ * rescaled operator has non-partitioned state.
+ *
+ * @throws Exception
+ */
+ @Test
+ public void testSavepointRescalingFailureWithNonPartitionedState() throws Exception {
+ int parallelism = numSlots / 2;
+ int parallelism2 = numSlots;
+ int maxParallelism = 13;
+
+ FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
+ Deadline deadline = timeout.fromNow();
+
+ JobID jobID = null;
+ ActorGateway jobManager = null;
+
+ try {
+ jobManager = cluster.getLeaderGateway(deadline.timeLeft());
+
+ JobGraph jobGraph = createNonPartitionedStateJobGraph(parallelism, maxParallelism, 500);
+
+ jobID = jobGraph.getJobID();
+
+ cluster.submitJobDetached(jobGraph);
+
+ Future<Object> allTasksRunning = jobManager.ask(new TestingJobManagerMessages.WaitForAllVerticesToBeRunning(jobID), deadline.timeLeft());
+
+ Await.ready(allTasksRunning, deadline.timeLeft());
+
+ Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID), deadline.timeLeft());
+
+ Object savepointResponse = Await.result(savepointPathFuture, deadline.timeLeft());
+
+ assertTrue(savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess);
+
+ final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)savepointResponse).savepointPath();
+
+ Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
+
+ Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
+
+ Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
+
+ assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
+
+ Await.ready(jobRemovedFuture, deadline.timeLeft());
+
+ // job successfully removed
+ jobID = null;
+
+ JobGraph scaledJobGraph = createNonPartitionedStateJobGraph(parallelism2, maxParallelism, 500);
+
+ scaledJobGraph.setSavepointPath(savepointPath);
+
+ jobID = scaledJobGraph.getJobID();
+
+ cluster.submitJobAndWait(scaledJobGraph, false);
+
+ jobID = null;
+
+ } catch (JobExecutionException exception) {
+ if (exception.getCause() instanceof SuppressRestartsException) {
+ SuppressRestartsException suppressRestartsException = (SuppressRestartsException) exception.getCause();
+
+ if (suppressRestartsException.getCause() instanceof IllegalStateException) {
+ // we expect a IllegalStateException wrapped in a SuppressRestartsException wrapped
+ // in a JobExecutionException, because the job containing non-partitioned state
+ // is being rescaled
+ } else {
+ throw exception;
+ }
+ } else {
+ throw exception;
+ }
+ } finally {
+ // clear any left overs from a possibly failed job
+ if (jobID != null && jobManager != null) {
+ Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
+
+ try {
+ Await.ready(jobRemovedFuture, timeout);
+ } catch (TimeoutException | InterruptedException ie) {
+ fail("Failed while cleaning up the cluster.");
+ }
+ }
+ }
+ }
+
+ /**
+ * Tests that a job with non partitioned state can be restarted from a savepoint with a
+ * different parallelism if the operator with non-partitioned state are not rescaled.
+ *
+ * @throws Exception
+ */
+ @Test
+ public void testSavepointRescalingWithPartiallyNonPartitionedState() throws Exception {
+ int numberKeys = 42;
+ int numberElements = 1000;
+ int numberElements2 = 500;
+ int parallelism = numSlots / 2;
+ int parallelism2 = numSlots;
+ int maxParallelism = 13;
+
+ FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
+ Deadline deadline = timeout.fromNow();
+
+ ActorGateway jobManager = null;
+ JobID jobID = null;
+
+ try {
+ jobManager = cluster.getLeaderGateway(deadline.timeLeft());
+
+ JobGraph jobGraph = createPartitionedNonPartitionedStateJobGraph(
+ parallelism,
+ maxParallelism,
+ parallelism,
+ numberKeys,
+ numberElements,
+ false,
+ 100);
+
+ jobID = jobGraph.getJobID();
+
+ cluster.submitJobDetached(jobGraph);
+
+ // wait til the sources have emitted numberElements for each key and completed a checkpoint
+ SubtaskIndexFlatMapper.workCompletedLatch.await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
+
+ // verify the current state
+
+ Set<Tuple2<Integer, Integer>> actualResult = CollectionSink.getElementsSet();
+
+ Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
+
+ HashKeyGroupAssigner<Integer> keyGroupAssigner = new HashKeyGroupAssigner<>(maxParallelism);
+
+ for (int key = 0; key < numberKeys; key++) {
+ int keyGroupIndex = keyGroupAssigner.getKeyGroupIndex(key);
+
+ expectedResult.add(Tuple2.of(keyGroupIndex % parallelism, numberElements * key));
+ }
+
+ assertEquals(expectedResult, actualResult);
+
+ // clear the CollectionSink set for the restarted job
+ CollectionSink.clearElementsSet();
+
+ Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID), deadline.timeLeft());
+
+ final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)
+ Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
+
+ Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
+
+ Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
+
+ Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
+
+ assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
+
+ Await.ready(jobRemovedFuture, deadline.timeLeft());
+
+ jobID = null;
+
+ JobGraph scaledJobGraph = createPartitionedNonPartitionedStateJobGraph(
+ parallelism2,
+ maxParallelism,
+ parallelism,
+ numberKeys,
+ numberElements + numberElements2,
+ true,
+ 100);
+
+ scaledJobGraph.setSavepointPath(savepointPath);
+
+ jobID = scaledJobGraph.getJobID();
+
+ cluster.submitJobAndWait(scaledJobGraph, false);
+
+ jobID = null;
+
+ Set<Tuple2<Integer, Integer>> actualResult2 = CollectionSink.getElementsSet();
+
+ Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
+
+ HashKeyGroupAssigner<Integer> keyGroupAssigner2 = new HashKeyGroupAssigner<>(maxParallelism);
+
+ for (int key = 0; key < numberKeys; key++) {
+ int keyGroupIndex = keyGroupAssigner2.getKeyGroupIndex(key);
+ expectedResult2.add(Tuple2.of(keyGroupIndex % parallelism2, key * (numberElements + numberElements2)));
+ }
+
+ assertEquals(expectedResult2, actualResult2);
+
+ } finally {
+ // clear the CollectionSink set for the restarted job
+ CollectionSink.clearElementsSet();
+
+ // clear any left overs from a possibly failed job
+ if (jobID != null && jobManager != null) {
+ Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
+
+ try {
+ Await.ready(jobRemovedFuture, timeout);
+ } catch (TimeoutException | InterruptedException ie) {
+ fail("Failed while cleaning up the cluster.");
+ }
+ }
+ }
+ }
+
+ private static JobGraph createNonPartitionedStateJobGraph(int parallelism, int maxParallelism, long checkpointInterval) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(parallelism);
+ env.getConfig().setMaxParallelism(maxParallelism);
+ env.enableCheckpointing(checkpointInterval);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+
+ DataStream<Integer> input = env.addSource(new NonPartitionedStateSource());
+
+ input.addSink(new DiscardingSink<Integer>());
+
+ return env.getStreamGraph().getJobGraph();
+ }
+
+ private static JobGraph createPartitionedStateJobGraph(
+ int parallelism,
+ int maxParallelism,
+ int numberKeys,
+ int numberElements,
+ boolean terminateAfterEmission,
+ int checkpointingInterval) {
+
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(parallelism);
+ env.getConfig().setMaxParallelism(maxParallelism);
+ env.enableCheckpointing(checkpointingInterval);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+
+ DataStream<Integer> input = env.addSource(new SubtaskIndexSource(
+ numberKeys,
+ numberElements,
+ terminateAfterEmission))
+ .keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = -7952298871120320940L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ });
+
+ SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
+
+ DataStream<Tuple2<Integer, Integer>> result = input.flatMap(new SubtaskIndexFlatMapper(numberElements));
+
+ result.addSink(new CollectionSink());
+
+ return env.getStreamGraph().getJobGraph();
+ }
+
+ private static JobGraph createPartitionedNonPartitionedStateJobGraph(
+ int parallelism,
+ int maxParallelism,
+ int fixedParallelism,
+ int numberKeys,
+ int numberElements,
+ boolean terminateAfterEmission,
+ int checkpointingInterval) {
+
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(parallelism);
+ env.getConfig().setMaxParallelism(maxParallelism);
+ env.enableCheckpointing(checkpointingInterval);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+
+ DataStream<Integer> input = env.addSource(new SubtaskIndexNonPartitionedStateSource(
+ numberKeys,
+ numberElements,
+ terminateAfterEmission))
+ .setParallelism(fixedParallelism)
+ .keyBy(new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID = -7952298871120320940L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ });
+
+ SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
+
+ DataStream<Tuple2<Integer, Integer>> result = input.flatMap(new SubtaskIndexFlatMapper(numberElements));
+
+ result.addSink(new CollectionSink());
+
+ return env.getStreamGraph().getJobGraph();
+ }
+
+ private static class SubtaskIndexSource
+ extends RichParallelSourceFunction<Integer> {
+
+ private static final long serialVersionUID = -400066323594122516L;
+
+ private final int numberKeys;
+ private final int numberElements;
+ private final boolean terminateAfterEmission;
+
+ protected int counter = 0;
+
+ private boolean running = true;
+
+ SubtaskIndexSource(
+ int numberKeys,
+ int numberElements,
+ boolean terminateAfterEmission) {
+
+ this.numberKeys = numberKeys;
+ this.numberElements = numberElements;
+ this.terminateAfterEmission = terminateAfterEmission;
+ }
+
+ @Override
+ public void run(SourceContext<Integer> ctx) throws Exception {
+ final Object lock = ctx.getCheckpointLock();
+ final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
+
+ while (running) {
+
+ if (counter < numberElements) {
+ synchronized (lock) {
+ for (int value = subtaskIndex;
+ value < numberKeys;
+ value += getRuntimeContext().getNumberOfParallelSubtasks()) {
+
+ ctx.collect(value);
+ }
+
+ counter++;
+ }
+ } else {
+ if (terminateAfterEmission) {
+ running = false;
+ } else {
+ Thread.sleep(100);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void cancel() {
+ running = false;
+ }
+ }
+
+ private static class SubtaskIndexNonPartitionedStateSource extends SubtaskIndexSource implements Checkpointed<Integer> {
+
+ private static final long serialVersionUID = 8388073059042040203L;
+
+ SubtaskIndexNonPartitionedStateSource(int numberKeys, int numberElements, boolean terminateAfterEmission) {
+ super(numberKeys, numberElements, terminateAfterEmission);
+ }
+
+ @Override
+ public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
+ return counter;
+ }
+
+ @Override
+ public void restoreState(Integer state) throws Exception {
+ counter = state;
+ }
+ }
+
+ private static class SubtaskIndexFlatMapper extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>> {
+
+ private static final long serialVersionUID = 5273172591283191348L;
+
+ private static volatile CountDownLatch workCompletedLatch = new CountDownLatch(1);
+
+ private transient ValueState<Integer> counter;
+ private transient ValueState<Integer> sum;
+
+ private final int numberElements;
+
+ SubtaskIndexFlatMapper(int numberElements) {
+ this.numberElements = numberElements;
+ }
+
+ @Override
+ public void open(Configuration configuration) {
+ counter = getRuntimeContext().getState(new ValueStateDescriptor<Integer>("counter", Integer.class, 0));
+ sum = getRuntimeContext().getState(new ValueStateDescriptor<Integer>("sum", Integer.class, 0));
+ }
+
+ @Override
+ public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out) throws Exception {
+ int count = counter.value() + 1;
+ counter.update(count);
+
+ int s = sum.value() + value;
+ sum.update(s);
+
+ if (count % numberElements == 0) {
+ out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s));
+ workCompletedLatch.countDown();
+ }
+ }
+ }
+
+ private static class CollectionSink<IN> implements SinkFunction<IN> {
+
+ private static ConcurrentSet<Object> elements = new ConcurrentSet<Object>();
+
+ private static final long serialVersionUID = -1652452958040267745L;
+
+ public static <IN> Set<IN> getElementsSet() {
+ return (Set<IN>) elements;
+ }
+
+ public static void clearElementsSet() {
+ elements.clear();
+ }
+
+ @Override
+ public void invoke(IN value) throws Exception {
+ elements.add(value);
+ }
+ }
+
+ private static class NonPartitionedStateSource extends RichParallelSourceFunction<Integer> implements Checkpointed<Integer> {
+
+ private static final long serialVersionUID = -8108185918123186841L;
+
+ private int counter = 0;
+ private boolean running = true;
+
+ @Override
+ public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
+ return counter;
+ }
+
+ @Override
+ public void restoreState(Integer state) throws Exception {
+ counter = state;
+ }
+
+ @Override
+ public void run(SourceContext<Integer> ctx) throws Exception {
+ final Object lock = ctx.getCheckpointLock();
+
+ while (running) {
+ synchronized (lock) {
+ counter++;
+
+ ctx.collect(counter * getRuntimeContext().getIndexOfThisSubtask());
+ }
+
+ Thread.sleep(100);
+ }
+ }
+
+ @Override
+ public void cancel() {
+ running = true;
+ }
+ }
+}