You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2016/08/31 17:28:39 UTC

[21/27] flink git commit: [FLINK-4380] Remove KeyGroupAssigner in favor of static method/Have default max. parallelism at 128

http://git-wip-us.apache.org/repos/asf/flink/blob/6d430618/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
index 277fab4..6259598 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
@@ -23,13 +23,11 @@ import java.util.Random;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.MapFunction;
-import org.apache.flink.api.common.state.KeyGroupAssigner;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
-import org.apache.flink.runtime.state.HashKeyGroupAssigner;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
@@ -161,202 +159,4 @@ public class StreamingJobGraphGeneratorTest extends TestLogger {
 		assertEquals(1, jobGraph.getVerticesAsArray()[1].getParallelism());
 	}
 
-	/**
-	 * Tests that the KeyGroupAssigner is properly set in the {@link StreamConfig} if the max
-	 * parallelism is set for the whole job.
-	 */
-	@Test
-	public void testKeyGroupAssignerProperlySet() {
-		int maxParallelism = 42;
-
-		final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
-		env.getConfig().setMaxParallelism(maxParallelism);
-
-		DataStream<Integer> input = env.fromElements(1, 2, 3);
-
-		DataStream<Integer> keyedResult = input.keyBy(new KeySelector<Integer, Integer>() {
-			private static final long serialVersionUID = 350461576474507944L;
-
-			@Override
-			public Integer getKey(Integer value) throws Exception {
-				return value;
-			}
-		}).map(new NoOpIntMap());
-
-		keyedResult.addSink(new DiscardingSink<Integer>());
-
-		JobGraph jobGraph = env.getStreamGraph().getJobGraph();
-
-		List<JobVertex> jobVertices = jobGraph.getVerticesSortedTopologicallyFromSources();
-
-		assertEquals(maxParallelism, jobVertices.get(1).getMaxParallelism());
-
-		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner = extractHashKeyGroupAssigner(jobVertices.get(1));
-
-		assertEquals(maxParallelism, hashKeyGroupAssigner.getNumberKeyGroups());
-	}
-
-	/**
-	 * Tests that the key group assigner for the keyed streams in the stream config is properly
-	 * initialized with the max parallelism value if there is no max parallelism defined for the
-	 * whole job.
-	 */
-	@Test
-	public void testKeyGroupAssignerProperlySetAutoMaxParallelism() {
-		int globalParallelism = 42;
-		int mapParallelism = 17;
-		int maxParallelism = 43;
-		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
-		env.setParallelism(globalParallelism);
-
-		DataStream<Integer> source = env.fromElements(1, 2, 3);
-
-		DataStream<Integer> keyedResult1 = source.keyBy(new KeySelector<Integer, Integer>() {
-			private static final long serialVersionUID = 9205556348021992189L;
-
-			@Override
-			public Integer getKey(Integer value) throws Exception {
-				return value;
-			}
-		}).map(new NoOpIntMap());
-
-		DataStream<Integer> keyedResult2 = keyedResult1.keyBy(new KeySelector<Integer, Integer>() {
-			private static final long serialVersionUID = 1250168178707154838L;
-
-			@Override
-			public Integer getKey(Integer value) throws Exception {
-				return value;
-			}
-		}).map(new NoOpIntMap()).setParallelism(mapParallelism);
-
-		DataStream<Integer> keyedResult3 = keyedResult2.keyBy(new KeySelector<Integer, Integer>() {
-			private static final long serialVersionUID = 1250168178707154838L;
-
-			@Override
-			public Integer getKey(Integer value) throws Exception {
-				return value;
-			}
-		}).map(new NoOpIntMap()).setMaxParallelism(maxParallelism);
-
-		DataStream<Integer> keyedResult4 = keyedResult3.keyBy(new KeySelector<Integer, Integer>() {
-			private static final long serialVersionUID = 1250168178707154838L;
-
-			@Override
-			public Integer getKey(Integer value) throws Exception {
-				return value;
-			}
-		}).map(new NoOpIntMap()).setMaxParallelism(maxParallelism).setParallelism(mapParallelism);
-
-		keyedResult4.addSink(new DiscardingSink<Integer>());
-
-		JobGraph jobGraph = env.getStreamGraph().getJobGraph();
-
-		List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
-
-		JobVertex keyedResultJV1 = vertices.get(1);
-		JobVertex keyedResultJV2 = vertices.get(2);
-		JobVertex keyedResultJV3 = vertices.get(3);
-		JobVertex keyedResultJV4 = vertices.get(4);
-
-		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner1 = extractHashKeyGroupAssigner(keyedResultJV1);
-		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner2 = extractHashKeyGroupAssigner(keyedResultJV2);
-		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner3 = extractHashKeyGroupAssigner(keyedResultJV3);
-		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner4 = extractHashKeyGroupAssigner(keyedResultJV4);
-
-		assertEquals(globalParallelism, hashKeyGroupAssigner1.getNumberKeyGroups());
-		assertEquals(mapParallelism, hashKeyGroupAssigner2.getNumberKeyGroups());
-		assertEquals(maxParallelism, hashKeyGroupAssigner3.getNumberKeyGroups());
-		assertEquals(maxParallelism, hashKeyGroupAssigner4.getNumberKeyGroups());
-	}
-
-	/**
-	 * Tests that the {@link KeyGroupAssigner} is properly set in the {@link StreamConfig} for
-	 * connected streams.
-	 */
-	@Test
-	public void testMaxParallelismWithConnectedKeyedStream() {
-		int maxParallelism = 42;
-
-		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
-		DataStream<Integer> input1 = env.fromElements(1, 2, 3, 4).setMaxParallelism(128).name("input1");
-		DataStream<Integer> input2 = env.fromElements(1, 2, 3, 4).setMaxParallelism(129).name("input2");
-
-		env.getConfig().setMaxParallelism(maxParallelism);
-
-		DataStream<Integer> keyedResult = input1.connect(input2).keyBy(
-			new KeySelector<Integer, Integer>() {
-				private static final long serialVersionUID = -6908614081449363419L;
-
-				@Override
-				public Integer getKey(Integer value) throws Exception {
-					return value;
-				}
-			},
-			new KeySelector<Integer, Integer>() {
-				private static final long serialVersionUID = 3195683453223164931L;
-
-				@Override
-				public Integer getKey(Integer value) throws Exception {
-					return value;
-				}
-			}).map(new StreamGraphGeneratorTest.NoOpIntCoMap());
-
-		keyedResult.addSink(new DiscardingSink<Integer>());
-
-		JobGraph jobGraph = env.getStreamGraph().getJobGraph();
-
-		List<JobVertex> jobVertices = jobGraph.getVerticesSortedTopologicallyFromSources();
-
-		JobVertex input1JV = jobVertices.get(0);
-		JobVertex input2JV = jobVertices.get(1);
-		JobVertex connectedJV = jobVertices.get(2);
-
-		// disambiguate the partial order of the inputs
-		if (input1JV.getName().equals("Source: input1")) {
-			assertEquals(128, input1JV.getMaxParallelism());
-			assertEquals(129, input2JV.getMaxParallelism());
-		} else {
-			assertEquals(128, input2JV.getMaxParallelism());
-			assertEquals(129, input1JV.getMaxParallelism());
-		}
-
-		assertEquals(maxParallelism, connectedJV.getMaxParallelism());
-
-		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner = extractHashKeyGroupAssigner(connectedJV);
-
-		assertEquals(maxParallelism, hashKeyGroupAssigner.getNumberKeyGroups());
-	}
-
-	/**
-	 * Tests that the {@link JobGraph} creation fails if the parallelism is greater than the max
-	 * parallelism.
-	 */
-	@Test(expected=IllegalStateException.class)
-	public void testFailureOfJobJobCreationIfParallelismGreaterThanMaxParallelism() {
-		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
-		env.getConfig().setMaxParallelism(42);
-
-		DataStream<Integer> input = env.fromElements(1, 2, 3, 4);
-
-		DataStream<Integer> result = input.map(new NoOpIntMap()).setParallelism(43);
-
-		result.addSink(new DiscardingSink<Integer>());
-
-		env.getStreamGraph().getJobGraph();
-
-		fail("The JobGraph should not have been created because the parallelism is greater than " +
-			"the max parallelism.");
-	}
-
-	private HashKeyGroupAssigner<Integer> extractHashKeyGroupAssigner(JobVertex jobVertex) {
-		Configuration config = jobVertex.getConfiguration();
-
-		StreamConfig streamConfig = new StreamConfig(config);
-
-		KeyGroupAssigner<Integer> keyGroupAssigner = streamConfig.getKeyGroupAssigner(getClass().getClassLoader());
-
-		assertTrue(keyGroupAssigner instanceof HashKeyGroupAssigner);
-
-		return (HashKeyGroupAssigner<Integer>) keyGroupAssigner;
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/6d430618/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
index d3b7ff9..fe09788 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
@@ -30,19 +30,16 @@ import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
-import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.KvStateRegistry;
-import org.apache.flink.runtime.state.HashKeyGroupAssigner;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
-import org.apache.flink.runtime.state.heap.HeapListState;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
@@ -196,7 +193,7 @@ public class StreamingRuntimeContextTest {
 								new JobID(),
 								"test_op",
 								IntSerializer.INSTANCE,
-								new HashKeyGroupAssigner<Integer>(1),
+								1,
 								new KeyGroupRange(0, 0),
 								new KvStateRegistry().createTaskRegistry(new JobID(),
 										new JobVertexID()));

http://git-wip-us.apache.org/repos/asf/flink/blob/6d430618/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
index 4e7e4d0..59bfe6f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
@@ -22,7 +22,6 @@ import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.accumulators.Accumulator;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.functions.RichReduceFunction;
-import org.apache.flink.api.common.state.KeyGroupAssigner;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
@@ -37,9 +36,6 @@ import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
-import org.apache.flink.runtime.state.AbstractStateBackend;
-import org.apache.flink.runtime.state.memory.MemoryStateBackend;
-import org.apache.flink.runtime.state.HashKeyGroupAssigner;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.apache.flink.streaming.api.graph.StreamConfig;
@@ -52,7 +48,6 @@ import org.apache.flink.streaming.runtime.tasks.TestTimeServiceProvider;
 import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.junit.After;
-import org.junit.Ignore;
 import org.junit.Test;
 
 import org.mockito.invocation.InvocationOnMock;
@@ -212,7 +207,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 
 			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
 					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 5000, 1000);
-			op.setup(mockTask, createTaskConfig(mockKeySelector, StringSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), mockOut);
+			op.setup(mockTask, createTaskConfig(mockKeySelector, StringSerializer.INSTANCE, 10), mockOut);
 			op.open();
 			assertTrue(op.getNextSlideTime() % 1000 == 0);
 			assertTrue(op.getNextEvaluationTime() % 1000 == 0);
@@ -264,7 +259,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			final Object lock = new Object();
 			final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
 
-			op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), out);
+			op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE, 10), out);
 			op.open();
 
 			final int numElements = 1000;
@@ -322,7 +317,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 							IntSerializer.INSTANCE, tupleSerializer,
 							windowSize, windowSize);
 
-			op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), out);
+			op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE, 10), out);
 			op.open();
 
 			final int numWindows = 10;
@@ -389,7 +384,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 							IntSerializer.INSTANCE, tupleSerializer,
 							150, 50);
 
-			op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), out);
+			op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE, 10), out);
 			op.open();
 
 			final int numElements = 1000;
@@ -458,7 +453,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 							sumFunction, fieldOneSelector,
 							IntSerializer.INSTANCE, tupleSerializer, 150, 50);
 
-			op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), out);
+			op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE, 10), out);
 			op.open();
 
 			synchronized (lock) {
@@ -520,7 +515,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 							IntSerializer.INSTANCE, tupleSerializer,
 							hundredYears, hundredYears);
 
-			op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), out);
+			op.setup(mockTask, createTaskConfig(fieldOneSelector, IntSerializer.INSTANCE, 10), out);
 			op.open();
 
 			for (int i = 0; i < 100; i++) {
@@ -973,7 +968,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 		return mockTask;
 	}
 
-	private static StreamConfig createTaskConfig(KeySelector<?, ?> partitioner, TypeSerializer<?> keySerializer, KeyGroupAssigner<?> keyGroupAssigner) {
+	private static StreamConfig createTaskConfig(KeySelector<?, ?> partitioner, TypeSerializer<?> keySerializer, int numberOfKeGroups) {
 		StreamConfig cfg = new StreamConfig(new Configuration());
 		return cfg;
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/6d430618/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java
index 6fbf35e..4ca7449 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java
@@ -23,7 +23,6 @@ import static org.junit.Assert.assertEquals;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.plugable.SerializationDelegate;
-import org.apache.flink.runtime.state.HashKeyGroupAssigner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.util.TestLogger;
 import org.junit.Before;
@@ -48,7 +47,7 @@ public class KeyGroupStreamPartitionerTest extends TestLogger {
 				return value.getField(0);
 			}
 		},
-		new HashKeyGroupAssigner<String>(1024));
+		1024);
 	}
 
 	@Test

http://git-wip-us.apache.org/repos/asf/flink/blob/6d430618/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
index 5f73e25..5573a53 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
@@ -23,7 +23,6 @@ import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate;
-import org.apache.flink.runtime.state.HashKeyGroupAssigner;
 
 import java.io.IOException;
 
@@ -106,7 +105,7 @@ public class OneInputStreamTaskTestHarness<IN, OUT> extends StreamTaskTestHarnes
 		ClosureCleaner.clean(keySelector, false);
 		streamConfig.setStatePartitioner(0, keySelector);
 		streamConfig.setStateKeySerializer(keyType.createSerializer(executionConfig));
-		streamConfig.setKeyGroupAssigner(new HashKeyGroupAssigner<K>(10));
+		streamConfig.setNumberOfKeyGroups(10);
 	}
 }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/6d430618/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index 5594193..03f50f9 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -19,14 +19,12 @@ package org.apache.flink.streaming.util;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.state.KeyGroupAssigner;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.HashKeyGroupAssigner;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
@@ -70,7 +68,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 		ClosureCleaner.clean(keySelector, false);
 		config.setStatePartitioner(0, keySelector);
 		config.setStateKeySerializer(keyType.createSerializer(executionConfig));
-		config.setKeyGroupAssigner(new HashKeyGroupAssigner<K>(MAX_PARALLELISM));
+		config.setNumberOfKeyGroups(MAX_PARALLELISM);
 
 		setupMockTaskCreateKeyedBackend();
 	}
@@ -84,7 +82,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 		ClosureCleaner.clean(keySelector, false);
 		config.setStatePartitioner(0, keySelector);
 		config.setStateKeySerializer(keyType.createSerializer(executionConfig));
-		config.setKeyGroupAssigner(new HashKeyGroupAssigner<K>(MAX_PARALLELISM));
+		config.setNumberOfKeyGroups(MAX_PARALLELISM);
 
 		setupMockTaskCreateKeyedBackend();
 	}
@@ -99,7 +97,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 		ClosureCleaner.clean(keySelector, false);
 		config.setStatePartitioner(0, keySelector);
 		config.setStateKeySerializer(keyType.createSerializer(executionConfig));
-		config.setKeyGroupAssigner(new HashKeyGroupAssigner<K>(MAX_PARALLELISM));
+		config.setNumberOfKeyGroups(MAX_PARALLELISM);
 
 		setupMockTaskCreateKeyedBackend();
 	}
@@ -112,7 +110,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 				public KeyedStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable {
 
 					final TypeSerializer keySerializer = (TypeSerializer) invocationOnMock.getArguments()[0];
-					final KeyGroupAssigner keyGroupAssigner = (KeyGroupAssigner) invocationOnMock.getArguments()[1];
+					final int numberOfKeyGroups = (Integer) invocationOnMock.getArguments()[1];
 					final KeyGroupRange keyGroupRange = (KeyGroupRange) invocationOnMock.getArguments()[2];
 
 					if (restoredKeyedState == null) {
@@ -121,7 +119,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 								new JobID(),
 								"test_op",
 								keySerializer,
-								keyGroupAssigner,
+								numberOfKeyGroups,
 								keyGroupRange,
 								mockTask.getEnvironment().getTaskKvStateRegistry());
 						return keyedStateBackend;
@@ -131,7 +129,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 								new JobID(),
 								"test_op",
 								keySerializer,
-								keyGroupAssigner,
+								numberOfKeyGroups,
 								keyGroupRange,
 								Collections.singletonList(restoredKeyedState),
 								mockTask.getEnvironment().getTaskKvStateRegistry());
@@ -139,7 +137,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 						return keyedStateBackend;
 					}
 				}
-			}).when(mockTask).createKeyedStateBackend(any(TypeSerializer.class), any(KeyGroupAssigner.class), any(KeyGroupRange.class));
+			}).when(mockTask).createKeyedStateBackend(any(TypeSerializer.class), anyInt(), any(KeyGroupRange.class));
 		} catch (Exception e) {
 			throw new RuntimeException(e.getMessage(), e);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/6d430618/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
index 78e05b7..15074a7 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
@@ -19,7 +19,6 @@ package org.apache.flink.streaming.util;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.state.KeyGroupAssigner;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.configuration.Configuration;
@@ -29,8 +28,6 @@ import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.streaming.api.graph.StreamConfig;
@@ -47,7 +44,6 @@ import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
 import java.util.Collection;
-import java.util.Collections;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.Executors;
 

http://git-wip-us.apache.org/repos/asf/flink/blob/6d430618/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index 39f3086..82dbd1f 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -34,8 +34,7 @@ import org.apache.flink.runtime.execution.SuppressRestartsException;
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.messages.JobManagerMessages;
-import org.apache.flink.runtime.state.HashKeyGroupAssigner;
-import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
 import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
@@ -141,12 +140,10 @@ public class RescalingITCase extends TestLogger {
 
 			Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
 
-			HashKeyGroupAssigner<Integer> keyGroupAssigner = new HashKeyGroupAssigner<>(maxParallelism);
-
 			for (int key = 0; key < numberKeys; key++) {
-				int keyGroupIndex = keyGroupAssigner.getKeyGroupIndex(key);
+				int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
 
-				expectedResult.add(Tuple2.of(KeyGroupRange.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex), numberElements * key));
+				expectedResult.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex), numberElements * key));
 			}
 
 			assertEquals(expectedResult, actualResult);
@@ -185,11 +182,9 @@ public class RescalingITCase extends TestLogger {
 
 			Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
 
-			HashKeyGroupAssigner<Integer> keyGroupAssigner2 = new HashKeyGroupAssigner<>(maxParallelism);
-
 			for (int key = 0; key < numberKeys; key++) {
-				int keyGroupIndex = keyGroupAssigner2.getKeyGroupIndex(key);
-				expectedResult2.add(Tuple2.of(KeyGroupRange.computeOperatorIndexForKeyGroup(maxParallelism, parallelism2, keyGroupIndex), key * (numberElements + numberElements2)));
+				int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
+				expectedResult2.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism2, keyGroupIndex), key * (numberElements + numberElements2)));
 			}
 
 			assertEquals(expectedResult2, actualResult2);
@@ -347,12 +342,10 @@ public class RescalingITCase extends TestLogger {
 
 			Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
 
-			HashKeyGroupAssigner<Integer> keyGroupAssigner = new HashKeyGroupAssigner<>(maxParallelism);
-
 			for (int key = 0; key < numberKeys; key++) {
-				int keyGroupIndex = keyGroupAssigner.getKeyGroupIndex(key);
+				int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
 
-				expectedResult.add(Tuple2.of(KeyGroupRange.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex) , numberElements * key));
+				expectedResult.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex) , numberElements * key));
 			}
 
 			assertEquals(expectedResult, actualResult);
@@ -398,11 +391,9 @@ public class RescalingITCase extends TestLogger {
 
 			Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
 
-			HashKeyGroupAssigner<Integer> keyGroupAssigner2 = new HashKeyGroupAssigner<>(maxParallelism);
-
 			for (int key = 0; key < numberKeys; key++) {
-				int keyGroupIndex = keyGroupAssigner2.getKeyGroupIndex(key);
-				expectedResult2.add(Tuple2.of(KeyGroupRange.computeOperatorIndexForKeyGroup(maxParallelism, parallelism2, keyGroupIndex), key * (numberElements + numberElements2)));
+				int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
+				expectedResult2.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism2, keyGroupIndex), key * (numberElements + numberElements2)));
 			}
 
 			assertEquals(expectedResult2, actualResult2);

http://git-wip-us.apache.org/repos/asf/flink/blob/6d430618/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
index d4dd475..694f006 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
@@ -21,7 +21,6 @@ package org.apache.flink.test.streaming.runtime;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
-import org.apache.flink.api.common.state.KeyGroupAssigner;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
@@ -104,7 +103,7 @@ public class StateBackendITCase extends StreamingMultipleProgramsTestBase {
 				JobID jobID,
 				String operatorIdentifier,
 				TypeSerializer<K> keySerializer,
-				KeyGroupAssigner<K> keyGroupAssigner,
+				int numberOfKeyGroups,
 				KeyGroupRange keyGroupRange,
 				TaskKvStateRegistry kvStateRegistry) throws Exception {
 			throw new SuccessException();
@@ -115,7 +114,7 @@ public class StateBackendITCase extends StreamingMultipleProgramsTestBase {
 				JobID jobID,
 				String operatorIdentifier,
 				TypeSerializer<K> keySerializer,
-				KeyGroupAssigner<K> keyGroupAssigner,
+				int numberOfKeyGroups,
 				KeyGroupRange keyGroupRange,
 				List<KeyGroupsStateHandle> restoredState,
 				TaskKvStateRegistry kvStateRegistry) throws Exception {