You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by gy...@apache.org on 2015/06/25 19:21:45 UTC

[12/12] flink git commit: [streaming] Add KeyedDataStream abstraction and integrate it with the rest of the refactoring

[streaming] Add KeyedDataStream abstraction and integrate it with the rest of the refactoring

Closes #747


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/cad85103
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/cad85103
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/cad85103

Branch: refs/heads/master
Commit: cad85103e8f742c2567a8d63893ddb439705b085
Parents: 0ae1758
Author: Paris Carbone <pa...@kth.se>
Authored: Mon Jun 22 17:14:33 2015 +0200
Committer: Gyula Fora <gy...@apache.org>
Committed: Thu Jun 25 19:20:00 2015 +0200

----------------------------------------------------------------------
 docs/apis/streaming_guide.md                    |  4 +-
 .../streaming/api/datastream/DataStream.java    | 59 +++++++++++++--
 .../api/datastream/GroupedDataStream.java       | 19 +----
 .../api/datastream/KeyedDataStream.java         | 77 ++++++++++++++++++++
 .../flink/streaming/api/graph/StreamGraph.java  |  5 ++
 .../flink/streaming/api/graph/StreamNode.java   |  9 +++
 .../api/graph/StreamingJobGraphGenerator.java   |  3 +
 .../runtime/tasks/OneInputStreamTask.java       |  2 +-
 .../api/state/StatefulOperatorTest.java         | 38 ++++++++++
 .../flink/streaming/api/scala/DataStream.scala  | 30 +++++++-
 10 files changed, 221 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/cad85103/docs/apis/streaming_guide.md
----------------------------------------------------------------------
diff --git a/docs/apis/streaming_guide.md b/docs/apis/streaming_guide.md
index 2713d6e..9218c98 100644
--- a/docs/apis/streaming_guide.md
+++ b/docs/apis/streaming_guide.md
@@ -1202,6 +1202,8 @@ Checkpointing of the states needs to be enabled from the `StreamExecutionEnviron
 
 Operator states can be accessed from the `RuntimeContext` using the `getOperatorState(“name”, defaultValue, partitioned)` method so it is only accessible in `RichFunction`s. A recommended usage pattern is to retrieve the operator state in the `open(…)` method of the operator and set it as a field in the operator instance for runtime usage. Multiple `OperatorState`s can be used simultaneously by the same operator by using different names to identify them.
 
+Partitioned operator state works only on `KeyedDataStreams`. A `KeyedDataStream` can be created from `DataStream` using the `keyBy` or `groupBy` methods. The `keyBy` method simply takes a `KeySelector` to derive the keys by which the operator state will be partitioned, however, it does not affect the actual partitioning of the `DataStream` records. If data partitioning is also desired then the `groupBy`  method should be used instead to create a `GroupedDataStream` which is a subtype of `KeyedDataStream`. Mind that `KeyedDataStreams` do not support repartitioning (e.g. `shuffle(), forward(), groupBy(...)`).
+
 By default operator states are checkpointed using default java serialization thus they need to be `Serializable`. The user can gain more control over the state checkpoint mechanism by passing a `StateCheckpointer` instance when retrieving the `OperatorState` from the `RuntimeContext`. The `StateCheckpointer` allows custom implementations for the checkpointing logic for increased efficiency and to store arbitrary non-serializable states.
 
 By default state checkpoints will be stored in-memory at the JobManager. Flink also supports storing the checkpoints on any flink-supported file system (such as HDFS or Tachyon) which can be set in the flink-conf.yaml. Note that the state backend must be accessible from the JobManager, use `file://` only for local setups.
@@ -1222,7 +1224,7 @@ public class CounterSum implements RichReduceFunction<Long> {
 
     @Override
     public void open(Configuration config) {
-        counter = getRuntimeContext().getOperatorState(“counter”, 0L);
+        counter = getRuntimeContext().getOperatorState(“counter”, 0L, false);
     }
 }
 {% endhighlight %} 

http://git-wip-us.apache.org/repos/asf/flink/blob/cad85103/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
index b065950..fc16264 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
@@ -301,16 +301,63 @@ public class DataStream<OUT> {
 	}
 
 	/**
-	 * Groups the elements of a {@link DataStream} by the given key positions to
-	 * be used with grouped operators like
-	 * {@link GroupedDataStream#reduce(ReduceFunction)}</p> This operator also
-	 * affects the partitioning of the stream, by forcing values with the same
-	 * key to go to the same processing instance.
 	 * 
+	 * It creates a new {@link KeyedDataStream} that uses the provided key for partitioning
+	 * its operator states. 
+	 *
+	 * @param key
+	 *            The KeySelector to be used for extracting the key for partitioning
+	 * @return The {@link DataStream} with partitioned state (i.e. KeyedDataStream)
+	 */
+	public KeyedDataStream<OUT> keyBy(KeySelector<OUT,?> key){
+		return new KeyedDataStream<OUT>(this, clean(key));
+	}
+
+	/**
+	 * Partitions the operator state of a {@link DataStream} by the given key positions. 
+	 *
 	 * @param fields
 	 *            The position of the fields on which the {@link DataStream}
 	 *            will be grouped.
-	 * @return The grouped {@link DataStream}
+	 * @return The {@link DataStream} with partitioned state (i.e. KeyedDataStream)
+	 */
+	public KeyedDataStream<OUT> keyBy(int... fields) {
+		if (getType() instanceof BasicArrayTypeInfo || getType() instanceof PrimitiveArrayTypeInfo) {
+			return keyBy(new KeySelectorUtil.ArrayKeySelector<OUT>(fields));
+		} else {
+			return keyBy(new Keys.ExpressionKeys<OUT>(fields, getType()));
+		}
+	}
+
+	/**
+	 * Partitions the operator state of a {@link DataStream}using field expressions. 
+	 * A field expression is either the name of a public field or a getter method with parentheses
+	 * of the {@link DataStream}S underlying type. A dot can be used to drill
+	 * down into objects, as in {@code "field1.getInnerField2()" }.
+	 *
+	 * @param fields
+	 *            One or more field expressions on which the state of the {@link DataStream} operators will be
+	 *            partitioned.
+	 * @return The {@link DataStream} with partitioned state (i.e. KeyedDataStream)
+	 **/
+	public KeyedDataStream<OUT> keyBy(String... fields) {
+		return keyBy(new Keys.ExpressionKeys<OUT>(fields, getType()));
+	}
+
+	private KeyedDataStream<OUT> keyBy(Keys<OUT> keys) {
+		return new KeyedDataStream<OUT>(this, clean(KeySelectorUtil.getSelectorForKeys(keys,
+				getType(), getExecutionConfig())));
+	}
+	
+	/**
+	 * Partitions the operator state of a {@link DataStream} by the given key positions. 
+	 * Mind that keyBy does not affect the partitioning of the {@link DataStream}
+	 * but only the way explicit state is partitioned among parallel instances.
+	 * 
+	 * @param fields
+	 *            The position of the fields on which the states of the {@link DataStream}
+	 *            will be partitioned.
+	 * @return The {@link DataStream} with partitioned state (i.e. KeyedDataStream)
 	 */
 	public GroupedDataStream<OUT> groupBy(int... fields) {
 		if (getType() instanceof BasicArrayTypeInfo || getType() instanceof PrimitiveArrayTypeInfo) {

http://git-wip-us.apache.org/repos/asf/flink/blob/cad85103/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/GroupedDataStream.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/GroupedDataStream.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/GroupedDataStream.java
index 2d6829d..87720fa 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/GroupedDataStream.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/GroupedDataStream.java
@@ -26,7 +26,6 @@ import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.streaming.api.functions.aggregation.AggregationFunction;
 import org.apache.flink.streaming.api.operators.StreamGroupedFold;
 import org.apache.flink.streaming.api.operators.StreamGroupedReduce;
-import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 
 /**
  * A GroupedDataStream represents a {@link DataStream} which has been
@@ -37,9 +36,7 @@ import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
  * @param <OUT>
  *            The output type of the {@link GroupedDataStream}.
  */
-public class GroupedDataStream<OUT> extends DataStream<OUT> {
-
-	KeySelector<OUT, ?> keySelector;
+public class GroupedDataStream<OUT> extends KeyedDataStream<OUT> {
 
 	/**
 	 * Creates a new {@link GroupedDataStream}, group inclusion is determined using
@@ -49,17 +46,11 @@ public class GroupedDataStream<OUT> extends DataStream<OUT> {
 	 * @param keySelector Function for determining group inclusion
 	 */
 	public GroupedDataStream(DataStream<OUT> dataStream, KeySelector<OUT, ?> keySelector) {
-		super(dataStream.partitionByHash(keySelector));
-		this.keySelector = keySelector;
+		super(dataStream, keySelector);
 	}
 
 	protected GroupedDataStream(GroupedDataStream<OUT> dataStream) {
 		super(dataStream);
-		this.keySelector = dataStream.keySelector;
-	}
-
-	public KeySelector<OUT, ?> getKeySelector() {
-		return this.keySelector;
 	}
 
 	/**
@@ -225,12 +216,8 @@ public class GroupedDataStream<OUT> extends DataStream<OUT> {
 	}
 
 	@Override
-	protected DataStream<OUT> setConnectionType(StreamPartitioner<OUT> partitioner) {
-		return super.setConnectionType(partitioner);
-	}
-
-	@Override
 	public GroupedDataStream<OUT> copy() {
 		return new GroupedDataStream<OUT>(this);
 	}
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cad85103/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/KeyedDataStream.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/KeyedDataStream.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/KeyedDataStream.java
new file mode 100644
index 0000000..b944302
--- /dev/null
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/KeyedDataStream.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.datastream;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+
+/**
+ * A KeyedDataStream represents a {@link DataStream} on which operator state is
+ * partitioned by key using a provided {@link KeySelector}. Typical operations supported by a {@link DataStream}
+ * are also possible on a KeyedDataStream, with the exception of partitioning methods such as shuffle, forward and groupBy.
+ * 
+ * 
+ * @param <OUT>
+ *            The output type of the {@link KeyedDataStream}.
+ */
+public class KeyedDataStream<OUT> extends DataStream<OUT> {
+	KeySelector<OUT, ?> keySelector;
+
+	/**
+	 * Creates a new {@link KeyedDataStream} using the given {@link KeySelector}
+	 * to partition operator state by key.
+	 * 
+	 * @param dataStream
+	 *            Base stream of data
+	 * @param keySelector
+	 *            Function for determining state partitions
+	 */
+	public KeyedDataStream(DataStream<OUT> dataStream, KeySelector<OUT, ?> keySelector) {
+		super(dataStream.partitionByHash(keySelector));
+		this.keySelector = keySelector;
+	}
+
+	protected KeyedDataStream(KeyedDataStream<OUT> dataStream) {
+		super(dataStream);
+		this.keySelector = dataStream.keySelector;
+	}
+
+	public KeySelector<OUT, ?> getKeySelector() {
+		return this.keySelector;
+	}
+
+	@Override
+	protected DataStream<OUT> setConnectionType(StreamPartitioner<OUT> partitioner) {
+		throw new UnsupportedOperationException("Cannot override partitioning for KeyedDataStream.");
+	}
+
+	@Override
+	public KeyedDataStream<OUT> copy() {
+		return new KeyedDataStream<OUT>(this);
+	}
+
+	@Override
+	public <R> SingleOutputStreamOperator<R, ?> transform(String operatorName,
+			TypeInformation<R> outTypeInfo, OneInputStreamOperator<OUT, R> operator) {
+		SingleOutputStreamOperator<R, ?> returnStream = super.transform(operatorName, outTypeInfo,operator);
+		streamGraph.setKey(returnStream.getId(), keySelector);
+		return returnStream;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cad85103/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
index 8ef4ca0..e07d881 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
@@ -33,6 +33,7 @@ import java.util.Set;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.io.InputFormat;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.MissingTypeInfo;
 import org.apache.flink.optimizer.plan.StreamingPlan;
@@ -278,6 +279,10 @@ public class StreamGraph extends StreamingPlan {
 		getStreamNode(vertexID).setParallelism(parallelism);
 	}
 
+	public void setKey(Integer vertexID, KeySelector<?,?> key) {
+		getStreamNode(vertexID).setStatePartitioner(key);
+	}
+
 	public void setBufferTimeout(Integer vertexID, long bufferTimeout) {
 		getStreamNode(vertexID).setBufferTimeout(bufferTimeout);
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/cad85103/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java
index ccca2f1..0b909bd 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java
@@ -22,6 +22,7 @@ import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.flink.api.common.io.InputFormat;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.streaming.api.collector.selector.OutputSelector;
 import org.apache.flink.streaming.api.collector.selector.OutputSelectorWrapper;
@@ -48,6 +49,7 @@ public class StreamNode implements Serializable {
 	private String operatorName;
 	private Integer slotSharingID;
 	private boolean isolatedSlot = false;
+	private KeySelector<?,?> statePartitioner;
 
 	private transient StreamOperator<?> operator;
 	private List<OutputSelector<?>> outputSelectors;
@@ -219,4 +221,11 @@ public class StreamNode implements Serializable {
 		return operatorName + id;
 	}
 
+	public KeySelector<?, ?> getStatePartitioner() {
+		return statePartitioner;
+	}
+
+	public void setStatePartitioner(KeySelector<?, ?> statePartitioner) {
+		this.statePartitioner = statePartitioner;
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cad85103/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index 531fc71..eb34e3f 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -18,6 +18,7 @@
 package org.apache.flink.streaming.api.graph;
 
 import java.io.IOException;
+import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
@@ -28,6 +29,7 @@ import java.util.Map.Entry;
 
 import org.apache.commons.lang.StringUtils;
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
@@ -268,6 +270,7 @@ public class StreamingJobGraphGenerator {
 		config.setChainedOutputs(chainableOutputs);
 		config.setStateMonitoring(streamGraph.isCheckpointingEnabled());
 		config.setStateHandleProvider(streamGraph.getStateHandleProvider());
+		config.setStatePartitioner((KeySelector<?, Serializable>) vertex.getStatePartitioner());
 
 		Class<? extends AbstractInvokable> vertexClass = vertex.getJobVertexClass();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cad85103/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
index 87042ba..80239fd 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
@@ -99,7 +99,7 @@ public class OneInputStreamTask<IN, OUT> extends StreamTask<OUT, OneInputStreamO
 
 			StreamRecord<IN> nextRecord;
 			while (isRunning && (nextRecord = readNext()) != null) {
-				headContext.setNextInput(nextRecord);
+				headContext.setNextInput(nextRecord.getObject());
 				streamOperator.processElement(nextRecord.getObject());
 			}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cad85103/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
index 442d8ea..af719f3 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
@@ -27,6 +27,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.RichMapFunction;
@@ -40,9 +41,13 @@ import org.apache.flink.runtime.state.LocalStateHandle.LocalStateHandleProvider;
 import org.apache.flink.runtime.state.PartitionedStateHandle;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.datastream.KeyedDataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamMap;
 import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext;
+import org.apache.flink.streaming.util.TestStreamEnvironment;
 import org.apache.flink.util.InstantiationUtil;
 import org.junit.Test;
 
@@ -92,6 +97,27 @@ public class StatefulOperatorTest {
 		assertEquals((Integer) 7, ((StatefulMapper) restoredMap.getUserFunction()).checkpointedCounter);
 
 	}
+	
+	@Test
+	public void apiTest() throws Exception {
+		StreamExecutionEnvironment env = new TestStreamEnvironment(3, 32);
+		
+		KeyedDataStream<Integer> keyedStream = env.fromCollection(Arrays.asList(0, 1, 2, 3, 4, 5, 6)).keyBy(new ModKey(4));
+		
+		keyedStream.map(new StatefulMapper()).addSink(new SinkFunction<String>() {
+			private static final long serialVersionUID = 1L;
+			public void invoke(String value) throws Exception {}
+		});
+		
+		try {
+			keyedStream.shuffle();
+			fail();
+		} catch (UnsupportedOperationException e) {
+
+		}
+		
+		env.execute();
+	}
 
 	private void processInputs(StreamMap<Integer, ?> map, List<Integer> input) throws Exception {
 		for (Integer i : input) {
@@ -173,6 +199,18 @@ public class StatefulOperatorTest {
 			} catch (RuntimeException e){
 			}
 		}
+		
+		@SuppressWarnings({ "rawtypes", "unchecked" })
+		@Override
+		public void close() throws Exception {
+			Map<String, StreamOperatorState> states = ((StreamingRuntimeContext) getRuntimeContext()).getOperatorStates();
+			PartitionedStreamOperatorState<Integer, Integer, Integer> groupCounter = (PartitionedStreamOperatorState<Integer, Integer, Integer>) states.get("groupCounter");
+			for (Entry<Serializable, Integer> count : groupCounter.getPartitionedState().entrySet()) {
+				Integer key = (Integer) count.getKey();
+				Integer expected = key < 3 ? 2 : 1;
+				assertEquals(expected, count.getValue());
+			}
+		}
 
 		@Override
 		public Integer snapshotState(long checkpointId, long checkpointTimestamp)

http://git-wip-us.apache.org/repos/asf/flink/blob/cad85103/flink-staging/flink-streaming/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala b/flink-staging/flink-streaming/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
index 501f7bd..96f951b 100644
--- a/flink-staging/flink-streaming/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
+++ b/flink-staging/flink-streaming/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
@@ -30,7 +30,8 @@ import org.apache.flink.api.common.functions.{FilterFunction, FlatMapFunction, F
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.functions.KeySelector
 import org.apache.flink.streaming.api.collector.selector.OutputSelector
-import org.apache.flink.streaming.api.datastream.{DataStream => JavaStream, DataStreamSink, GroupedDataStream, SingleOutputStreamOperator}
+import org.apache.flink.streaming.api.datastream.{DataStream => JavaStream, DataStreamSink, GroupedDataStream, 
+    KeyedDataStream, SingleOutputStreamOperator}
 import org.apache.flink.streaming.api.functions.aggregation.AggregationFunction.AggregationType
 import org.apache.flink.streaming.api.functions.sink.{FileSinkFunctionByMillis, SinkFunction}
 import org.apache.flink.streaming.api.functions.aggregation.{ComparableAggregator, SumAggregator}
@@ -209,6 +210,33 @@ class DataStream[T](javaStream: JavaStream[T]) {
   def connect[T2](dataStream: DataStream[T2]): ConnectedDataStream[T, T2] = 
     javaStream.connect(dataStream.getJavaStream)
 
+
+
+  /**
+   * Partitions the operator states of the DataStream by the given key positions 
+   * (for tuple/array types).
+   */
+  def keyBy(fields: Int*): DataStream[T] = javaStream.keyBy(fields: _*)
+
+  /**
+   *
+   * Partitions the operator states of the DataStream by the given field expressions.
+   */
+  def keyBy(firstField: String, otherFields: String*): DataStream[T] =
+    javaStream.keyBy(firstField +: otherFields.toArray: _*)
+
+
+  /**
+   * Partitions the operator states of the DataStream by the given K key. 
+   */
+  def keyBy[K: TypeInformation](fun: T => K): DataStream[T] = {
+    val cleanFun = clean(fun)
+    val keyExtractor = new KeySelector[T, K] {
+      def getKey(in: T) = cleanFun(in)
+    }
+    javaStream.keyBy(keyExtractor)
+  }
+  
   /**
    * Groups the elements of a DataStream by the given key positions (for tuple/array types) to
    * be used with grouped operators like grouped reduce or grouped aggregations.