You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/04/06 19:29:09 UTC

[02/12] flink git commit: [FLINK-6223] [py] Rework PythonPlanBinder generics

[FLINK-6223] [py] Rework PythonPlanBinder generics


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/bba49d67
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/bba49d67
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/bba49d67

Branch: refs/heads/table-retraction
Commit: bba49d673f9a36d4291bc74f07895e957de75bce
Parents: 8f78e9d
Author: zentol <ch...@apache.org>
Authored: Thu Mar 30 23:45:14 2017 +0200
Committer: zentol <ch...@apache.org>
Committed: Thu Apr 6 10:57:10 2017 +0200

----------------------------------------------------------------------
 .../flink/python/api/PythonPlanBinder.java      | 339 +++++++++----------
 .../apache/flink/python/api/util/SetCache.java  | 204 +++++++++++
 2 files changed, 367 insertions(+), 176 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/bba49d67/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java
index a3cae4a..7c228e1 100644
--- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/PythonPlanBinder.java
@@ -17,11 +17,12 @@ import java.io.IOException;
 import java.net.URI;
 import java.net.URISyntaxException;
 import java.util.Arrays;
-import java.util.HashMap;
 import java.util.Random;
 
 import org.apache.flink.api.common.JobExecutionResult;
+import org.apache.flink.api.common.operators.Keys;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.ExecutionEnvironment;
 import org.apache.flink.api.java.LocalEnvironment;
@@ -29,12 +30,11 @@ import org.apache.flink.api.java.io.PrintingOutputFormat;
 import org.apache.flink.api.java.io.TupleCsvInputFormat;
 import org.apache.flink.api.java.operators.CoGroupRawOperator;
 import org.apache.flink.api.java.operators.CrossOperator.DefaultCross;
-import org.apache.flink.api.java.operators.Grouping;
-import org.apache.flink.api.common.operators.Keys;
 import org.apache.flink.api.java.operators.SortedGrouping;
 import org.apache.flink.api.java.operators.UdfOperator;
 import org.apache.flink.api.java.operators.UnsortedGrouping;
 import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.TupleTypeInfo;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.GlobalConfiguration;
@@ -53,6 +53,7 @@ import org.apache.flink.python.api.functions.util.KeyDiscarder;
 import org.apache.flink.python.api.functions.util.SerializerMap;
 import org.apache.flink.python.api.functions.util.StringDeserializerMap;
 import org.apache.flink.python.api.streaming.plan.PythonPlanStreamer;
+import org.apache.flink.python.api.util.SetCache;
 import org.apache.flink.runtime.filecache.FileCache;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -92,7 +93,7 @@ public class PythonPlanBinder {
 	private static String FLINK_HDFS_PATH = "hdfs:/tmp";
 	public static final String FLINK_TMP_DATA_DIR = System.getProperty("java.io.tmpdir") + File.separator + "flink_data";
 
-	private HashMap<Integer, Object> sets = new HashMap<>();
+	private final SetCache sets = new SetCache();
 	public ExecutionEnvironment env;
 	private int currentEnvironmentID = 0;
 	private PythonPlanStreamer streamer;
@@ -242,7 +243,7 @@ public class PythonPlanBinder {
 	private void receivePlan() throws IOException {
 		env = ExecutionEnvironment.getExecutionEnvironment();
 		//IDs used in HashMap of sets are only unique for each environment
-		sets.clear();
+		sets.reset();
 		receiveParameters();
 		receiveOperations();
 	}
@@ -263,18 +264,18 @@ public class PythonPlanBinder {
 			Tuple value = (Tuple) streamer.getRecord(true);
 			switch (Parameters.valueOf(((String) value.getField(0)).toUpperCase())) {
 				case DOP:
-					Integer dop = (Integer) value.getField(1);
+					Integer dop = value.<Integer>getField(1);
 					env.setParallelism(dop);
 					break;
 				case MODE:
-					FLINK_HDFS_PATH = (Boolean) value.getField(1) ? "file:/tmp/flink" : "hdfs:/tmp/flink";
+					FLINK_HDFS_PATH = value.<Boolean>getField(1) ? "file:/tmp/flink" : "hdfs:/tmp/flink";
 					break;
 				case RETRY:
-					int retry = (Integer) value.getField(1);
+					int retry = value.<Integer>getField(1);
 					env.setRestartStrategy(RestartStrategies.fixedDelayRestart(retry, 10000L));
 					break;
 				case ID:
-					currentEnvironmentID = (Integer) value.getField(1);
+					currentEnvironmentID = value.<Integer>getField(1);
 					break;
 			}
 		}
@@ -352,40 +353,40 @@ public class PythonPlanBinder {
 					createUnionOperation(info);
 					break;
 				case COGROUP:
-					createCoGroupOperation(info);
+					createCoGroupOperation(info, info.types);
 					break;
 				case CROSS:
-					createCrossOperation(NONE, info);
+					createCrossOperation(NONE, info, info.types);
 					break;
 				case CROSS_H:
-					createCrossOperation(HUGE, info);
+					createCrossOperation(HUGE, info, info.types);
 					break;
 				case CROSS_T:
-					createCrossOperation(TINY, info);
+					createCrossOperation(TINY, info, info.types);
 					break;
 				case FILTER:
-					createFilterOperation(info);
+					createFilterOperation(info, info.types);
 					break;
 				case FLATMAP:
-					createFlatMapOperation(info);
+					createFlatMapOperation(info, info.types);
 					break;
 				case GROUPREDUCE:
 					createGroupReduceOperation(info);
 					break;
 				case JOIN:
-					createJoinOperation(NONE, info);
+					createJoinOperation(NONE, info, info.types);
 					break;
 				case JOIN_H:
-					createJoinOperation(HUGE, info);
+					createJoinOperation(HUGE, info, info.types);
 					break;
 				case JOIN_T:
-					createJoinOperation(TINY, info);
+					createJoinOperation(TINY, info, info.types);
 					break;
 				case MAP:
-					createMapOperation(info);
+					createMapOperation(info, info.types);
 					break;
 				case MAPPARTITION:
-					createMapPartitionOperation(info);
+					createMapPartitionOperation(info, info.types);
 					break;
 				case REDUCE:
 					createReduceOperation(info);
@@ -399,57 +400,54 @@ public class PythonPlanBinder {
 	}
 
 	@SuppressWarnings("unchecked")
-	private void createCsvSource(PythonOperationInfo info) throws IOException {
+	private <T extends Tuple> void createCsvSource(PythonOperationInfo info) {
 		if (!(info.types instanceof TupleTypeInfo)) {
 			throw new RuntimeException("The output type of a csv source has to be a tuple. The derived type is " + info);
 		}
 		Path path = new Path(info.path);
 		String lineD = info.lineDelimiter;
 		String fieldD = info.fieldDelimiter;
-		TupleTypeInfo<?> types = (TupleTypeInfo) info.types;
-		sets.put(info.setID, env.createInput(new TupleCsvInputFormat(path, lineD, fieldD, types), info.types).setParallelism(getParallelism(info)).name("CsvSource")
-				.map(new SerializerMap<>()).setParallelism(getParallelism(info)).name("CsvSourcePostStep"));
+		TupleTypeInfo<T> types = (TupleTypeInfo<T>) info.types;
+		sets.add(info.setID, env.createInput(new TupleCsvInputFormat<>(path, lineD, fieldD, types), types).setParallelism(getParallelism(info)).name("CsvSource")
+			.map(new SerializerMap<T>()).setParallelism(getParallelism(info)).name("CsvSourcePostStep"));
 	}
 
-	private void createTextSource(PythonOperationInfo info) throws IOException {
-		sets.put(info.setID, env.readTextFile(info.path).setParallelism(getParallelism(info)).name("TextSource")
-				.map(new SerializerMap<String>()).setParallelism(getParallelism(info)).name("TextSourcePostStep"));
+	private void createTextSource(PythonOperationInfo info) {
+		sets.add(info.setID, env.readTextFile(info.path).setParallelism(getParallelism(info)).name("TextSource")
+			.map(new SerializerMap<String>()).setParallelism(getParallelism(info)).name("TextSourcePostStep"));
 	}
 
-	private void createValueSource(PythonOperationInfo info) throws IOException {
-		sets.put(info.setID, env.fromElements(info.values).setParallelism(getParallelism(info)).name("ValueSource")
-				.map(new SerializerMap<>()).setParallelism(getParallelism(info)).name("ValueSourcePostStep"));
+	private void createValueSource(PythonOperationInfo info) {
+		sets.add(info.setID, env.fromElements(info.values).setParallelism(getParallelism(info)).name("ValueSource")
+			.map(new SerializerMap<>()).setParallelism(getParallelism(info)).name("ValueSourcePostStep"));
 	}
 
-	private void createSequenceSource(PythonOperationInfo info) throws IOException {
-		sets.put(info.setID, env.generateSequence(info.frm, info.to).setParallelism(getParallelism(info)).name("SequenceSource")
-				.map(new SerializerMap<Long>()).setParallelism(getParallelism(info)).name("SequenceSourcePostStep"));
+	private void createSequenceSource(PythonOperationInfo info) {
+		sets.add(info.setID, env.generateSequence(info.frm, info.to).setParallelism(getParallelism(info)).name("SequenceSource")
+			.map(new SerializerMap<Long>()).setParallelism(getParallelism(info)).name("SequenceSourcePostStep"));
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createCsvSink(PythonOperationInfo info) throws IOException {
-		DataSet parent = (DataSet) sets.get(info.parentID);
+	private void createCsvSink(PythonOperationInfo info) {
+		DataSet<byte[]> parent = sets.getDataSet(info.parentID);
 		parent.map(new StringTupleDeserializerMap()).setParallelism(getParallelism(info)).name("CsvSinkPreStep")
 				.writeAsCsv(info.path, info.lineDelimiter, info.fieldDelimiter, info.writeMode).setParallelism(getParallelism(info)).name("CsvSink");
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createTextSink(PythonOperationInfo info) throws IOException {
-		DataSet parent = (DataSet) sets.get(info.parentID);
+	private void createTextSink(PythonOperationInfo info) {
+		DataSet<byte[]> parent = sets.getDataSet(info.parentID);
 		parent.map(new StringDeserializerMap()).setParallelism(getParallelism(info))
 			.writeAsText(info.path, info.writeMode).setParallelism(getParallelism(info)).name("TextSink");
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createPrintSink(PythonOperationInfo info) throws IOException {
-		DataSet parent = (DataSet) sets.get(info.parentID);
+	private void createPrintSink(PythonOperationInfo info) {
+		DataSet<byte[]> parent = sets.getDataSet(info.parentID);
 		parent.map(new StringDeserializerMap()).setParallelism(getParallelism(info)).name("PrintSinkPreStep")
-			.output(new PrintingOutputFormat(info.toError)).setParallelism(getParallelism(info));
+			.output(new PrintingOutputFormat<String>(info.toError)).setParallelism(getParallelism(info));
 	}
 
-	private void createBroadcastVariable(PythonOperationInfo info) throws IOException {
-		UdfOperator<?> op1 = (UdfOperator) sets.get(info.parentID);
-		DataSet<?> op2 = (DataSet) sets.get(info.otherID);
+	private void createBroadcastVariable(PythonOperationInfo info) {
+		UdfOperator<?> op1 = (UdfOperator<?>) sets.getDataSet(info.parentID);
+		DataSet<?> op2 = sets.getDataSet(info.otherID);
 
 		op1.withBroadcastSet(op2, info.name);
 		Configuration c = op1.getParameters();
@@ -465,82 +463,80 @@ public class PythonPlanBinder {
 		op1.withParameters(c);
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createDistinctOperation(PythonOperationInfo info) throws IOException {
-		DataSet op = (DataSet) sets.get(info.parentID);
-		sets.put(info.setID, op.distinct(info.keys).setParallelism(getParallelism(info)).name("Distinct")
-				.map(new KeyDiscarder()).setParallelism(getParallelism(info)).name("DistinctPostStep"));
-	}
-
-	@SuppressWarnings("unchecked")
-	private void createFirstOperation(PythonOperationInfo info) throws IOException {
-		Object op = sets.get(info.parentID);
-		if (op instanceof DataSet) {
-			sets.put(info.setID, ((DataSet) op).first(info.count).setParallelism(getParallelism(info)).name("First"));
-			return;
-		}
-		if (op instanceof UnsortedGrouping) {
-			sets.put(info.setID, ((UnsortedGrouping) op).first(info.count).setParallelism(getParallelism(info)).name("First")
-				.map(new KeyDiscarder()).setParallelism(getParallelism(info)).name("FirstPostStep"));
-			return;
-		}
-		if (op instanceof SortedGrouping) {
-			sets.put(info.setID, ((SortedGrouping) op).first(info.count).setParallelism(getParallelism(info)).name("First")
-				.map(new KeyDiscarder()).setParallelism(getParallelism(info)).name("FirstPostStep"));
+	private <K extends Tuple> void createDistinctOperation(PythonOperationInfo info) {
+		DataSet<Tuple2<K, byte[]>> op = sets.getDataSet(info.parentID);
+		DataSet<byte[]> result = op
+			.distinct(info.keys).setParallelism(getParallelism(info)).name("Distinct")
+			.map(new KeyDiscarder<K>()).setParallelism(getParallelism(info)).name("DistinctPostStep");
+		sets.add(info.setID, result);
+	}
+
+	private <K extends Tuple> void createFirstOperation(PythonOperationInfo info) {
+		if (sets.isDataSet(info.parentID)) {
+			DataSet<byte[]> op = sets.getDataSet(info.parentID);
+			sets.add(info.setID, op
+				.first(info.count).setParallelism(getParallelism(info)).name("First"));
+		} else if (sets.isUnsortedGrouping(info.parentID)) {
+			UnsortedGrouping<Tuple2<K, byte[]>> op = sets.getUnsortedGrouping(info.parentID);
+			sets.add(info.setID, op
+				.first(info.count).setParallelism(getParallelism(info)).name("First")
+				.map(new KeyDiscarder<K>()).setParallelism(getParallelism(info)).name("FirstPostStep"));
+		} else if (sets.isSortedGrouping(info.parentID)) {
+			SortedGrouping<Tuple2<K, byte[]>> op = sets.getSortedGrouping(info.parentID);
+			sets.add(info.setID, op
+				.first(info.count).setParallelism(getParallelism(info)).name("First")
+				.map(new KeyDiscarder<K>()).setParallelism(getParallelism(info)).name("FirstPostStep"));
 		}
 	}
 
-	private void createGroupOperation(PythonOperationInfo info) throws IOException {
-		DataSet op1 = (DataSet) sets.get(info.parentID);
-		sets.put(info.setID, op1.groupBy(info.keys));
+	private void createGroupOperation(PythonOperationInfo info) {
+		DataSet<?> op1 = sets.getDataSet(info.parentID);
+		sets.add(info.setID, op1.groupBy(info.keys));
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createHashPartitionOperation(PythonOperationInfo info) throws IOException {
-		DataSet op1 = (DataSet) sets.get(info.parentID);
-		sets.put(info.setID, op1.partitionByHash(info.keys).setParallelism(getParallelism(info))
-				.map(new KeyDiscarder()).setParallelism(getParallelism(info)).name("HashPartitionPostStep"));
+	private <K extends Tuple> void createHashPartitionOperation(PythonOperationInfo info) {
+		DataSet<Tuple2<K, byte[]>> op1 = sets.getDataSet(info.parentID);
+		DataSet<byte[]> result = op1
+			.partitionByHash(info.keys).setParallelism(getParallelism(info))
+			.map(new KeyDiscarder<K>()).setParallelism(getParallelism(info)).name("HashPartitionPostStep");
+		sets.add(info.setID, result);
 	}
 
-	private void createRebalanceOperation(PythonOperationInfo info) throws IOException {
-		DataSet op = (DataSet) sets.get(info.parentID);
-		sets.put(info.setID, op.rebalance().setParallelism(getParallelism(info)).name("Rebalance"));
+	private void createRebalanceOperation(PythonOperationInfo info) {
+		DataSet<?> op = sets.getDataSet(info.parentID);
+		sets.add(info.setID, op.rebalance().setParallelism(getParallelism(info)).name("Rebalance"));
 	}
 
-	private void createSortOperation(PythonOperationInfo info) throws IOException {
-		Grouping op1 = (Grouping) sets.get(info.parentID);
-		if (op1 instanceof UnsortedGrouping) {
-			sets.put(info.setID, ((UnsortedGrouping) op1).sortGroup(info.field, info.order));
-			return;
-		}
-		if (op1 instanceof SortedGrouping) {
-			sets.put(info.setID, ((SortedGrouping) op1).sortGroup(info.field, info.order));
+	private void createSortOperation(PythonOperationInfo info) {
+		if (sets.isDataSet(info.parentID)) {
+			throw new IllegalArgumentException("sort() can not be applied on a DataSet.");
+		} else if (sets.isUnsortedGrouping(info.parentID)) {
+			sets.add(info.setID, sets.getUnsortedGrouping(info.parentID).sortGroup(info.field, info.order));
+		} else if (sets.isSortedGrouping(info.parentID)) {
+			sets.add(info.setID, sets.getSortedGrouping(info.parentID).sortGroup(info.field, info.order));
 		}
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createUnionOperation(PythonOperationInfo info) throws IOException {
-		DataSet op1 = (DataSet) sets.get(info.parentID);
-		DataSet op2 = (DataSet) sets.get(info.otherID);
-		sets.put(info.setID, op1.union(op2).setParallelism(getParallelism(info)).name("Union"));
+	private <IN> void createUnionOperation(PythonOperationInfo info) {
+		DataSet<IN> op1 = sets.getDataSet(info.parentID);
+		DataSet<IN> op2 = sets.getDataSet(info.otherID);
+		sets.add(info.setID, op1.union(op2).setParallelism(getParallelism(info)).name("Union"));
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createCoGroupOperation(PythonOperationInfo info) {
-		DataSet op1 = (DataSet) sets.get(info.parentID);
-		DataSet op2 = (DataSet) sets.get(info.otherID);
-		Keys.ExpressionKeys<?> key1 = new Keys.ExpressionKeys(info.keys1, op1.getType());
-		Keys.ExpressionKeys<?> key2 = new Keys.ExpressionKeys(info.keys2, op2.getType());
-		PythonCoGroup pcg = new PythonCoGroup(info.envID, info.setID, info.types);
-		sets.put(info.setID, new CoGroupRawOperator(op1, op2, key1, key2, pcg, info.types, info.name).setParallelism(getParallelism(info)));
+	private <IN1, IN2, OUT> void createCoGroupOperation(PythonOperationInfo info, TypeInformation<OUT> type) {
+		DataSet<IN1> op1 = sets.getDataSet(info.parentID);
+		DataSet<IN2> op2 = sets.getDataSet(info.otherID);
+		Keys.ExpressionKeys<IN1> key1 = new Keys.ExpressionKeys<>(info.keys1, op1.getType());
+		Keys.ExpressionKeys<IN2> key2 = new Keys.ExpressionKeys<>(info.keys2, op2.getType());
+		PythonCoGroup<IN1, IN2, OUT> pcg = new PythonCoGroup<>(info.envID, info.setID, type);
+		sets.add(info.setID, new CoGroupRawOperator<>(op1, op2, key1, key2, pcg, type, info.name).setParallelism(getParallelism(info)));
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createCrossOperation(DatasizeHint mode, PythonOperationInfo info) {
-		DataSet op1 = (DataSet) sets.get(info.parentID);
-		DataSet op2 = (DataSet) sets.get(info.otherID);
+	private <IN1, IN2, OUT> void createCrossOperation(DatasizeHint mode, PythonOperationInfo info, TypeInformation<OUT> type) {
+		DataSet<IN1> op1 = sets.getDataSet(info.parentID);
+		DataSet<IN2> op2 = sets.getDataSet(info.otherID);
 
-		DefaultCross defaultResult;
+		DefaultCross<IN1, IN2> defaultResult;
 		switch (mode) {
 			case NONE:
 				defaultResult = op1.cross(op2);
@@ -557,119 +553,110 @@ public class PythonPlanBinder {
 
 		defaultResult.setParallelism(getParallelism(info));
 		if (info.usesUDF) {
-			sets.put(info.setID, defaultResult.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
+			sets.add(info.setID, defaultResult.mapPartition(new PythonMapPartition<Tuple2<IN1, IN2>, OUT>(info.envID, info.setID, type)).setParallelism(getParallelism(info)).name(info.name));
 		} else {
-			sets.put(info.setID, defaultResult.name("DefaultCross"));
+			sets.add(info.setID, defaultResult.name("DefaultCross"));
 		}
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createFilterOperation(PythonOperationInfo info) {
-		DataSet op1 = (DataSet) sets.get(info.parentID);
-		sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
+	private <IN, OUT> void createFilterOperation(PythonOperationInfo info, TypeInformation<OUT> type) {
+		DataSet<IN> op1 = sets.getDataSet(info.parentID);
+		sets.add(info.setID, op1.mapPartition(new PythonMapPartition<IN, OUT>(info.envID, info.setID, type)).setParallelism(getParallelism(info)).name(info.name));
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createFlatMapOperation(PythonOperationInfo info) {
-		DataSet op1 = (DataSet) sets.get(info.parentID);
-		sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
+	private <IN, OUT> void createFlatMapOperation(PythonOperationInfo info, TypeInformation<OUT> type) {
+		DataSet<IN> op1 = sets.getDataSet(info.parentID);
+		sets.add(info.setID, op1.mapPartition(new PythonMapPartition<IN, OUT>(info.envID, info.setID, type)).setParallelism(getParallelism(info)).name(info.name));
 	}
 
 	private void createGroupReduceOperation(PythonOperationInfo info) {
-		Object op1 = sets.get(info.parentID);
-		if (op1 instanceof DataSet) {
-			sets.put(info.setID, applyGroupReduceOperation((DataSet) op1, info));
-			return;
-		}
-		if (op1 instanceof UnsortedGrouping) {
-			sets.put(info.setID, applyGroupReduceOperation((UnsortedGrouping) op1, info));
-			return;
-		}
-		if (op1 instanceof SortedGrouping) {
-			sets.put(info.setID, applyGroupReduceOperation((SortedGrouping) op1, info));
+		if (sets.isDataSet(info.parentID)) {
+			sets.add(info.setID, applyGroupReduceOperation(sets.getDataSet(info.parentID), info, info.types));
+		} else if (sets.isUnsortedGrouping(info.parentID)) {
+			sets.add(info.setID, applyGroupReduceOperation(sets.getUnsortedGrouping(info.parentID), info, info.types));
+		} else if (sets.isSortedGrouping(info.parentID)) {
+			sets.add(info.setID, applyGroupReduceOperation(sets.getSortedGrouping(info.parentID), info, info.types));
 		}
 	}
 
-	@SuppressWarnings("unchecked")
-	private DataSet applyGroupReduceOperation(DataSet op1, PythonOperationInfo info) {
-		return op1.reduceGroup(new IdentityGroupReduce()).setCombinable(false).name("PythonGroupReducePreStep").setParallelism(getParallelism(info))
-				.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
+	private <IN, OUT> DataSet<OUT> applyGroupReduceOperation(DataSet<IN> op1, PythonOperationInfo info, TypeInformation<OUT> type) {
+		return op1
+			.reduceGroup(new IdentityGroupReduce<IN>()).setCombinable(false).name("PythonGroupReducePreStep").setParallelism(getParallelism(info))
+			.mapPartition(new PythonMapPartition<IN, OUT>(info.envID, info.setID, type)).setParallelism(getParallelism(info)).name(info.name);
 	}
 
-	@SuppressWarnings("unchecked")
-	private DataSet applyGroupReduceOperation(UnsortedGrouping op1, PythonOperationInfo info) {
-		return op1.reduceGroup(new IdentityGroupReduce()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonGroupReducePreStep")
-				.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
+	private <IN, OUT> DataSet<OUT> applyGroupReduceOperation(UnsortedGrouping<IN> op1, PythonOperationInfo info, TypeInformation<OUT> type) {
+		return op1
+			.reduceGroup(new IdentityGroupReduce<IN>()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonGroupReducePreStep")
+			.mapPartition(new PythonMapPartition<IN, OUT>(info.envID, info.setID, type)).setParallelism(getParallelism(info)).name(info.name);
 	}
 
-	@SuppressWarnings("unchecked")
-	private DataSet applyGroupReduceOperation(SortedGrouping op1, PythonOperationInfo info) {
-		return op1.reduceGroup(new IdentityGroupReduce()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonGroupReducePreStep")
-				.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
+	private <IN, OUT> DataSet<OUT> applyGroupReduceOperation(SortedGrouping<IN> op1, PythonOperationInfo info, TypeInformation<OUT> type) {
+		return op1
+			.reduceGroup(new IdentityGroupReduce<IN>()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonGroupReducePreStep")
+			.mapPartition(new PythonMapPartition<IN, OUT>(info.envID, info.setID, type)).setParallelism(getParallelism(info)).name(info.name);
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createJoinOperation(DatasizeHint mode, PythonOperationInfo info) {
-		DataSet op1 = (DataSet) sets.get(info.parentID);
-		DataSet op2 = (DataSet) sets.get(info.otherID);
+	private <IN1, IN2, OUT> void createJoinOperation(DatasizeHint mode, PythonOperationInfo info, TypeInformation<OUT> type) {
+		DataSet<IN1> op1 = sets.getDataSet(info.parentID);
+		DataSet<IN2> op2 = sets.getDataSet(info.otherID);
 
 		if (info.usesUDF) {
-			sets.put(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode, getParallelism(info))
-					.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
+			sets.add(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode, getParallelism(info))
+				.mapPartition(new PythonMapPartition<Tuple2<byte[], byte[]>, OUT>(info.envID, info.setID, type)).setParallelism(getParallelism(info)).name(info.name));
 		} else {
-			sets.put(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode, getParallelism(info)));
+			sets.add(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode, getParallelism(info)));
 		}
 	}
 
-	@SuppressWarnings("unchecked")
-	private DataSet createDefaultJoin(DataSet op1, DataSet op2, String[] firstKeys, String[] secondKeys, DatasizeHint mode, int parallelism) {
+	private <IN1, IN2> DataSet<Tuple2<byte[], byte[]>> createDefaultJoin(DataSet<IN1> op1, DataSet<IN2> op2, String[] firstKeys, String[] secondKeys, DatasizeHint mode, int parallelism) {
 		switch (mode) {
 			case NONE:
-				return op1.join(op2).where(firstKeys).equalTo(secondKeys).setParallelism(parallelism)
-					.map(new NestedKeyDiscarder()).setParallelism(parallelism).name("DefaultJoinPostStep");
+				return op1
+					.join(op2).where(firstKeys).equalTo(secondKeys).setParallelism(parallelism)
+					.map(new NestedKeyDiscarder<Tuple2<IN1, IN2>>()).setParallelism(parallelism).name("DefaultJoinPostStep");
 			case HUGE:
-				return op1.joinWithHuge(op2).where(firstKeys).equalTo(secondKeys).setParallelism(parallelism)
-					.map(new NestedKeyDiscarder()).setParallelism(parallelism).name("DefaultJoinPostStep");
+				return op1
+					.joinWithHuge(op2).where(firstKeys).equalTo(secondKeys).setParallelism(parallelism)
+					.map(new NestedKeyDiscarder<Tuple2<IN1, IN2>>()).setParallelism(parallelism).name("DefaultJoinPostStep");
 			case TINY:
-				return op1.joinWithTiny(op2).where(firstKeys).equalTo(secondKeys).setParallelism(parallelism)
-					.map(new NestedKeyDiscarder()).setParallelism(parallelism).name("DefaultJoinPostStep");
+				return op1
+					.joinWithTiny(op2).where(firstKeys).equalTo(secondKeys).setParallelism(parallelism)
+					.map(new NestedKeyDiscarder<Tuple2<IN1, IN2>>()).setParallelism(parallelism).name("DefaultJoinPostStep");
 			default:
 				throw new IllegalArgumentException("Invalid join mode specified.");
 		}
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createMapOperation(PythonOperationInfo info) {
-		DataSet op1 = (DataSet) sets.get(info.parentID);
-		sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
+	private <IN, OUT> void createMapOperation(PythonOperationInfo info, TypeInformation<OUT> type) {
+		DataSet<IN> op1 = sets.getDataSet(info.parentID);
+		sets.add(info.setID, op1.mapPartition(new PythonMapPartition<IN, OUT>(info.envID, info.setID, type)).setParallelism(getParallelism(info)).name(info.name));
 	}
 
-	@SuppressWarnings("unchecked")
-	private void createMapPartitionOperation(PythonOperationInfo info) {
-		DataSet op1 = (DataSet) sets.get(info.parentID);
-		sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
+	private <IN, OUT> void createMapPartitionOperation(PythonOperationInfo info, TypeInformation<OUT> type) {
+		DataSet<IN> op1 = sets.getDataSet(info.parentID);
+		sets.add(info.setID, op1.mapPartition(new PythonMapPartition<IN, OUT>(info.envID, info.setID, type)).setParallelism(getParallelism(info)).name(info.name));
 	}
 
 	private void createReduceOperation(PythonOperationInfo info) {
-		Object op1 = sets.get(info.parentID);
-		if (op1 instanceof DataSet) {
-			sets.put(info.setID, applyReduceOperation((DataSet) op1, info));
-			return;
-		}
-		if (op1 instanceof UnsortedGrouping) {
-			sets.put(info.setID, applyReduceOperation((UnsortedGrouping) op1, info));
+		if (sets.isDataSet(info.parentID)) {
+			sets.add(info.setID, applyReduceOperation(sets.getDataSet(info.parentID), info, info.types));
+		} else if (sets.isUnsortedGrouping(info.parentID)) {
+			sets.add(info.setID, applyReduceOperation(sets.getUnsortedGrouping(info.parentID), info, info.types));
+		} else if (sets.isSortedGrouping(info.parentID)) {
+			throw new IllegalArgumentException("Reduce cannot be applied on a SortedGrouping.");
 		}
 	}
 
-	@SuppressWarnings("unchecked")
-	private DataSet applyReduceOperation(DataSet op1, PythonOperationInfo info) {
-		return op1.reduceGroup(new IdentityGroupReduce()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonReducePreStep")
-				.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
+	private <IN, OUT> DataSet<OUT> applyReduceOperation(DataSet<IN> op1, PythonOperationInfo info, TypeInformation<OUT> type) {
+		return op1
+			.reduceGroup(new IdentityGroupReduce<IN>()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonReducePreStep")
+			.mapPartition(new PythonMapPartition<IN, OUT>(info.envID, info.setID, type)).setParallelism(getParallelism(info)).name(info.name);
 	}
 
-	@SuppressWarnings("unchecked")
-	private DataSet applyReduceOperation(UnsortedGrouping op1, PythonOperationInfo info) {
-		return op1.reduceGroup(new IdentityGroupReduce()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonReducePreStep")
-				.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
+	private <IN, OUT> DataSet<OUT> applyReduceOperation(UnsortedGrouping<IN> op1, PythonOperationInfo info, TypeInformation<OUT> type) {
+		return op1
+			.reduceGroup(new IdentityGroupReduce<IN>()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonReducePreStep")
+			.mapPartition(new PythonMapPartition<IN, OUT>(info.envID, info.setID, type)).setParallelism(getParallelism(info)).name(info.name);
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/bba49d67/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/util/SetCache.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/util/SetCache.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/util/SetCache.java
new file mode 100644
index 0000000..750ba63
--- /dev/null
+++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/util/SetCache.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.python.api.util;
+
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.operators.SortedGrouping;
+import org.apache.flink.api.java.operators.UnsortedGrouping;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A container holding {@link DataSet DataSets}, {@link SortedGrouping sorted} and{@link UnsortedGrouping unsorted}
+ * groupings.
+ */
+public class SetCache {
+
+	private enum SetType {
+
+		DATA_SET(DataSet.class.getName()),
+		UNSORTED_GROUPING(UnsortedGrouping.class.getName()),
+		SORTED_GROUPING(SortedGrouping.class.getName());
+
+		private final String className;
+
+		SetType(String className) {
+			this.className = className;
+		}
+
+		@Override
+		public String toString() {
+			return className;
+		}
+	}
+
+	private final Map<Integer, SetType> setTypes = new HashMap<>();
+
+	@SuppressWarnings("rawtypes")
+	private final Map<Integer, DataSet> dataSets = new HashMap<>();
+	@SuppressWarnings("rawtypes")
+	private final Map<Integer, UnsortedGrouping> unsortedGroupings = new HashMap<>();
+	@SuppressWarnings("rawtypes")
+	private final Map<Integer, SortedGrouping> sortedGroupings = new HashMap<>();
+
+	private int cachedID = -1;
+	private SetType cachedType = null;
+
+	/**
+	 * Adds the given {@link DataSet} to this cache for the given ID.
+	 *
+	 * @param id  Set ID
+	 * @param set DataSet to add
+	 * @param <D> DataSet class
+	 */
+	public <D extends DataSet<?>> void add(int id, D set) {
+		cacheSetType(id, SetType.DATA_SET);
+		dataSets.put(id, set);
+	}
+
+	/**
+	 * Adds the given {@link UnsortedGrouping} to this cache for the given ID.
+	 *
+	 * @param id  Set ID
+	 * @param set UnsortedGrouping to add
+	 * @param <U> UnsortedGrouping class
+	 */
+	public <U extends UnsortedGrouping<?>> void add(int id, U set) {
+		cacheSetType(id, SetType.UNSORTED_GROUPING);
+		unsortedGroupings.put(id, set);
+	}
+
+	/**
+	 * Adds the given {@link SortedGrouping} to this cache for the given ID.
+	 *
+	 * @param id  Set ID
+	 * @param set SortedGrouping to add
+	 * @param <S> SortedGrouping class
+	 */
+	public <S extends SortedGrouping<?>> void add(int id, S set) {
+		cacheSetType(id, SetType.SORTED_GROUPING);
+		sortedGroupings.put(id, set);
+	}
+
+	private <T> void cacheSetType(int id, SetType type) {
+		SetType prior = setTypes.put(id, type);
+		if (prior != null) {
+			throw new IllegalStateException("Set ID " + id + " used to denote multiple sets.");
+		}
+	}
+
+	/**
+	 * Checks whether the cached set for the given ID is a {@link DataSet}.
+	 *
+	 * @param id id of set to check
+	 * @return true, if the cached set is a DataSet, false otherwise
+	 */
+	public boolean isDataSet(int id) {
+		return isType(id, SetType.DATA_SET);
+	}
+
+	/**
+	 * Checks whether the cached set for the given ID is an {@link UnsortedGrouping}.
+	 *
+	 * @param id id of set to check
+	 * @return true, if the cached set is an UnsortedGrouping, false otherwise
+	 */
+	public boolean isUnsortedGrouping(int id) {
+		return isType(id, SetType.UNSORTED_GROUPING);
+	}
+
+	/**
+	 * Checks whether the cached set for the given ID is a {@link SortedGrouping}.
+	 *
+	 * @param id Set ID
+	 * @return true, if the cached set is a SortedGrouping, false otherwise
+	 */
+	public boolean isSortedGrouping(int id) {
+		return isType(id, SetType.SORTED_GROUPING);
+	}
+
+	private boolean isType(int id, SetType type) {
+		if (cachedID != id) {
+			cachedID = id;
+			cachedType = setTypes.get(id);
+			if (cachedType == null) {
+				throw new IllegalStateException("No set exists for the given ID " + id);
+			}
+		}
+		return cachedType == type;
+	}
+
+	/**
+	 * Returns the cached {@link DataSet} for the given ID.
+	 *
+	 * @param id  Set ID
+	 * @param <T> DataSet type
+	 * @return Cached DataSet
+	 * @throws IllegalStateException if the cached set is not a DataSet
+	 */
+	@SuppressWarnings("unchecked")
+	public <T> DataSet<T> getDataSet(int id) {
+		return verifyType(id, dataSets.get(id), SetType.DATA_SET);
+	}
+
+	/**
+	 * Returns the cached {@link UnsortedGrouping} for the given ID.
+	 *
+	 * @param id  Set ID
+	 * @param <T> UnsortedGrouping type
+	 * @return Cached UnsortedGrouping
+	 * @throws IllegalStateException if the cached set is not an UnsortedGrouping
+	 */
+	@SuppressWarnings("unchecked")
+	public <T> UnsortedGrouping<T> getUnsortedGrouping(int id) {
+		return verifyType(id, unsortedGroupings.get(id), SetType.UNSORTED_GROUPING);
+	}
+
+	/**
+	 * Returns the cached {@link SortedGrouping} for the given ID.
+	 *
+	 * @param id  Set ID
+	 * @param <T> SortedGrouping type
+	 * @return Cached SortedGrouping
+	 * @throws IllegalStateException if the cached set is not a SortedGrouping
+	 */
+	@SuppressWarnings("unchecked")
+	public <T> SortedGrouping<T> getSortedGrouping(int id) {
+		return verifyType(id, sortedGroupings.get(id), SetType.SORTED_GROUPING);
+	}
+
+	private <X> X verifyType(int id, X set, SetType type) {
+		if (set == null) {
+			SetType actualType = setTypes.get(id);
+			throw new IllegalStateException("Set ID " + id + " did not denote a " + type + ", but a " + actualType + " instead.");
+		}
+		return set;
+	}
+
+	/**
+	 * Resets this SetCache, removing any cached sets.
+	 */
+	public void reset() {
+		setTypes.clear();
+
+		dataSets.clear();
+		unsortedGroupings.clear();
+		sortedGroupings.clear();
+	}
+}