You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2014/11/18 12:22:50 UTC
[2/4] incubator-flink git commit: [FLINK-1237] Add support for custom
partitioners - Functions: GroupReduce, Reduce, Aggregate on UnsortedGrouping,
SortedGrouping,
Join (Java API & Scala API) - Manual partition on DataSet (Java API & S
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java
index e906232..66821ae 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java
@@ -208,10 +208,9 @@ public class AggregateOperator<IN> extends SingleInputOperator<IN, IN, Aggregate
po.setCombinable(true);
- // set input
po.setInput(input);
- // set dop
po.setDegreeOfParallelism(this.getParallelism());
+ po.setCustomPartitioner(grouping.getCustomPartitioner());
SingleInputSemanticProperties props = new SingleInputSemanticProperties();
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java
index 126949c..e60c7de 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java
@@ -22,9 +22,11 @@ import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.operators.Operator;
+import org.apache.flink.api.common.operators.SingleInputSemanticProperties;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase;
import org.apache.flink.api.common.operators.base.MapOperatorBase;
+import org.apache.flink.api.common.operators.util.FieldSet;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
@@ -71,7 +73,7 @@ public class DistinctOperator<T> extends SingleInputOperator<T, T, DistinctOpera
}
- // FieldPositionKeys can only be applied on Tuples
+ // FieldPositionKeys can only be applied on Tuples and POJOs
if (keys instanceof Keys.ExpressionKeys && !(input.getType() instanceof CompositeType)) {
throw new InvalidProgramException("Distinction on field positions is only possible on composite type DataSets.");
}
@@ -84,7 +86,7 @@ public class DistinctOperator<T> extends SingleInputOperator<T, T, DistinctOpera
final RichGroupReduceFunction<T, T> function = new DistinctFunction<T>();
- String name = "Distinct at "+distinctLocationName;
+ String name = "Distinct at " + distinctLocationName;
if (keys instanceof Keys.ExpressionKeys) {
@@ -95,7 +97,19 @@ public class DistinctOperator<T> extends SingleInputOperator<T, T, DistinctOpera
po.setCombinable(true);
po.setInput(input);
- po.setDegreeOfParallelism(this.getParallelism());
+ po.setDegreeOfParallelism(getParallelism());
+
+ // make sure that distinct preserves the partitioning for the fields on which they operate
+ if (getType().isTupleType()) {
+ SingleInputSemanticProperties sProps = new SingleInputSemanticProperties();
+
+ for (int field : keys.computeLogicalKeyPositions()) {
+ sProps.setForwardedField(field, new FieldSet(field));
+ }
+
+ po.setSemanticProperties(sProps);
+ }
+
return po;
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java
index 327d12f..bef91ed 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java
@@ -113,10 +113,14 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
return this;
}
+ // --------------------------------------------------------------------------------------------
+ // Translation
+ // --------------------------------------------------------------------------------------------
+
@Override
- protected org.apache.flink.api.common.operators.base.GroupReduceOperatorBase<?, OUT, ?> translateToDataFlow(Operator<IN> input) {
+ protected GroupReduceOperatorBase<?, OUT, ?> translateToDataFlow(Operator<IN> input) {
- String name = getName() != null ? getName() : "GroupReduce at "+defaultName;
+ String name = getName() != null ? getName() : "GroupReduce at " + defaultName;
// distinguish between grouped reduce and non-grouped reduce
if (grouper == null) {
@@ -124,9 +128,8 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
UnaryOperatorInformation<IN, OUT> operatorInfo = new UnaryOperatorInformation<IN, OUT>(getInputType(), getResultType());
GroupReduceOperatorBase<IN, OUT, GroupReduceFunction<IN, OUT>> po =
new GroupReduceOperatorBase<IN, OUT, GroupReduceFunction<IN, OUT>>(function, operatorInfo, new int[0], name);
-
+
po.setCombinable(combinable);
- // set input
po.setInput(input);
// the degree of parallelism for a non grouped reduce can only be 1
po.setDegreeOfParallelism(1);
@@ -141,7 +144,8 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
PlanUnwrappingReduceGroupOperator<IN, OUT, ?> po = translateSelectorFunctionReducer(
selectorKeys, function, getInputType(), getResultType(), name, input, isCombinable());
- po.setDegreeOfParallelism(this.getParallelism());
+ po.setDegreeOfParallelism(getParallelism());
+ po.setCustomPartitioner(grouper.getCustomPartitioner());
return po;
}
@@ -154,7 +158,8 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
po.setCombinable(combinable);
po.setInput(input);
- po.setDegreeOfParallelism(this.getParallelism());
+ po.setDegreeOfParallelism(getParallelism());
+ po.setCustomPartitioner(grouper.getCustomPartitioner());
// set group order
if (grouper instanceof SortedGrouping) {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/Grouping.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/Grouping.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/Grouping.java
index 36a364e..3c0d07f 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/Grouping.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/Grouping.java
@@ -19,7 +19,7 @@
package org.apache.flink.api.java.operators;
import org.apache.flink.api.common.InvalidProgramException;
-
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.java.DataSet;
/**
@@ -40,7 +40,10 @@ public abstract class Grouping<T> {
protected final DataSet<T> dataSet;
protected final Keys<T> keys;
+
+ protected Partitioner<?> customPartitioner;
+
public Grouping(DataSet<T> set, Keys<T> keys) {
if (set == null || keys == null) {
throw new NullPointerException();
@@ -62,5 +65,14 @@ public abstract class Grouping<T> {
public Keys<T> getKeys() {
return this.keys;
}
-
+
+ /**
+ * Gets the custom partitioner to be used for this grouping, or {@code null}, if
+ * none was defined.
+ *
+ * @return The custom partitioner to be used for this grouping.
+ */
+ public Partitioner<?> getCustomPartitioner() {
+ return this.customPartitioner;
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java
index 93e0371..21534f1 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java
@@ -25,6 +25,7 @@ import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichFlatJoinFunction;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.operators.DualInputSemanticProperties;
@@ -46,13 +47,16 @@ import org.apache.flink.api.java.operators.translation.PlanBothUnwrappingJoinOpe
import org.apache.flink.api.java.operators.translation.PlanLeftUnwrappingJoinOperator;
import org.apache.flink.api.java.operators.translation.PlanRightUnwrappingJoinOperator;
import org.apache.flink.api.java.operators.translation.WrappingFunction;
-//CHECKSTYLE.OFF: AvoidStarImport - Needed for TupleGenerator
-import org.apache.flink.api.java.tuple.*;
-//CHECKSTYLE.ON: AvoidStarImport
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.util.Collector;
+//CHECKSTYLE.OFF: AvoidStarImport - Needed for TupleGenerator
+import org.apache.flink.api.java.tuple.*;
+
+import com.google.common.base.Preconditions;
+//CHECKSTYLE.ON: AvoidStarImport
+
/**
* A {@link DataSet} that is the result of a Join transformation.
*
@@ -69,14 +73,25 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
private final JoinHint joinHint;
+ private Partitioner<?> customPartitioner;
+
+
protected JoinOperator(DataSet<I1> input1, DataSet<I2> input2,
Keys<I1> keys1, Keys<I2> keys2,
TypeInformation<OUT> returnType, JoinHint hint)
{
super(input1, input2, returnType);
- if (keys1 == null || keys2 == null) {
- throw new NullPointerException();
+ Preconditions.checkNotNull(keys1);
+ Preconditions.checkNotNull(keys2);
+
+ try {
+ if (!keys1.areCompatible(keys2)) {
+ throw new InvalidProgramException("The types of the key fields do not match.");
+ }
+ }
+ catch (IncompatibleKeysException ike) {
+ throw new InvalidProgramException("The types of the key fields do not match: " + ike.getMessage(), ike);
}
// sanity check solution set key mismatches
@@ -110,10 +125,43 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
return this.keys2;
}
+ /**
+ * Gets the JoinHint that describes how the join is executed.
+ *
+ * @return The JoinHint.
+ */
public JoinHint getJoinHint() {
return this.joinHint;
}
+ /**
+ * Sets a custom partitioner for this join. The partitioner will be called on the join keys to determine
+ * the partition a key should be assigned to. The partitioner is evaluated on both join inputs in the
+ * same way.
+ * <p>
+ * NOTE: A custom partitioner can only be used with single-field join keys, not with composite join keys.
+ *
+ * @param partitioner The custom partitioner to be used.
+ * @return This join operator, to allow for function chaining.
+ */
+ public JoinOperator<I1, I2, OUT> withPartitioner(Partitioner<?> partitioner) {
+ if (partitioner != null) {
+ keys1.validateCustomPartitioner(partitioner, null);
+ keys2.validateCustomPartitioner(partitioner, null);
+ }
+ this.customPartitioner = partitioner;
+ return this;
+ }
+
+ /**
+ * Gets the custom partitioner used by this join, or {@code null}, if none is set.
+ *
+ * @return The custom partitioner used by this join;
+ */
+ public Partitioner<?> getPartitioner() {
+ return customPartitioner;
+ }
+
// --------------------------------------------------------------------------------------------
// special join types
// --------------------------------------------------------------------------------------------
@@ -206,30 +254,20 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
// }
@Override
- protected JoinOperatorBase<?, ?, OUT, ?> translateToDataFlow(
- Operator<I1> input1,
- Operator<I2> input2) {
+ protected JoinOperatorBase<?, ?, OUT, ?> translateToDataFlow(Operator<I1> input1, Operator<I2> input2) {
String name = getName() != null ? getName() : "Join at "+joinLocationName;
- try {
- keys1.areCompatible(super.keys2);
- } catch(IncompatibleKeysException ike) {
- throw new InvalidProgramException("The types of the key fields do not match.", ike);
- }
final JoinOperatorBase<?, ?, OUT, ?> translated;
- if (keys1 instanceof Keys.SelectorFunctionKeys
- && keys2 instanceof Keys.SelectorFunctionKeys) {
+ if (keys1 instanceof Keys.SelectorFunctionKeys && keys2 instanceof Keys.SelectorFunctionKeys) {
// Both join sides have a key selector function, so we need to do the
// tuple wrapping/unwrapping on both sides.
@SuppressWarnings("unchecked")
- Keys.SelectorFunctionKeys<I1, ?> selectorKeys1 =
- (Keys.SelectorFunctionKeys<I1, ?>) keys1;
+ Keys.SelectorFunctionKeys<I1, ?> selectorKeys1 = (Keys.SelectorFunctionKeys<I1, ?>) keys1;
@SuppressWarnings("unchecked")
- Keys.SelectorFunctionKeys<I2, ?> selectorKeys2 =
- (Keys.SelectorFunctionKeys<I2, ?>) keys2;
+ Keys.SelectorFunctionKeys<I2, ?> selectorKeys2 = (Keys.SelectorFunctionKeys<I2, ?>) keys2;
PlanBothUnwrappingJoinOperator<I1, I2, OUT, ?> po =
translateSelectorFunctionJoin(selectorKeys1, selectorKeys2, function,
@@ -304,6 +342,7 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
}
translated.setJoinHint(getJoinHint());
+ translated.setCustomPartitioner(getPartitioner());
return translated;
}
@@ -506,22 +545,6 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
out.collect (this.wrappedFunction.join(left, right));
}
}
-
- /*
- private static class GeneratedFlatJoinFunction<IN1, IN2, OUT> extends FlatJoinFunction<IN1, IN2, OUT> {
-
- private Joinable<IN1,IN2,OUT> function;
-
- private GeneratedFlatJoinFunction(Joinable<IN1, IN2, OUT> function) {
- this.function = function;
- }
-
- @Override
- public void join(IN1 first, IN2 second, Collector<OUT> out) throws Exception {
- out.collect(function.join(first, second));
- }
- }
- */
/**
* Initiates a ProjectJoin transformation and projects the first join input<br/>
@@ -933,32 +956,6 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
}
}
- public static final class LeftSemiFlatJoinFunction<T1, T2> extends RichFlatJoinFunction<T1, T2, T1> {
-
- private static final long serialVersionUID = 1L;
-
- @Override
- //public T1 join(T1 left, T2 right) throws Exception {
- // return left;
- //}
- public void join (T1 left, T2 right, Collector<T1> out) {
- out.collect(left);
- }
- }
-
- public static final class RightSemiFlatJoinFunction<T1, T2> extends RichFlatJoinFunction<T1, T2, T2> {
-
- private static final long serialVersionUID = 1L;
-
- @Override
- //public T2 join(T1 left, T2 right) throws Exception {
- // return right;
- //}
- public void join (T1 left, T2 right, Collector<T2> out) {
- out.collect(right);
- }
- }
-
public static final class JoinProjection<I1, I2> {
private final DataSet<I1> ds1;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java
index 46bbfab..c2a2a8e 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java
@@ -24,13 +24,17 @@ import java.util.LinkedList;
import java.util.List;
import com.google.common.base.Joiner;
+
import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor;
import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -54,6 +58,8 @@ public abstract class Keys<T> {
public abstract int[] computeLogicalKeyPositions();
+ public abstract <E> void validateCustomPartitioner(Partitioner<E> partitioner, TypeInformation<E> typeInfo);
+
// --------------------------------------------------------------------------------------------
// Specializations for expression-based / extractor-based grouping
@@ -146,6 +152,27 @@ public abstract class Keys<T> {
public int[] computeLogicalKeyPositions() {
return logicalKeyFields;
}
+
+ @Override
+ public <E> void validateCustomPartitioner(Partitioner<E> partitioner, TypeInformation<E> typeInfo) {
+ if (logicalKeyFields.length != 1) {
+ throw new InvalidProgramException("Custom partitioners can only be used with keys that have one key field.");
+ }
+
+ if (typeInfo == null) {
+ try {
+ typeInfo = TypeExtractor.getPartitionerTypes(partitioner);
+ }
+ catch (Throwable t) {
+ // best effort check, so we ignore exceptions
+ }
+ }
+
+ if (typeInfo != null && !(typeInfo instanceof GenericTypeInfo) && (!keyType.equals(typeInfo))) {
+ throw new InvalidProgramException("The partitioner is imcompatible with the key type. "
+ + "Partitioner type: " + typeInfo + " , key type: " + keyType);
+ }
+ }
@Override
public String toString() {
@@ -299,12 +326,36 @@ public abstract class Keys<T> {
@Override
public int[] computeLogicalKeyPositions() {
- List<Integer> logicalKeys = new LinkedList<Integer>();
- for(FlatFieldDescriptor kd : keyFields) {
- logicalKeys.addAll( Ints.asList(kd.getPosition()));
+ List<Integer> logicalKeys = new ArrayList<Integer>();
+ for (FlatFieldDescriptor kd : keyFields) {
+ logicalKeys.add(kd.getPosition());
}
return Ints.toArray(logicalKeys);
}
+
+ @Override
+ public <E> void validateCustomPartitioner(Partitioner<E> partitioner, TypeInformation<E> typeInfo) {
+ if (keyFields.size() != 1) {
+ throw new InvalidProgramException("Custom partitioners can only be used with keys that have one key field.");
+ }
+
+ if (typeInfo == null) {
+ try {
+ typeInfo = TypeExtractor.getPartitionerTypes(partitioner);
+ }
+ catch (Throwable t) {
+ // best effort check, so we ignore exceptions
+ }
+ }
+
+ if (typeInfo != null && !(typeInfo instanceof GenericTypeInfo)) {
+ TypeInformation<?> keyType = keyFields.get(0).getType();
+ if (!keyType.equals(typeInfo)) {
+ throw new InvalidProgramException("The partitioner is incompatible with the key type. "
+ + "Partitioner type: " + typeInfo + " , key type: " + keyType);
+ }
+ }
+ }
@Override
public String toString() {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java
index 77d5681..22d4d44 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java
@@ -19,6 +19,7 @@
package org.apache.flink.api.java.operators;
import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.Operator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.base.MapOperatorBase;
@@ -32,6 +33,8 @@ import org.apache.flink.api.java.operators.translation.KeyRemovingMapper;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import com.google.common.base.Preconditions;
+
/**
* This operator represents a partitioning.
*
@@ -42,66 +45,102 @@ public class PartitionOperator<T> extends SingleInputUdfOperator<T, T, Partition
private final Keys<T> pKeys;
private final PartitionMethod pMethod;
private final String partitionLocationName;
+ private final Partitioner<?> customPartitioner;
+
public PartitionOperator(DataSet<T> input, PartitionMethod pMethod, Keys<T> pKeys, String partitionLocationName) {
+ this(input, pMethod, pKeys, null, null, partitionLocationName);
+ }
+
+ public PartitionOperator(DataSet<T> input, PartitionMethod pMethod, String partitionLocationName) {
+ this(input, pMethod, null, null, null, partitionLocationName);
+ }
+
+ public PartitionOperator(DataSet<T> input, Keys<T> pKeys, Partitioner<?> customPartitioner, String partitionLocationName) {
+ this(input, PartitionMethod.CUSTOM, pKeys, customPartitioner, null, partitionLocationName);
+ }
+
+ public <P> PartitionOperator(DataSet<T> input, Keys<T> pKeys, Partitioner<P> customPartitioner,
+ TypeInformation<P> partitionerTypeInfo, String partitionLocationName)
+ {
+ this(input, PartitionMethod.CUSTOM, pKeys, customPartitioner, partitionerTypeInfo, partitionLocationName);
+ }
+
+ private <P> PartitionOperator(DataSet<T> input, PartitionMethod pMethod, Keys<T> pKeys, Partitioner<P> customPartitioner,
+ TypeInformation<P> partitionerTypeInfo, String partitionLocationName)
+ {
super(input, input.getType());
- this.partitionLocationName = partitionLocationName;
-
- if(pMethod == PartitionMethod.HASH && pKeys == null) {
- throw new IllegalArgumentException("Hash Partitioning requires keys");
- } else if(pMethod == PartitionMethod.RANGE) {
- throw new UnsupportedOperationException("Range Partitioning not yet supported");
+
+ Preconditions.checkNotNull(pMethod);
+ Preconditions.checkArgument(pKeys != null || pMethod == PartitionMethod.REBALANCE, "Partitioning requires keys");
+ Preconditions.checkArgument(pMethod != PartitionMethod.CUSTOM || customPartitioner != null, "Custom partioning requires a partitioner.");
+ Preconditions.checkArgument(pMethod != PartitionMethod.RANGE, "Range partitioning is not yet supported");
+
+ if (pKeys instanceof Keys.ExpressionKeys<?> && !(input.getType() instanceof CompositeType) ) {
+ throw new IllegalArgumentException("Hash Partitioning with key fields only possible on Tuple or POJO DataSets");
}
- if(pKeys instanceof Keys.ExpressionKeys<?> && !(input.getType() instanceof CompositeType) ) {
- throw new IllegalArgumentException("Hash Partitioning with key fields only possible on Composite-type DataSets");
+ if (customPartitioner != null) {
+ pKeys.validateCustomPartitioner(customPartitioner, partitionerTypeInfo);
}
this.pMethod = pMethod;
this.pKeys = pKeys;
+ this.partitionLocationName = partitionLocationName;
+ this.customPartitioner = customPartitioner;
}
- public PartitionOperator(DataSet<T> input, PartitionMethod pMethod, String partitionLocationName) {
- this(input, pMethod, null, partitionLocationName);
- }
+ // --------------------------------------------------------------------------------------------
+ // Properties
+ // --------------------------------------------------------------------------------------------
- /*
- * Translation of partitioning
+ /**
+ * Gets the custom partitioner from this partitioning.
+ *
+ * @return The custom partitioner.
*/
+ public Partitioner<?> getCustomPartitioner() {
+ return customPartitioner;
+ }
+
+ // --------------------------------------------------------------------------------------------
+ // Translation
+ // --------------------------------------------------------------------------------------------
+
protected org.apache.flink.api.common.operators.SingleInputOperator<?, T, ?> translateToDataFlow(Operator<T> input) {
- String name = "Partition at "+partitionLocationName;
+ String name = "Partition at " + partitionLocationName;
// distinguish between partition types
if (pMethod == PartitionMethod.REBALANCE) {
UnaryOperatorInformation<T, T> operatorInfo = new UnaryOperatorInformation<T, T>(getType(), getType());
PartitionOperatorBase<T> noop = new PartitionOperatorBase<T>(operatorInfo, pMethod, name);
- // set input
+
noop.setInput(input);
- // set DOP
noop.setDegreeOfParallelism(getParallelism());
return noop;
}
- else if (pMethod == PartitionMethod.HASH) {
+ else if (pMethod == PartitionMethod.HASH || pMethod == PartitionMethod.CUSTOM) {
if (pKeys instanceof Keys.ExpressionKeys) {
int[] logicalKeyPositions = pKeys.computeLogicalKeyPositions();
UnaryOperatorInformation<T, T> operatorInfo = new UnaryOperatorInformation<T, T>(getType(), getType());
PartitionOperatorBase<T> noop = new PartitionOperatorBase<T>(operatorInfo, pMethod, logicalKeyPositions, name);
- // set input
+
noop.setInput(input);
- // set DOP
noop.setDegreeOfParallelism(getParallelism());
+ noop.setCustomPartitioner(customPartitioner);
return noop;
- } else if (pKeys instanceof Keys.SelectorFunctionKeys) {
+ }
+ else if (pKeys instanceof Keys.SelectorFunctionKeys) {
@SuppressWarnings("unchecked")
Keys.SelectorFunctionKeys<T, ?> selectorKeys = (Keys.SelectorFunctionKeys<T, ?>) pKeys;
- MapOperatorBase<?, T, ?> po = translateSelectorFunctionReducer(selectorKeys, pMethod, getType(), name, input, getParallelism());
+ MapOperatorBase<?, T, ?> po = translateSelectorFunctionPartitioner(selectorKeys, pMethod, getType(), name, input, getParallelism(), customPartitioner);
return po;
}
else {
@@ -112,14 +151,13 @@ public class PartitionOperator<T> extends SingleInputUdfOperator<T, T, Partition
else if (pMethod == PartitionMethod.RANGE) {
throw new UnsupportedOperationException("Range partitioning not yet supported");
}
-
- return null;
+ else {
+ throw new UnsupportedOperationException("Unsupported partitioning method: " + pMethod.name());
+ }
}
-
- // --------------------------------------------------------------------------------------------
- private static <T, K> MapOperatorBase<Tuple2<K, T>, T, ?> translateSelectorFunctionReducer(Keys.SelectorFunctionKeys<T, ?> rawKeys,
- PartitionMethod pMethod, TypeInformation<T> inputType, String name, Operator<T> input, int partitionDop)
+ private static <T, K> MapOperatorBase<Tuple2<K, T>, T, ?> translateSelectorFunctionPartitioner(Keys.SelectorFunctionKeys<T, ?> rawKeys,
+ PartitionMethod pMethod, TypeInformation<T> inputType, String name, Operator<T> input, int partitionDop, Partitioner<?> customPartitioner)
{
@SuppressWarnings("unchecked")
final Keys.SelectorFunctionKeys<T, K> keys = (Keys.SelectorFunctionKeys<T, K>) rawKeys;
@@ -137,6 +175,8 @@ public class PartitionOperator<T> extends SingleInputUdfOperator<T, T, Partition
noop.setInput(keyExtractingMap);
keyRemovingMap.setInput(noop);
+ noop.setCustomPartitioner(customPartitioner);
+
// set dop
keyExtractingMap.setDegreeOfParallelism(input.getDegreeOfParallelism());
noop.setDegreeOfParallelism(partitionDop);
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java
index 7089cf6..02b0ede 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java
@@ -87,9 +87,8 @@ public class ReduceOperator<IN> extends SingleInputUdfOperator<IN, IN, ReduceOpe
UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<IN, IN>(getInputType(), getInputType());
ReduceOperatorBase<IN, ReduceFunction<IN>> po =
new ReduceOperatorBase<IN, ReduceFunction<IN>>(function, operatorInfo, new int[0], name);
- // set input
- po.setInput(input);
+ po.setInput(input);
// the degree of parallelism for a non grouped reduce can only be 1
po.setDegreeOfParallelism(1);
@@ -102,7 +101,9 @@ public class ReduceOperator<IN> extends SingleInputUdfOperator<IN, IN, ReduceOpe
@SuppressWarnings("unchecked")
Keys.SelectorFunctionKeys<IN, ?> selectorKeys = (Keys.SelectorFunctionKeys<IN, ?>) grouper.getKeys();
- MapOperatorBase<?, IN, ?> po = translateSelectorFunctionReducer(selectorKeys, function, getInputType(), name, input, this.getParallelism());
+ MapOperatorBase<?, IN, ?> po = translateSelectorFunctionReducer(selectorKeys, function, getInputType(), name, input, getParallelism());
+ ((PlanUnwrappingReduceOperator<?, ?>) po.getInput()).setCustomPartitioner(grouper.getCustomPartitioner());
+
return po;
}
else if (grouper.getKeys() instanceof Keys.ExpressionKeys) {
@@ -113,17 +114,16 @@ public class ReduceOperator<IN> extends SingleInputUdfOperator<IN, IN, ReduceOpe
ReduceOperatorBase<IN, ReduceFunction<IN>> po =
new ReduceOperatorBase<IN, ReduceFunction<IN>>(function, operatorInfo, logicalKeyPositions, name);
- // set input
+ po.setCustomPartitioner(grouper.getCustomPartitioner());
+
po.setInput(input);
- // set dop
- po.setDegreeOfParallelism(this.getParallelism());
+ po.setDegreeOfParallelism(getParallelism());
return po;
}
else {
throw new UnsupportedOperationException("Unrecognized key type.");
}
-
}
// --------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/SortedGrouping.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortedGrouping.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortedGrouping.java
index 36d14ee..63e5a19 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortedGrouping.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortedGrouping.java
@@ -27,6 +27,7 @@ import java.util.Arrays;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.GroupReduceFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.Keys.ExpressionKeys;
@@ -34,7 +35,6 @@ import org.apache.flink.api.java.typeutils.TypeExtractor;
import com.google.common.base.Preconditions;
-
/**
* SortedGrouping is an intermediate step for a transformation on a grouped and sorted DataSet.<br/>
* The following transformation can be applied on sorted groups:
@@ -84,6 +84,8 @@ public class SortedGrouping<T> extends Grouping<T> {
Arrays.fill(this.groupSortOrders, order); // if field == "*"
}
+ // --------------------------------------------------------------------------------------------
+
protected int[] getGroupSortKeyPositions() {
return this.groupSortKeyPositions;
}
@@ -91,6 +93,21 @@ public class SortedGrouping<T> extends Grouping<T> {
protected Order[] getGroupSortOrders() {
return this.groupSortOrders;
}
+
+ /**
+ * Uses a custom partitioner for the grouping.
+ *
+ * @param partitioner The custom partitioner.
+ * @return The grouping object itself, to allow for method chaining.
+ */
+ public SortedGrouping<T> withPartitioner(Partitioner<?> partitioner) {
+ Preconditions.checkNotNull(partitioner);
+
+ getKeys().validateCustomPartitioner(partitioner, null);
+
+ this.customPartitioner = partitioner;
+ return this;
+ }
/**
* Applies a GroupReduce transformation on a grouped and sorted {@link DataSet}.<br/>
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/UnsortedGrouping.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/UnsortedGrouping.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/UnsortedGrouping.java
index b504e37..d323eae 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/UnsortedGrouping.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/UnsortedGrouping.java
@@ -20,6 +20,7 @@ package org.apache.flink.api.java.operators;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.GroupReduceFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -32,11 +33,27 @@ import org.apache.flink.api.java.functions.SelectByMinFunction;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
+import com.google.common.base.Preconditions;
+
public class UnsortedGrouping<T> extends Grouping<T> {
public UnsortedGrouping(DataSet<T> set, Keys<T> keys) {
super(set, keys);
}
+
+ /**
+ * Uses a custom partitioner for the grouping.
+ *
+ * @param partitioner The custom partitioner.
+ * @return The grouping object itself, to allow for method chaining.
+ */
+ public UnsortedGrouping<T> withPartitioner(Partitioner<?> partitioner) {
+ Preconditions.checkNotNull(partitioner);
+ getKeys().validateCustomPartitioner(partitioner, null);
+
+ this.customPartitioner = partitioner;
+ return this;
+ }
// --------------------------------------------------------------------------------------------
// Operations / Transformations
@@ -213,7 +230,9 @@ public class UnsortedGrouping<T> extends Grouping<T> {
* @see Order
*/
public SortedGrouping<T> sortGroup(int field, Order order) {
- return new SortedGrouping<T>(this.dataSet, this.keys, field, order);
+ SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, field, order);
+ sg.customPartitioner = getCustomPartitioner();
+ return sg;
}
/**
@@ -228,7 +247,9 @@ public class UnsortedGrouping<T> extends Grouping<T> {
* @see Order
*/
public SortedGrouping<T> sortGroup(String field, Order order) {
- return new SortedGrouping<T>(this.dataSet, this.keys, field, order);
+ SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, field, order);
+ sg.customPartitioner = getCustomPartitioner();
+ return sg;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvInputFormat.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvInputFormat.java b/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvInputFormat.java
index 15333e8..e3ad06f 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvInputFormat.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvInputFormat.java
@@ -20,6 +20,7 @@
package org.apache.flink.api.java.record.io;
import com.google.common.base.Preconditions;
+
import org.apache.flink.api.common.io.GenericCsvInputFormat;
import org.apache.flink.api.common.io.ParseException;
import org.apache.flink.api.common.operators.CompilerHints;
@@ -54,6 +55,7 @@ import java.io.IOException;
* @see Configuration
* @see Record
*/
+@SuppressWarnings("deprecation")
public class CsvInputFormat extends GenericCsvInputFormat<Record> {
private static final long serialVersionUID = 1L;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvOutputFormat.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvOutputFormat.java b/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvOutputFormat.java
index 2c514fe..a5d83c3 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvOutputFormat.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvOutputFormat.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.api.java.record.io;
import java.io.BufferedOutputStream;
@@ -52,6 +51,7 @@ import org.apache.flink.types.Value;
* @see Configuration
* @see Record
*/
+@SuppressWarnings("deprecation")
public class CsvOutputFormat extends FileOutputFormat {
private static final long serialVersionUID = 1L;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/record/io/DelimitedOutputFormat.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/record/io/DelimitedOutputFormat.java b/flink-java/src/main/java/org/apache/flink/api/java/record/io/DelimitedOutputFormat.java
index 0818f45..49d9a2a 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/record/io/DelimitedOutputFormat.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/record/io/DelimitedOutputFormat.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.api.java.record.io;
@@ -27,10 +26,10 @@ import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.Record;
-
/**
* The base class for output formats that serialize their records into a delimited sequence.
*/
+@SuppressWarnings("deprecation")
public abstract class DelimitedOutputFormat extends FileOutputFormat {
private static final long serialVersionUID = 1L;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/record/operators/ReduceOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/record/operators/ReduceOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/record/operators/ReduceOperator.java
index 5329a69..d09c2dc 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/record/operators/ReduceOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/record/operators/ReduceOperator.java
@@ -53,6 +53,7 @@ import org.apache.flink.util.InstantiationUtil;
*
* @see ReduceFunction
*/
+@SuppressWarnings("deprecation")
public class ReduceOperator extends GroupReduceOperatorBase<Record, Record, GroupReduceFunction<Record, Record>> implements RecordOperator {
private static final String DEFAULT_NAME = "<Unnamed Reducer>"; // the default name for contracts
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
index d52e1b0..33750b5 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
@@ -42,6 +42,7 @@ import org.apache.flink.api.common.functions.InvalidTypesException;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.io.InputFormat;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
@@ -115,6 +116,10 @@ public class TypeExtractor {
return getUnaryOperatorReturnType((Function) selectorInterface, KeySelector.class, false, false, inType);
}
+ public static <T> TypeInformation<T> getPartitionerTypes(Partitioner<T> partitioner) {
+ return new TypeExtractor().privateCreateTypeInfo(Partitioner.class, partitioner.getClass(), 0, null, null);
+ }
+
@SuppressWarnings("unchecked")
public static <IN> TypeInformation<IN> getInputFormatTypes(InputFormat<IN, ?> inputFormatInterface) {
if(inputFormatInterface instanceof ResultTypeQueryable) {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/ChannelSelector.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/ChannelSelector.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/ChannelSelector.java
index a62de77..c780f87 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/ChannelSelector.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/ChannelSelector.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.runtime.io.network.api;
import org.apache.flink.core.io.IOReadableWritable;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java
index b39b402..c1037b5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java
@@ -25,6 +25,7 @@ import org.apache.flink.api.common.accumulators.AccumulatorHelper;
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.functions.FlatCombineFunction;
import org.apache.flink.api.common.functions.Function;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
@@ -1269,7 +1270,9 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
throw new Exception("Incompatibe serializer-/comparator factories.");
}
final DataDistribution distribution = config.getOutputDataDistribution(i, cl);
- oe = new RecordOutputEmitter(strategy, comparator, distribution);
+ final Partitioner<?> partitioner = config.getOutputPartitioner(i, cl);
+
+ oe = new RecordOutputEmitter(strategy, comparator, partitioner, distribution);
}
writers.add(new RecordWriter<Record>(task, oe));
@@ -1292,17 +1295,17 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
// create the OutputEmitter from output ship strategy
final ShipStrategyType strategy = config.getOutputShipStrategy(i);
final TypeComparatorFactory<T> compFactory = config.getOutputComparator(i, cl);
- final DataDistribution dataDist = config.getOutputDataDistribution(i, cl);
final ChannelSelector<SerializationDelegate<T>> oe;
if (compFactory == null) {
oe = new OutputEmitter<T>(strategy);
- } else if (dataDist == null){
- final TypeComparator<T> comparator = compFactory.createComparator();
- oe = new OutputEmitter<T>(strategy, comparator);
- } else {
+ }
+ else {
+ final DataDistribution dataDist = config.getOutputDataDistribution(i, cl);
+ final Partitioner<?> partitioner = config.getOutputPartitioner(i, cl);
+
final TypeComparator<T> comparator = compFactory.createComparator();
- oe = new OutputEmitter<T>(strategy, comparator, dataDist);
+ oe = new OutputEmitter<T>(strategy, comparator, partitioner, dataDist);
}
writers.add(new RecordWriter<SerializationDelegate<T>>(task, oe));
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/HistogramPartitionFunction.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/HistogramPartitionFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/HistogramPartitionFunction.java
deleted file mode 100644
index 54bb901..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/HistogramPartitionFunction.java
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.runtime.operators.shipping;
-
-import java.util.Arrays;
-
-import org.apache.flink.api.common.operators.Order;
-import org.apache.flink.types.Record;
-
-public class HistogramPartitionFunction implements PartitionFunction {
- private final Record[] splitBorders;
- private final Order partitionOrder;
-
- public HistogramPartitionFunction(Record[] splitBorders, Order partitionOrder) {
- this.splitBorders = splitBorders;
- this.partitionOrder = partitionOrder;
- }
-
- @Override
- public void selectChannels(Record data, int numChannels, int[] channels) {
- //TODO: Check partition borders match number of channels
- int pos = Arrays.binarySearch(splitBorders, data);
-
- /*
- *
- * TODO CHECK ONLY FOR KEYS NOT FOR WHOLE RECORD
- *
- */
-
- if(pos < 0) {
- pos++;
- pos = -pos;
- }
-
- if(partitionOrder == Order.ASCENDING || partitionOrder == Order.ANY) {
- channels[0] = pos;
- } else {
- channels[0] = splitBorders.length - pos;
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java
index 4f297b0..ec92e3f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java
@@ -20,6 +20,7 @@
package org.apache.flink.runtime.operators.shipping;
import org.apache.flink.api.common.distributions.DataDistribution;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.runtime.io.network.api.ChannelSelector;
import org.apache.flink.runtime.plugable.SerializationDelegate;
@@ -33,6 +34,10 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
private int nextChannelToSendTo = 0; // counter to go over channels round robin
private final TypeComparator<T> comparator; // the comparator for hashing / sorting
+
+ private final Partitioner<Object> partitioner;
+
+ private Object[] extractedKeys;
// ------------------------------------------------------------------------
// Constructors
@@ -62,7 +67,7 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
* @param comparator The comparator used to hash / compare the records.
*/
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator) {
- this(strategy, comparator, null);
+ this(strategy, comparator, null, null);
}
/**
@@ -74,12 +79,22 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
* @param distr The distribution pattern used in the case of a range partitioning.
*/
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, DataDistribution distr) {
+ this(strategy, comparator, null, distr);
+ }
+
+ public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, Partitioner<?> partitioner) {
+ this(strategy, comparator, partitioner, null);
+ }
+
+ @SuppressWarnings("unchecked")
+ public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, Partitioner<?> partitioner, DataDistribution distr) {
if (strategy == null) {
throw new NullPointerException();
}
this.strategy = strategy;
this.comparator = comparator;
+ this.partitioner = (Partitioner<Object>) partitioner;
switch (strategy) {
case FORWARD:
@@ -87,6 +102,7 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
case PARTITION_RANGE:
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
+ case PARTITION_CUSTOM:
case BROADCAST:
break;
default:
@@ -96,6 +112,9 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
if ((strategy == ShipStrategyType.PARTITION_RANGE) && distr == null) {
throw new NullPointerException("Data distribution must not be null when the ship strategy is range partitioning.");
}
+ if (strategy == ShipStrategyType.PARTITION_CUSTOM && partitioner == null) {
+ throw new NullPointerException("Partitioner must not be null when the ship strategy is set to custom partitioning.");
+ }
}
// ------------------------------------------------------------------------
@@ -111,10 +130,12 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
return robin(numberOfChannels);
case PARTITION_HASH:
return hashPartitionDefault(record.getInstance(), numberOfChannels);
- case PARTITION_RANGE:
- return rangePartition(record.getInstance(), numberOfChannels);
case BROADCAST:
return broadcast(numberOfChannels);
+ case PARTITION_CUSTOM:
+ return customPartition(record.getInstance(), numberOfChannels);
+ case PARTITION_RANGE:
+ return rangePartition(record.getInstance(), numberOfChannels);
default:
throw new UnsupportedOperationException("Unsupported distribution strategy: " + strategy.name());
}
@@ -189,4 +210,25 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
private final int[] rangePartition(T record, int numberOfChannels) {
throw new UnsupportedOperationException();
}
+
+ private final int[] customPartition(T record, int numberOfChannels) {
+ if (channels == null) {
+ channels = new int[1];
+ extractedKeys = new Object[1];
+ }
+
+ try {
+ if (comparator.extractKeys(record, extractedKeys, 0) == 1) {
+ final Object key = extractedKeys[0];
+ channels[0] = partitioner.partition(key, numberOfChannels);
+ return channels;
+ }
+ else {
+ throw new RuntimeException("Inconsistency in the key comparator - comparator extracted more than one field.");
+ }
+ }
+ catch (Throwable t) {
+ throw new RuntimeException("Error while calling custom partitioner.", t);
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/PartitionFunction.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/PartitionFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/PartitionFunction.java
deleted file mode 100644
index dadec16..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/PartitionFunction.java
+++ /dev/null
@@ -1,26 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.runtime.operators.shipping;
-
-import org.apache.flink.types.Record;
-
-public interface PartitionFunction {
- public void selectChannels(Record data, int numChannels, int[] channels);
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/RecordOutputEmitter.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/RecordOutputEmitter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/RecordOutputEmitter.java
index 8a375e0..9d06aad 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/RecordOutputEmitter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/RecordOutputEmitter.java
@@ -20,6 +20,7 @@
package org.apache.flink.runtime.operators.shipping;
import org.apache.flink.api.common.distributions.DataDistribution;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.runtime.io.network.api.ChannelSelector;
import org.apache.flink.types.Key;
@@ -43,7 +44,11 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
private final DataDistribution distribution; // the data distribution to create the partition boundaries for range partitioning
+ private final Partitioner<Object> partitioner;
+
private int nextChannelToSendTo; // counter to go over channels round robin
+
+ private Object[] extractedKeys;
// ------------------------------------------------------------------------
// Constructors
@@ -66,7 +71,7 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
* @param comparator The comparator used to hash / compare the records.
*/
public RecordOutputEmitter(ShipStrategyType strategy, TypeComparator<Record> comparator) {
- this(strategy, comparator, null);
+ this(strategy, comparator, null, null);
}
/**
@@ -78,6 +83,15 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
* @param distr The distribution pattern used in the case of a range partitioning.
*/
public RecordOutputEmitter(ShipStrategyType strategy, TypeComparator<Record> comparator, DataDistribution distr) {
+ this(strategy, comparator, null, distr);
+ }
+
+ public RecordOutputEmitter(ShipStrategyType strategy, TypeComparator<Record> comparator, Partitioner<?> partitioner) {
+ this(strategy, comparator, partitioner, null);
+ }
+
+ @SuppressWarnings("unchecked")
+ public RecordOutputEmitter(ShipStrategyType strategy, TypeComparator<Record> comparator, Partitioner<?> partitioner, DataDistribution distr) {
if (strategy == null) {
throw new NullPointerException();
}
@@ -85,6 +99,7 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
this.strategy = strategy;
this.comparator = comparator;
this.distribution = distr;
+ this.partitioner = (Partitioner<Object>) partitioner;
switch (strategy) {
case FORWARD:
@@ -94,6 +109,7 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
this.channels = new int[1];
break;
case BROADCAST:
+ case PARTITION_CUSTOM:
break;
default:
throw new IllegalArgumentException("Invalid shipping strategy for OutputEmitter: " + strategy.name());
@@ -102,6 +118,9 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
if ((strategy == ShipStrategyType.PARTITION_RANGE) && distr == null) {
throw new NullPointerException("Data distribution must not be null when the ship strategy is range partitioning.");
}
+ if (strategy == ShipStrategyType.PARTITION_CUSTOM && partitioner == null) {
+ throw new NullPointerException("Partitioner must not be null when the ship strategy is set to custom partitioning.");
+ }
}
// ------------------------------------------------------------------------
@@ -113,13 +132,16 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
switch (strategy) {
case FORWARD:
case PARTITION_RANDOM:
+ case PARTITION_FORCED_REBALANCE:
return robin(numberOfChannels);
case PARTITION_HASH:
return hashPartitionDefault(record, numberOfChannels);
- case PARTITION_RANGE:
- return rangePartition(record, numberOfChannels);
+ case PARTITION_CUSTOM:
+ return customPartition(record, numberOfChannels);
case BROADCAST:
return broadcast(numberOfChannels);
+ case PARTITION_RANGE:
+ return rangePartition(record, numberOfChannels);
default:
throw new UnsupportedOperationException("Unsupported distribution strategy: " + strategy.name());
}
@@ -200,4 +222,25 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
"The number of channels to partition among is inconsistent with the partitioners state.");
}
}
+
+ private final int[] customPartition(Record record, int numberOfChannels) {
+ if (channels == null) {
+ channels = new int[1];
+ extractedKeys = new Object[1];
+ }
+
+ try {
+ if (comparator.extractKeys(record, extractedKeys, 0) == 1) {
+ final Object key = extractedKeys[0];
+ channels[0] = partitioner.partition(key, numberOfChannels);
+ return channels;
+ }
+ else {
+ throw new RuntimeException("Inconsistency in the key comparator - comparator extracted more than one field.");
+ }
+ }
+ catch (Throwable t) {
+ throw new RuntimeException("Error while calling custom partitioner.", t);
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/ShipStrategyType.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/ShipStrategyType.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/ShipStrategyType.java
index 45134a1..fb32a6e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/ShipStrategyType.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/ShipStrategyType.java
@@ -51,14 +51,19 @@ public enum ShipStrategyType {
PARTITION_RANGE(true, true),
/**
- * Partitioning the data evenly
+ * Partitioning the data evenly, forced at a specific location (cannot be pushed down by optimizer).
*/
PARTITION_FORCED_REBALANCE(true, false),
/**
* Replicating the data set to all instances.
*/
- BROADCAST(true, false);
+ BROADCAST(true, false),
+
+ /**
+ * Partitioning using a custom partitioner.
+ */
+ PARTITION_CUSTOM(true, true);
// --------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
index 1b44a3b..89cf98a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
@@ -36,6 +36,7 @@ import org.apache.flink.api.common.aggregators.AggregatorWithName;
import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.functions.Function;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypePairComparatorFactory;
@@ -141,6 +142,8 @@ public class TaskConfig {
private static final String OUTPUT_DATA_DISTRIBUTION_PREFIX = "out.distribution.";
+ private static final String OUTPUT_PARTITIONER = "out.partitioner.";
+
// ------------------------------------- Chaining ---------------------------------------------
private static final String CHAINING_NUM_STUBS = "chaining.num";
@@ -597,6 +600,27 @@ public class TaskConfig {
}
}
+ public void setOutputPartitioner(Partitioner<?> partitioner, int outputNum) {
+ try {
+ InstantiationUtil.writeObjectToConfig(partitioner, config, OUTPUT_PARTITIONER + outputNum);
+ }
+ catch (Throwable t) {
+ throw new RuntimeException("Could not serialize custom partitioner.", t);
+ }
+ }
+
+ public Partitioner<?> getOutputPartitioner(int outputNum, final ClassLoader cl) throws ClassNotFoundException {
+ try {
+ return (Partitioner<?>) InstantiationUtil.readObjectFromConfig(config, OUTPUT_PARTITIONER + outputNum, cl);
+ }
+ catch (ClassNotFoundException e) {
+ throw e;
+ }
+ catch (Throwable t) {
+ throw new RuntimeException("Could not deserialize custom partitioner.", t);
+ }
+ }
+
// --------------------------------------------------------------------------------------------
// Parameters to configure the memory and I/O behavior
// --------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java
index 534ef45..3d76921 100644
--- a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java
+++ b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java
@@ -232,6 +232,7 @@ public class ScalaAggregateOperator<IN> extends SingleInputOperator<IN, IN, Scal
}
po.setSemanticProperties(props);
+ po.setCustomPartitioner(grouping.getCustomPartitioner());
return po;
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
index ca8e469..d1233e6 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
@@ -1009,9 +1009,75 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
getCallLocationName())
wrap(op)
}
+
+ /**
+ * Partitions a tuple DataSet on the specified key fields using a custom partitioner.
+ * This method takes the key position to partition on, and a partitioner that accepts the key
+ * type.
+ * <p>
+ * Note: This method works only on single field keys.
+ */
+ def partitionCustom[K: TypeInformation](partitioner: Partitioner[K], field: Int) : DataSet[T] = {
+ val op = new PartitionOperator[T](
+ javaSet,
+ new Keys.ExpressionKeys[T](Array[Int](field), javaSet.getType, false),
+ partitioner,
+ implicitly[TypeInformation[K]],
+ getCallLocationName())
+
+ wrap(op)
+ }
+
+ /**
+ * Partitions a POJO DataSet on the specified key fields using a custom partitioner.
+ * This method takes the key expression to partition on, and a partitioner that accepts the key
+ * type.
+ * <p>
+ * Note: This method works only on single field keys.
+ */
+ def partitionCustom[K: TypeInformation](partitioner: Partitioner[K], field: String)
+ : DataSet[T] = {
+ val op = new PartitionOperator[T](
+ javaSet,
+ new Keys.ExpressionKeys[T](Array[String](field), javaSet.getType),
+ partitioner,
+ implicitly[TypeInformation[K]],
+ getCallLocationName())
+
+ wrap(op)
+ }
+
+ /**
+ * Partitions a DataSet on the key returned by the selector, using a custom partitioner.
+ * This method takes the key selector t get the key to partition on, and a partitioner that
+ * accepts the key type.
+ * <p>
+ * Note: This method works only on single field keys, i.e. the selector cannot return tuples
+ * of fields.
+ */
+ def partitionCustom[K: TypeInformation](partitioner: Partitioner[K], fun: T => K)
+ : DataSet[T] = {
+ val keyExtractor = new KeySelector[T, K] {
+ def getKey(in: T) = fun(in)
+ }
+
+ val keyType = implicitly[TypeInformation[K]];
+
+ val op = new PartitionOperator[T](
+ javaSet,
+ new Keys.SelectorFunctionKeys[T, K](
+ keyExtractor,
+ javaSet.getType,
+ keyType),
+ partitioner,
+ keyType,
+ getCallLocationName())
+
+ wrap(op)
+ }
/**
- * Enforces a rebalancing of the DataSet, i.e., the DataSet is evenly distributed over all
+ * Enforces a re-balancing of the DataSet, i.e., the DataSet is evenly distributed over all
* parallel instances of the
* following task. This can help to improve performance in case of heavy data skew and compute
* intensive operations.
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
index 23edc74..d87426e 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
@@ -20,9 +20,7 @@ package org.apache.flink.api.scala
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.java.functions.FirstReducer
import org.apache.flink.api.scala.operators.ScalaAggregateOperator
-
import scala.collection.JavaConverters._
-
import org.apache.commons.lang3.Validate
import org.apache.flink.api.common.functions.{GroupReduceFunction, ReduceFunction}
import org.apache.flink.api.common.operators.Order
@@ -30,9 +28,10 @@ import org.apache.flink.api.java.aggregation.Aggregations
import org.apache.flink.api.java.operators._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.util.Collector
-
import scala.collection.mutable
import scala.reflect.ClassTag
+import org.apache.flink.api.common.functions.Partitioner
+import com.google.common.base.Preconditions
/**
* A [[DataSet]] to which a grouping key was added. Operations work on groups of elements with the
@@ -49,6 +48,8 @@ class GroupedDataSet[T: ClassTag](
// when using a group-at-a-time reduce function.
private val groupSortKeyPositions = mutable.MutableList[Either[Int, String]]()
private val groupSortOrders = mutable.MutableList[Order]()
+
+ private var partitioner : Partitioner[_] = _
/**
* Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
@@ -113,16 +114,51 @@ class GroupedDataSet[T: ClassTag](
}
}
- grouping
+
+ if (partitioner == null) {
+ grouping
+ } else {
+ grouping.withPartitioner(partitioner)
+ }
+
} else {
- new UnsortedGrouping[T](set.javaSet, keys)
+ createUnsortedGrouping()
}
}
/** Convenience methods for creating the [[UnsortedGrouping]] */
- private def createUnsortedGrouping(): Grouping[T] = new UnsortedGrouping[T](set.javaSet, keys)
+ private def createUnsortedGrouping(): Grouping[T] = {
+ val grp = new UnsortedGrouping[T](set.javaSet, keys)
+ if (partitioner == null) {
+ grp
+ } else {
+ grp.withPartitioner(partitioner)
+ }
+ }
/**
+ * Sets a custom partitioner for the grouping.
+ */
+ def withPartitioner[K : TypeInformation](partitioner: Partitioner[K]) : GroupedDataSet[T] = {
+ Preconditions.checkNotNull(partitioner)
+ keys.validateCustomPartitioner(partitioner, implicitly[TypeInformation[K]])
+ this.partitioner = partitioner
+ this
+ }
+
+ /**
+ * Gets the custom partitioner to be used for this grouping, or null, if
+ * none was defined.
+ */
+ def getCustomPartitioner[K]() : Partitioner[K] = {
+ partitioner.asInstanceOf[Partitioner[K]]
+ }
+
+ // ----------------------------------------------------------------------------------------------
+ // Operations
+ // ----------------------------------------------------------------------------------------------
+
+ /**
* Creates a new [[DataSet]] by aggregating the specified tuple field using the given aggregation
* function. Since this is a keyed DataSet the aggregation will be performed on groups of
* tuples with the same key.
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
index 7062c63..f5b0783 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
@@ -21,15 +21,15 @@ import org.apache.commons.lang3.Validate
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.common.functions.{JoinFunction, RichFlatJoinFunction, FlatJoinFunction}
import org.apache.flink.api.common.typeutils.TypeSerializer
-import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint;
+import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint
import org.apache.flink.api.java.operators.JoinOperator.DefaultJoin.WrappingFlatJoinFunction
-import org.apache.flink.api.java.operators.JoinOperator.EquiJoin;
+import org.apache.flink.api.java.operators.JoinOperator.EquiJoin
import org.apache.flink.api.java.operators._
import org.apache.flink.api.scala.typeutils.{CaseClassSerializer, CaseClassTypeInfo}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.util.Collector
-
import scala.reflect.ClassTag
+import org.apache.flink.api.common.functions.Partitioner
/**
* A specific [[DataSet]] that results from a `join` operation. The result of a default join is a
@@ -66,6 +66,8 @@ class JoinDataSet[L, R](
rightKeys: Keys[R])
extends DataSet(defaultJoin) {
+ var customPartitioner : Partitioner[_] = _
+
/**
* Creates a new [[DataSet]] where the result for each pair of joined elements is the result
* of the given function.
@@ -86,8 +88,12 @@ class JoinDataSet[L, R](
implicitly[TypeInformation[O]],
defaultJoin.getJoinHint,
getCallLocationName())
-
- wrap(joinOperator)
+
+ if (customPartitioner != null) {
+ wrap(joinOperator.withPartitioner(customPartitioner))
+ } else {
+ wrap(joinOperator)
+ }
}
/**
@@ -112,7 +118,11 @@ class JoinDataSet[L, R](
defaultJoin.getJoinHint,
getCallLocationName())
- wrap(joinOperator)
+ if (customPartitioner != null) {
+ wrap(joinOperator.withPartitioner(customPartitioner))
+ } else {
+ wrap(joinOperator)
+ }
}
/**
@@ -136,7 +146,11 @@ class JoinDataSet[L, R](
defaultJoin.getJoinHint,
getCallLocationName())
- wrap(joinOperator)
+ if (customPartitioner != null) {
+ wrap(joinOperator.withPartitioner(customPartitioner))
+ } else {
+ wrap(joinOperator)
+ }
}
/**
@@ -161,7 +175,35 @@ class JoinDataSet[L, R](
defaultJoin.getJoinHint,
getCallLocationName())
- wrap(joinOperator)
+ if (customPartitioner != null) {
+ wrap(joinOperator.withPartitioner(customPartitioner))
+ } else {
+ wrap(joinOperator)
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------------
+ // Properties
+ // ----------------------------------------------------------------------------------------------
+
+ def withPartitioner[K : TypeInformation](partitioner : Partitioner[K]) : JoinDataSet[L, R] = {
+ if (partitioner != null) {
+ val typeInfo : TypeInformation[K] = implicitly[TypeInformation[K]]
+
+ leftKeys.validateCustomPartitioner(partitioner, typeInfo)
+ rightKeys.validateCustomPartitioner(partitioner, typeInfo)
+ }
+ this.customPartitioner = partitioner
+ defaultJoin.withPartitioner(partitioner)
+
+ this
+ }
+
+ /**
+ * Gets the custom partitioner used by this join, or null, if none is set.
+ */
+ def getPartitioner[K]() : Partitioner[K] = {
+ customPartitioner.asInstanceOf[Partitioner[K]]
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/java/org/apache/flink/test/cancelling/CancellingTestBase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/cancelling/CancellingTestBase.java b/flink-tests/src/test/java/org/apache/flink/test/cancelling/CancellingTestBase.java
index 5859a4a..bc40df6 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/cancelling/CancellingTestBase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/cancelling/CancellingTestBase.java
@@ -201,6 +201,8 @@ public abstract class CancellingTestBase {
case FAILING:
case CREATED:
break;
+ case RESTARTING:
+ throw new IllegalStateException("Job restarted");
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/java/org/apache/flink/test/iterative/StaticlyNestedIterationsITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/StaticlyNestedIterationsITCase.java b/flink-tests/src/test/java/org/apache/flink/test/iterative/StaticlyNestedIterationsITCase.java
index 19fc936..975e4aa 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/iterative/StaticlyNestedIterationsITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/iterative/StaticlyNestedIterationsITCase.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.io.DiscardingOuputFormat;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.test.util.JavaProgramTestBase;
@@ -49,7 +50,7 @@ public class StaticlyNestedIterationsITCase extends JavaProgramTestBase {
DataSet<Long> mainResult = mainIteration.closeWith(joined);
- mainResult.print();
+ mainResult.output(new DiscardingOuputFormat<Long>());
env.execute();
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/IterationWithChainingNepheleITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/IterationWithChainingNepheleITCase.java b/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/IterationWithChainingNepheleITCase.java
index 3a7cdb7..9dcdf75 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/IterationWithChainingNepheleITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/IterationWithChainingNepheleITCase.java
@@ -68,6 +68,7 @@ import org.junit.runners.Parameterized;
*
* {@link IterationWithChainingITCase}
*/
+@SuppressWarnings("deprecation")
@RunWith(Parameterized.class)
public class IterationWithChainingNepheleITCase extends RecordAPITestBase {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/AggregateTranslationTest.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/AggregateTranslationTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/AggregateTranslationTest.scala
index c4d7dc8..c9b1a3a 100644
--- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/AggregateTranslationTest.scala
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/AggregateTranslationTest.scala
@@ -15,6 +15,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.flink.api.scala.operators.translation
import org.apache.flink.api.common.Plan
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingKeySelectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingKeySelectorTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingKeySelectorTest.scala
new file mode 100644
index 0000000..17ecc3f
--- /dev/null
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingKeySelectorTest.scala
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.operators.translation
+
+import org.junit.Assert._
+import org.junit.Test
+import org.apache.flink.api.scala._
+import org.apache.flink.api.common.functions.Partitioner
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType
+import org.apache.flink.compiler.plan.SingleInputPlanNode
+import org.apache.flink.test.compiler.util.CompilerTestBase
+import scala.collection.immutable.Seq
+import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.InvalidProgramException
+
+class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
+
+ @Test
+ def testCustomPartitioningKeySelectorReduce() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0,0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
+ .reduce( (a,b) => a )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val keyRemovingMapper = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val reducer = keyRemovingMapper.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, keyRemovingMapper.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningKeySelectorGroupReduce() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0,0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
+ .reduceGroup( iter => Seq(iter.next()) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningKeySelectorGroupReduceSorted() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0,0,0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy( _._1 )
+ .withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .reduceGroup( iter => Seq(iter.next()) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningKeySelectorGroupReduceSorted2() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0,0,0,0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .sortGroup(2, Order.DESCENDING)
+ .reduceGroup( iter => Seq(iter.next()) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningKeySelectorInvalidType() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0) ).rebalance().setParallelism(4)
+
+ try {
+ data
+ .groupBy( _._1 )
+ .withPartitioner(new TestPartitionerLong())
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningKeySelectorInvalidTypeSorted() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0, 0) ).rebalance().setParallelism(4)
+
+ try {
+ data
+ .groupBy( _._1 )
+ .sortGroup(1, Order.ASCENDING)
+ .withPartitioner(new TestPartitionerLong())
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleRejectCompositeKey() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0, 0) ).rebalance().setParallelism(4)
+
+ try {
+ data.groupBy( v => (v._1, v._2) ).withPartitioner(new TestPartitionerInt())
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ private class TestPartitionerInt extends Partitioner[Int] {
+
+ override def partition(key: Int, numPartitions: Int): Int = 0
+ }
+
+ private class TestPartitionerLong extends Partitioner[Long] {
+
+ override def partition(key: Long, numPartitions: Int): Int = 0
+ }
+}