You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2014/06/13 14:44:56 UTC
[4/5] git commit: Aggregators and convergence criteria as objects
instead of classes
Aggregators and convergence criteria as objects instead of classes
Project: http://git-wip-us.apache.org/repos/asf/incubator-flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-flink/commit/08f189ad
Tree: http://git-wip-us.apache.org/repos/asf/incubator-flink/tree/08f189ad
Diff: http://git-wip-us.apache.org/repos/asf/incubator-flink/diff/08f189ad
Branch: refs/heads/master
Commit: 08f189ad37a64b094ba86def8544687419c131ce
Parents: f82c61f
Author: Stephan Ewen <se...@apache.org>
Authored: Fri Jun 13 13:14:42 2014 +0200
Committer: Stephan Ewen <se...@apache.org>
Committed: Fri Jun 13 14:28:04 2014 +0200
----------------------------------------------------------------------
.../spargel/java/VertexCentricIteration.java | 10 +-
.../spargel/java/SpargelTranslationTest.java | 4 +-
.../plantranslate/NepheleJobGraphGenerator.java | 10 +-
.../api/common/aggregators/Aggregator.java | 4 +-
.../common/aggregators/AggregatorRegistry.java | 21 +-
.../common/aggregators/AggregatorWithName.java | 6 +-
.../aggregators/ConvergenceCriterion.java | 4 +-
.../common/aggregators/DoubleSumAggregator.java | 1 +
.../aggregators/DoubleZeroConvergence.java | 1 +
.../common/aggregators/LongSumAggregator.java | 1 +
.../common/aggregators/LongZeroConvergence.java | 1 +
.../operators/base/BulkIterationBase.java | 4 +-
.../stratosphere/api/java/DeltaIteration.java | 2 +-
.../stratosphere/api/java/IterativeDataSet.java | 4 +-
.../DeltaIterationTranslationTest.java | 2 +-
.../task/IterationSynchronizationSinkTask.java | 11 +-
.../task/RuntimeAggregatorRegistry.java | 4 +-
.../pact/runtime/task/util/TaskConfig.java | 59 ++-
.../api/scala/operators/IterateOperators.scala | 2 +-
.../aggregators/AggregatorsITCase.java | 466 +++++++++++++++++++
...nentsWithParametrizableAggregatorITCase.java | 237 ++++++++++
...entsWithParametrizableConvergenceITCase.java | 223 +++++++++
.../ConnectedComponentsNepheleITCase.java | 6 +-
.../CustomCompensatableDanglingPageRank.java | 6 +-
...mpensatableDanglingPageRankWithCombiner.java | 8 +-
.../CompensatableDanglingPageRank.java | 6 +-
.../DiffL1NormConvergenceCriterion.java | 1 +
.../PageRankStatsAggregator.java | 1 +
.../test/recordJobs/graph/DanglingPageRank.java | 3 +-
.../DiffL1NormConvergenceCriterion.java | 1 +
.../pageRankUtil/PageRankStatsAggregator.java | 1 +
31 files changed, 1037 insertions(+), 73 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-addons/spargel/src/main/java/eu/stratosphere/spargel/java/VertexCentricIteration.java
----------------------------------------------------------------------
diff --git a/stratosphere-addons/spargel/src/main/java/eu/stratosphere/spargel/java/VertexCentricIteration.java b/stratosphere-addons/spargel/src/main/java/eu/stratosphere/spargel/java/VertexCentricIteration.java
index 98ce446..adbcadf 100644
--- a/stratosphere-addons/spargel/src/main/java/eu/stratosphere/spargel/java/VertexCentricIteration.java
+++ b/stratosphere-addons/spargel/src/main/java/eu/stratosphere/spargel/java/VertexCentricIteration.java
@@ -76,7 +76,7 @@ public class VertexCentricIteration<VertexKey extends Comparable<VertexKey>, Ver
private final DataSet<Tuple3<VertexKey, VertexKey, EdgeValue>> edgesWithValue;
- private final Map<String, Class<? extends Aggregator<?>>> aggregators;
+ private final Map<String, Aggregator<?>> aggregators;
private final int maximumNumberOfIterations;
@@ -118,7 +118,7 @@ public class VertexCentricIteration<VertexKey extends Comparable<VertexKey>, Ver
this.edgesWithoutValue = edgesWithoutValue;
this.edgesWithValue = null;
this.maximumNumberOfIterations = maximumNumberOfIterations;
- this.aggregators = new HashMap<String, Class<? extends Aggregator<?>>>();
+ this.aggregators = new HashMap<String, Aggregator<?>>();
this.messageType = getMessageType(mf);
}
@@ -150,7 +150,7 @@ public class VertexCentricIteration<VertexKey extends Comparable<VertexKey>, Ver
this.edgesWithoutValue = null;
this.edgesWithValue = edgesWithValue;
this.maximumNumberOfIterations = maximumNumberOfIterations;
- this.aggregators = new HashMap<String, Class<? extends Aggregator<?>>>();
+ this.aggregators = new HashMap<String, Aggregator<?>>();
this.messageType = getMessageType(mf);
}
@@ -167,7 +167,7 @@ public class VertexCentricIteration<VertexKey extends Comparable<VertexKey>, Ver
* @param name The name of the aggregator, used to retrieve it and its aggregates during execution.
* @param aggregator The aggregator.
*/
- public void registerAggregator(String name, Class<? extends Aggregator<?>> aggregator) {
+ public void registerAggregator(String name, Aggregator<?> aggregator) {
this.aggregators.put(name, aggregator);
}
@@ -286,7 +286,7 @@ public class VertexCentricIteration<VertexKey extends Comparable<VertexKey>, Ver
iteration.parallelism(parallelism);
// register all aggregators
- for (Map.Entry<String, Class<? extends Aggregator<?>>> entry : this.aggregators.entrySet()) {
+ for (Map.Entry<String, Aggregator<?>> entry : this.aggregators.entrySet()) {
iteration.registerAggregator(entry.getKey(), entry.getValue());
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-addons/spargel/src/test/java/eu/stratosphere/spargel/java/SpargelTranslationTest.java
----------------------------------------------------------------------
diff --git a/stratosphere-addons/spargel/src/test/java/eu/stratosphere/spargel/java/SpargelTranslationTest.java b/stratosphere-addons/spargel/src/test/java/eu/stratosphere/spargel/java/SpargelTranslationTest.java
index 3ae5c3f..deedca5 100644
--- a/stratosphere-addons/spargel/src/test/java/eu/stratosphere/spargel/java/SpargelTranslationTest.java
+++ b/stratosphere-addons/spargel/src/test/java/eu/stratosphere/spargel/java/SpargelTranslationTest.java
@@ -75,7 +75,7 @@ public class SpargelTranslationTest {
vertexIteration.setName(ITERATION_NAME);
vertexIteration.setParallelism(ITERATION_DOP);
- vertexIteration.registerAggregator(AGGREGATOR_NAME, LongSumAggregator.class);
+ vertexIteration.registerAggregator(AGGREGATOR_NAME, new LongSumAggregator());
result = initialVertices.runOperation(vertexIteration);
}
@@ -154,7 +154,7 @@ public class SpargelTranslationTest {
vertexIteration.setName(ITERATION_NAME);
vertexIteration.setParallelism(ITERATION_DOP);
- vertexIteration.registerAggregator(AGGREGATOR_NAME, LongSumAggregator.class);
+ vertexIteration.registerAggregator(AGGREGATOR_NAME, new LongSumAggregator());
result = initialVertices.runOperation(vertexIteration);
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plantranslate/NepheleJobGraphGenerator.java
----------------------------------------------------------------------
diff --git a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plantranslate/NepheleJobGraphGenerator.java b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plantranslate/NepheleJobGraphGenerator.java
index fdda1bf..4817095 100644
--- a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plantranslate/NepheleJobGraphGenerator.java
+++ b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plantranslate/NepheleJobGraphGenerator.java
@@ -1285,7 +1285,7 @@ public class NepheleJobGraphGenerator implements Visitor<PlanNode> {
syncConfig.addIterationAggregators(allAggregators);
String convAggName = aggs.getConvergenceCriterionAggregatorName();
- Class<? extends ConvergenceCriterion<?>> convCriterion = aggs.getConvergenceCriterion();
+ ConvergenceCriterion<?> convCriterion = aggs.getConvergenceCriterion();
if (convCriterion != null || convAggName != null) {
if (convCriterion == null) {
@@ -1474,15 +1474,15 @@ public class NepheleJobGraphGenerator implements Visitor<PlanNode> {
syncConfig.addIterationAggregators(allAggregators);
String convAggName = aggs.getConvergenceCriterionAggregatorName();
- Class<? extends ConvergenceCriterion<?>> convCriterion = aggs.getConvergenceCriterion();
+ ConvergenceCriterion<?> convCriterion = aggs.getConvergenceCriterion();
if (convCriterion != null || convAggName != null) {
throw new CompilerException("Error: Cannot use custom convergence criterion with workset iteration. Workset iterations have implicit convergence criterion where workset is empty.");
}
- headConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, LongSumAggregator.class);
- syncConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, LongSumAggregator.class);
- syncConfig.setConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, WorksetEmptyConvergenceCriterion.class);
+ headConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator());
+ syncConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator());
+ syncConfig.setConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new WorksetEmptyConvergenceCriterion());
}
// -------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/Aggregator.java
----------------------------------------------------------------------
diff --git a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/Aggregator.java b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/Aggregator.java
index a35849c..bbc305e 100644
--- a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/Aggregator.java
+++ b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/Aggregator.java
@@ -13,6 +13,8 @@
package eu.stratosphere.api.common.aggregators;
+import java.io.Serializable;
+
import eu.stratosphere.types.Value;
/**
@@ -65,7 +67,7 @@ import eu.stratosphere.types.Value;
*
* @param <T> The type of the aggregated value.
*/
-public interface Aggregator<T extends Value> {
+public interface Aggregator<T extends Value> extends Serializable {
/**
* Gets the aggregator's current aggregate.
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/AggregatorRegistry.java
----------------------------------------------------------------------
diff --git a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/AggregatorRegistry.java b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/AggregatorRegistry.java
index 8f00806..b914a41 100644
--- a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/AggregatorRegistry.java
+++ b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/AggregatorRegistry.java
@@ -24,15 +24,15 @@ import eu.stratosphere.types.Value;
*/
public class AggregatorRegistry {
- private final Map<String, Class<? extends Aggregator<?>>> registry = new HashMap<String, Class<? extends Aggregator<?>>>();
+ private final Map<String, Aggregator<?>> registry = new HashMap<String, Aggregator<?>>();
- private Class<? extends ConvergenceCriterion<? extends Value>> convergenceCriterion;
+ private ConvergenceCriterion<? extends Value> convergenceCriterion;
private String convergenceCriterionAggregatorName;
// --------------------------------------------------------------------------------------------
- public void registerAggregator(String name, Class<? extends Aggregator<?>> aggregator) {
+ public void registerAggregator(String name, Aggregator<?> aggregator) {
if (name == null || aggregator == null) {
throw new IllegalArgumentException("Name and aggregator must not be null");
}
@@ -42,32 +42,31 @@ public class AggregatorRegistry {
this.registry.put(name, aggregator);
}
- public Class<? extends Aggregator<?>> unregisterAggregator(String name) {
+ public Aggregator<?> unregisterAggregator(String name) {
return this.registry.remove(name);
}
public Collection<AggregatorWithName<?>> getAllRegisteredAggregators() {
ArrayList<AggregatorWithName<?>> list = new ArrayList<AggregatorWithName<?>>(this.registry.size());
- for (Map.Entry<String, Class<? extends Aggregator<?>>> entry : this.registry.entrySet()) {
+ for (Map.Entry<String, Aggregator<?>> entry : this.registry.entrySet()) {
@SuppressWarnings("unchecked")
- Class<Aggregator<Value>> valAgg = (Class<Aggregator<Value>>) (Class<?>) entry.getValue();
+ Aggregator<Value> valAgg = (Aggregator<Value>) entry.getValue();
list.add(new AggregatorWithName<Value>(entry.getKey(), valAgg));
}
return list;
}
public <T extends Value> void registerAggregationConvergenceCriterion(
- String name, Class<? extends Aggregator<T>> aggregator, Class<? extends ConvergenceCriterion<T>> convergenceCheck)
+ String name, Aggregator<T> aggregator, ConvergenceCriterion<T> convergenceCheck)
{
if (name == null || aggregator == null || convergenceCheck == null) {
throw new IllegalArgumentException("Name, aggregator, or convergence criterion must not be null");
}
- @SuppressWarnings("unchecked")
- Class<Aggregator<?>> genAgg = (Class<Aggregator<?>>) (Class<?>) aggregator;
+ Aggregator<?> genAgg = (Aggregator<?>) aggregator;
- Class<? extends Aggregator<?>> previous = this.registry.get(name);
+ Aggregator<?> previous = this.registry.get(name);
if (previous != null && previous != genAgg) {
throw new RuntimeException("An aggregator is already registered under the given name.");
}
@@ -81,7 +80,7 @@ public class AggregatorRegistry {
return this.convergenceCriterionAggregatorName;
}
- public Class<? extends ConvergenceCriterion<?>> getConvergenceCriterion() {
+ public ConvergenceCriterion<?> getConvergenceCriterion() {
return this.convergenceCriterion;
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/AggregatorWithName.java
----------------------------------------------------------------------
diff --git a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/AggregatorWithName.java b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/AggregatorWithName.java
index 75a4e50..51585c2 100644
--- a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/AggregatorWithName.java
+++ b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/AggregatorWithName.java
@@ -21,7 +21,7 @@ public class AggregatorWithName<T extends Value> {
private final String name;
- private final Class<? extends Aggregator<T>> aggregator;
+ private final Aggregator<T> aggregator;
/**
* Creates a new instance for the given aggregator and name.
@@ -29,7 +29,7 @@ public class AggregatorWithName<T extends Value> {
* @param name The name that the aggregator is registered under.
* @param aggregator The aggregator.
*/
- public AggregatorWithName(String name, Class<Aggregator<T>> aggregator) {
+ public AggregatorWithName(String name, Aggregator<T> aggregator) {
this.name = name;
this.aggregator = aggregator;
}
@@ -48,7 +48,7 @@ public class AggregatorWithName<T extends Value> {
*
* @return The aggregator.
*/
- public Class<? extends Aggregator<T>> getAggregator() {
+ public Aggregator<T> getAggregator() {
return aggregator;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/ConvergenceCriterion.java
----------------------------------------------------------------------
diff --git a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/ConvergenceCriterion.java b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/ConvergenceCriterion.java
index eb3ef25..804e0ad 100644
--- a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/ConvergenceCriterion.java
+++ b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/ConvergenceCriterion.java
@@ -13,12 +13,14 @@
package eu.stratosphere.api.common.aggregators;
+import java.io.Serializable;
+
import eu.stratosphere.types.Value;
/**
* Used to check for convergence.
*/
-public interface ConvergenceCriterion<T extends Value> {
+public interface ConvergenceCriterion<T extends Value> extends Serializable {
/**
* Decide whether the iterative algorithm has converged
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/DoubleSumAggregator.java
----------------------------------------------------------------------
diff --git a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/DoubleSumAggregator.java b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/DoubleSumAggregator.java
index b2fe9b0..0156cf4 100644
--- a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/DoubleSumAggregator.java
+++ b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/DoubleSumAggregator.java
@@ -18,6 +18,7 @@ import eu.stratosphere.types.DoubleValue;
/**
* An {@link Aggregator} that sums up {@link DoubleValue} values.
*/
+@SuppressWarnings("serial")
public class DoubleSumAggregator implements Aggregator<DoubleValue> {
private DoubleValue wrapper = new DoubleValue();
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/DoubleZeroConvergence.java
----------------------------------------------------------------------
diff --git a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/DoubleZeroConvergence.java b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/DoubleZeroConvergence.java
index 78f4b69..240dd59 100644
--- a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/DoubleZeroConvergence.java
+++ b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/DoubleZeroConvergence.java
@@ -20,6 +20,7 @@ import eu.stratosphere.types.DoubleValue;
* A {@link ConvergenceCriterion} over an {@link Aggregator} that defines convergence as reached once the aggregator
* holds the value zero. The aggregated data type is a {@link DoubleValue}.
*/
+@SuppressWarnings("serial")
public class DoubleZeroConvergence implements ConvergenceCriterion<DoubleValue> {
/**
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/LongSumAggregator.java
----------------------------------------------------------------------
diff --git a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/LongSumAggregator.java b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/LongSumAggregator.java
index 9cbd487..d2d2055 100644
--- a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/LongSumAggregator.java
+++ b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/LongSumAggregator.java
@@ -17,6 +17,7 @@ import eu.stratosphere.types.LongValue;
/**
* An {@link Aggregator} that sums up long values.
*/
+@SuppressWarnings("serial")
public class LongSumAggregator implements Aggregator<LongValue> {
private long sum; // the sum
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/LongZeroConvergence.java
----------------------------------------------------------------------
diff --git a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/LongZeroConvergence.java b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/LongZeroConvergence.java
index fde6ab7..7cbe883 100644
--- a/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/LongZeroConvergence.java
+++ b/stratosphere-core/src/main/java/eu/stratosphere/api/common/aggregators/LongZeroConvergence.java
@@ -20,6 +20,7 @@ import eu.stratosphere.types.LongValue;
* A {@link ConvergenceCriterion} over an {@link Aggregator} that defines convergence as reached once the aggregator
* holds the value zero. The aggregated data type is a {@link LongValue}.
*/
+@SuppressWarnings("serial")
public class LongZeroConvergence implements ConvergenceCriterion<LongValue> {
/**
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-core/src/main/java/eu/stratosphere/api/common/operators/base/BulkIterationBase.java
----------------------------------------------------------------------
diff --git a/stratosphere-core/src/main/java/eu/stratosphere/api/common/operators/base/BulkIterationBase.java b/stratosphere-core/src/main/java/eu/stratosphere/api/common/operators/base/BulkIterationBase.java
index 47842a0..6de9959 100644
--- a/stratosphere-core/src/main/java/eu/stratosphere/api/common/operators/base/BulkIterationBase.java
+++ b/stratosphere-core/src/main/java/eu/stratosphere/api/common/operators/base/BulkIterationBase.java
@@ -122,7 +122,7 @@ public class BulkIterationBase<T> extends SingleInputOperator<T, T, AbstractFunc
mapper.setInput(criterion);
this.terminationCriterion = mapper;
- this.getAggregators().registerAggregationConvergenceCriterion(TERMINATION_CRITERION_AGGREGATOR_NAME, TerminationCriterionAggregator.class, TerminationCriterionAggregationConvergence.class);
+ this.getAggregators().registerAggregationConvergenceCriterion(TERMINATION_CRITERION_AGGREGATOR_NAME, new TerminationCriterionAggregator(), new TerminationCriterionAggregationConvergence());
}
/**
@@ -244,6 +244,7 @@ public class BulkIterationBase<T> extends SingleInputOperator<T, T, AbstractFunc
/**
* Aggregator that basically only adds 1 for every output tuple of the termination criterion branch
*/
+ @SuppressWarnings("serial")
public static class TerminationCriterionAggregator implements Aggregator<LongValue> {
private long count;
@@ -271,6 +272,7 @@ public class BulkIterationBase<T> extends SingleInputOperator<T, T, AbstractFunc
/**
* Convergence for the termination criterion is reached if no tuple is output at current iteration for the termination criterion branch
*/
+ @SuppressWarnings("serial")
public static class TerminationCriterionAggregationConvergence implements ConvergenceCriterion<LongValue> {
private static final Log log = LogFactory.getLog(TerminationCriterionAggregationConvergence.class);
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-java/src/main/java/eu/stratosphere/api/java/DeltaIteration.java
----------------------------------------------------------------------
diff --git a/stratosphere-java/src/main/java/eu/stratosphere/api/java/DeltaIteration.java b/stratosphere-java/src/main/java/eu/stratosphere/api/java/DeltaIteration.java
index 17dfad3..7fa6638 100644
--- a/stratosphere-java/src/main/java/eu/stratosphere/api/java/DeltaIteration.java
+++ b/stratosphere-java/src/main/java/eu/stratosphere/api/java/DeltaIteration.java
@@ -192,7 +192,7 @@ public class DeltaIteration<ST, WT> {
*
* @return The DeltaIteration itself, to allow chaining function calls.
*/
- public DeltaIteration<ST, WT> registerAggregator(String name, Class<? extends Aggregator<?>> aggregator) {
+ public DeltaIteration<ST, WT> registerAggregator(String name, Aggregator<?> aggregator) {
this.aggregators.registerAggregator(name, aggregator);
return this;
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-java/src/main/java/eu/stratosphere/api/java/IterativeDataSet.java
----------------------------------------------------------------------
diff --git a/stratosphere-java/src/main/java/eu/stratosphere/api/java/IterativeDataSet.java b/stratosphere-java/src/main/java/eu/stratosphere/api/java/IterativeDataSet.java
index 4e24d98..bab990d 100644
--- a/stratosphere-java/src/main/java/eu/stratosphere/api/java/IterativeDataSet.java
+++ b/stratosphere-java/src/main/java/eu/stratosphere/api/java/IterativeDataSet.java
@@ -95,7 +95,7 @@ public class IterativeDataSet<T> extends SingleInputOperator<T, T, IterativeData
*
* @return The IterativeDataSet itself, to allow chaining function calls.
*/
- public IterativeDataSet<T> registerAggregator(String name, Class<? extends Aggregator<?>> aggregator) {
+ public IterativeDataSet<T> registerAggregator(String name, Aggregator<?> aggregator) {
this.aggregators.registerAggregator(name, aggregator);
return this;
}
@@ -115,7 +115,7 @@ public class IterativeDataSet<T> extends SingleInputOperator<T, T, IterativeData
* @return The IterativeDataSet itself, to allow chaining function calls.
*/
public <X extends Value> IterativeDataSet<T> registerAggregationConvergenceCriterion(
- String name, Class<? extends Aggregator<X>> aggregator, Class<? extends ConvergenceCriterion<X>> convergenceCheck)
+ String name, Aggregator<X> aggregator, ConvergenceCriterion<X> convergenceCheck)
{
this.aggregators.registerAggregationConvergenceCriterion(name, aggregator, convergenceCheck);
return this;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-java/src/test/java/eu/stratosphere/api/java/operators/translation/DeltaIterationTranslationTest.java
----------------------------------------------------------------------
diff --git a/stratosphere-java/src/test/java/eu/stratosphere/api/java/operators/translation/DeltaIterationTranslationTest.java b/stratosphere-java/src/test/java/eu/stratosphere/api/java/operators/translation/DeltaIterationTranslationTest.java
index 01faa5e..ad170a2 100644
--- a/stratosphere-java/src/test/java/eu/stratosphere/api/java/operators/translation/DeltaIterationTranslationTest.java
+++ b/stratosphere-java/src/test/java/eu/stratosphere/api/java/operators/translation/DeltaIterationTranslationTest.java
@@ -72,7 +72,7 @@ public class DeltaIterationTranslationTest implements java.io.Serializable {
DeltaIteration<Tuple3<Double, Long, String>, Tuple2<Double, String>> iteration = initialSolutionSet.iterateDelta(initialWorkSet, NUM_ITERATIONS, ITERATION_KEYS);
iteration.name(ITERATION_NAME).parallelism(ITERATION_DOP);
- iteration.registerAggregator(AGGREGATOR_NAME, LongSumAggregator.class);
+ iteration.registerAggregator(AGGREGATOR_NAME, new LongSumAggregator());
// test that multiple workset consumers are supported
DataSet<Tuple2<Double, String>> worksetSelfJoin =
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/iterative/task/IterationSynchronizationSinkTask.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/iterative/task/IterationSynchronizationSinkTask.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/iterative/task/IterationSynchronizationSinkTask.java
index 94f9b9f..4e7286b 100644
--- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/iterative/task/IterationSynchronizationSinkTask.java
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/iterative/task/IterationSynchronizationSinkTask.java
@@ -37,7 +37,6 @@ import eu.stratosphere.pact.runtime.iterative.event.WorkerDoneEvent;
import eu.stratosphere.pact.runtime.task.RegularPactTask;
import eu.stratosphere.pact.runtime.task.util.TaskConfig;
import eu.stratosphere.types.Value;
-import eu.stratosphere.util.InstantiationUtil;
/**
* The task responsible for synchronizing all iteration heads, implemented as an {@link AbstractOutputTask}. This task
@@ -80,17 +79,15 @@ public class IterationSynchronizationSinkTask extends AbstractOutputTask impleme
userCodeClassLoader = LibraryCacheManager.getClassLoader(getEnvironment().getJobID());
TaskConfig taskConfig = new TaskConfig(getTaskConfiguration());
- // instantiate all aggregators
+ // store all aggregators
this.aggregators = new HashMap<String, Aggregator<?>>();
for (AggregatorWithName<?> aggWithName : taskConfig.getIterationAggregators()) {
- Aggregator<?> agg = InstantiationUtil.instantiate(aggWithName.getAggregator(), Aggregator.class);
- aggregators.put(aggWithName.getName(), agg);
+ aggregators.put(aggWithName.getName(), aggWithName.getAggregator());
}
- // instantiate the aggregator convergence criterion
+ // store the aggregator convergence criterion
if (taskConfig.usesConvergenceCriterion()) {
- Class<? extends ConvergenceCriterion<Value>> convClass = taskConfig.getConvergenceCriterion();
- convergenceCriterion = InstantiationUtil.instantiate(convClass, ConvergenceCriterion.class);
+ convergenceCriterion = taskConfig.getConvergenceCriterion();
convergenceAggregatorName = taskConfig.getConvergenceCriterionAggregatorName();
Preconditions.checkNotNull(convergenceAggregatorName);
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/iterative/task/RuntimeAggregatorRegistry.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/iterative/task/RuntimeAggregatorRegistry.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/iterative/task/RuntimeAggregatorRegistry.java
index ec440bb..375c8ed 100644
--- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/iterative/task/RuntimeAggregatorRegistry.java
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/iterative/task/RuntimeAggregatorRegistry.java
@@ -19,7 +19,6 @@ import java.util.Map;
import eu.stratosphere.api.common.aggregators.Aggregator;
import eu.stratosphere.api.common.aggregators.AggregatorWithName;
import eu.stratosphere.types.Value;
-import eu.stratosphere.util.InstantiationUtil;
/**
@@ -36,8 +35,7 @@ public class RuntimeAggregatorRegistry {
this.previousGlobalAggregate = new HashMap<String, Value>();
for (AggregatorWithName<?> agg : aggs) {
- Aggregator<?> aggregator = InstantiationUtil.instantiate(agg.getAggregator(), Aggregator.class);
- this.aggregators.put(agg.getName(), aggregator);
+ this.aggregators.put(agg.getName(), agg.getAggregator());
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/util/TaskConfig.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/util/TaskConfig.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/util/TaskConfig.java
index e10d534..947e22a 100644
--- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/util/TaskConfig.java
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/util/TaskConfig.java
@@ -866,10 +866,14 @@ public class TaskConfig {
ITERATION_SOLUTION_SET_COMPARATOR_PARAMETERS, cl);
}
- public void addIterationAggregator(String name, Class<? extends Aggregator<?>> aggregator) {
+ public void addIterationAggregator(String name, Aggregator<?> aggregator) {
int num = this.config.getInteger(ITERATION_NUM_AGGREGATORS, 0);
this.config.setString(ITERATION_AGGREGATOR_NAME_PREFIX + num, name);
- this.config.setClass(ITERATION_AGGREGATOR_PREFIX + num, aggregator);
+ try {
+ InstantiationUtil.writeObjectToConfig(aggregator, this.config, ITERATION_AGGREGATOR_PREFIX + num);
+ } catch (IOException e) {
+ throw new RuntimeException("Error while writing the aggregator object to the task configuration.");
+ }
this.config.setInteger(ITERATION_NUM_AGGREGATORS, num + 1);
}
@@ -877,12 +881,17 @@ public class TaskConfig {
int num = this.config.getInteger(ITERATION_NUM_AGGREGATORS, 0);
for (AggregatorWithName<?> awn : aggregators) {
this.config.setString(ITERATION_AGGREGATOR_NAME_PREFIX + num, awn.getName());
- this.config.setClass(ITERATION_AGGREGATOR_PREFIX + num, awn.getAggregator());
+ try {
+ InstantiationUtil.writeObjectToConfig(awn.getAggregator(), this.config, ITERATION_AGGREGATOR_PREFIX + num);
+ } catch (IOException e) {
+ throw new RuntimeException("Error while writing the aggregator object to the task configuration.");
+ }
num++;
}
this.config.setInteger(ITERATION_NUM_AGGREGATORS, num);
}
+ @SuppressWarnings("unchecked")
public Collection<AggregatorWithName<?>> getIterationAggregators() {
final int numAggs = this.config.getInteger(ITERATION_NUM_AGGREGATORS, 0);
if (numAggs == 0) {
@@ -891,33 +900,53 @@ public class TaskConfig {
List<AggregatorWithName<?>> list = new ArrayList<AggregatorWithName<?>>(numAggs);
for (int i = 0; i < numAggs; i++) {
- @SuppressWarnings("unchecked")
- Class<Aggregator<Value>> aggClass = (Class<Aggregator<Value>>) (Class<?>) this.config.getClass(ITERATION_AGGREGATOR_PREFIX + i, null);
- if (aggClass == null) {
+ Aggregator<Value> aggObj;
+ try {
+ aggObj = (Aggregator<Value>) InstantiationUtil.readObjectFromConfig(
+ this.config, ITERATION_AGGREGATOR_PREFIX + i, getConfiguration().getClassLoader());
+ } catch (IOException e) {
+ throw new RuntimeException("Error while reading the aggregator object from the task configuration.");
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException("Error while reading the aggregator object from the task configuration. " +
+ "Aggregator class not found.");
+ }
+ if (aggObj == null) {
throw new RuntimeException("Missing config entry for aggregator.");
}
String name = this.config.getString(ITERATION_AGGREGATOR_NAME_PREFIX + i, null);
if (name == null) {
throw new RuntimeException("Missing config entry for aggregator.");
}
- list.add(new AggregatorWithName<Value>(name, aggClass));
+ list.add(new AggregatorWithName<Value>(name, aggObj));
}
return list;
}
- public void setConvergenceCriterion(String aggregatorName, Class<? extends ConvergenceCriterion<?>> convergenceCriterionClass) {
- this.config.setClass(ITERATION_CONVERGENCE_CRITERION, convergenceCriterionClass);
+ public void setConvergenceCriterion(String aggregatorName, ConvergenceCriterion<?> convCriterion) {
+ try {
+ InstantiationUtil.writeObjectToConfig(convCriterion, this.config, ITERATION_CONVERGENCE_CRITERION);
+ } catch (IOException e) {
+ throw new RuntimeException("Error while writing the convergence criterion object to the task configuration.");
+ }
this.config.setString(ITERATION_CONVERGENCE_CRITERION_AGG_NAME, aggregatorName);
}
- public <T extends Value> Class<? extends ConvergenceCriterion<T>> getConvergenceCriterion() {
- @SuppressWarnings("unchecked")
- Class<? extends ConvergenceCriterion<T>> clazz = (Class<? extends ConvergenceCriterion<T>>) (Class<?>)
- this.config.getClass(ITERATION_CONVERGENCE_CRITERION, null, ConvergenceCriterion.class);
- if (clazz == null) {
+ @SuppressWarnings("unchecked")
+ public <T extends Value> ConvergenceCriterion<T> getConvergenceCriterion() {
+ ConvergenceCriterion<T> convCriterionObj = null;
+ try {
+ convCriterionObj = (ConvergenceCriterion<T>) InstantiationUtil.readObjectFromConfig(
+ this.config, ITERATION_CONVERGENCE_CRITERION, getConfiguration().getClassLoader());
+ } catch (IOException e) {
+ throw new RuntimeException("Error while reading the covergence criterion object from the task configuration.");
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException("Error while reading the covergence criterion object from the task configuration. " +
+ "ConvergenceCriterion class not found.");
+ }
+ if (convCriterionObj == null) {
throw new NullPointerException();
}
- return clazz;
+ return convCriterionObj;
}
public boolean usesConvergenceCriterion() {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-scala/src/main/scala/eu/stratosphere/api/scala/operators/IterateOperators.scala
----------------------------------------------------------------------
diff --git a/stratosphere-scala/src/main/scala/eu/stratosphere/api/scala/operators/IterateOperators.scala b/stratosphere-scala/src/main/scala/eu/stratosphere/api/scala/operators/IterateOperators.scala
index 7247817..9a62b11 100644
--- a/stratosphere-scala/src/main/scala/eu/stratosphere/api/scala/operators/IterateOperators.scala
+++ b/stratosphere-scala/src/main/scala/eu/stratosphere/api/scala/operators/IterateOperators.scala
@@ -223,4 +223,4 @@ object WorksetIterateMacros {
return result
}
-}
\ No newline at end of file
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/AggregatorsITCase.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/AggregatorsITCase.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/AggregatorsITCase.java
new file mode 100644
index 0000000..0302e07
--- /dev/null
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/AggregatorsITCase.java
@@ -0,0 +1,466 @@
+/***********************************************************************************************************************
+ *
+ * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ *
+ **********************************************************************************************************************/
+package eu.stratosphere.test.iterative.aggregators;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.Random;
+
+import junit.framework.Assert;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import eu.stratosphere.api.common.aggregators.ConvergenceCriterion;
+import eu.stratosphere.api.common.aggregators.LongSumAggregator;
+import eu.stratosphere.api.java.DataSet;
+import eu.stratosphere.api.java.DeltaIteration;
+import eu.stratosphere.api.java.ExecutionEnvironment;
+import eu.stratosphere.api.java.IterativeDataSet;
+import eu.stratosphere.api.java.functions.FlatMapFunction;
+import eu.stratosphere.api.java.functions.MapFunction;
+import eu.stratosphere.api.java.tuple.Tuple2;
+import eu.stratosphere.configuration.Configuration;
+import eu.stratosphere.test.javaApiOperators.util.CollectionDataSets;
+import eu.stratosphere.test.util.JavaProgramTestBase;
+import eu.stratosphere.types.LongValue;
+import eu.stratosphere.util.Collector;
+
+/**
+ * Test the functionality of aggregators in bulk and delta iterative cases.
+ *
+ */
+@RunWith(Parameterized.class)
+public class AggregatorsITCase extends JavaProgramTestBase {
+
+ private static final int NUM_PROGRAMS = 5;
+ private static final int MAX_ITERATIONS = 20;
+ private static final int DOP = 2;
+
+ private int curProgId = config.getInteger("ProgramId", -1);
+ private String resultPath;
+ private String expectedResult;
+
+ public AggregatorsITCase(Configuration config) {
+ super(config);
+ }
+
+ @Override
+ protected void preSubmit() throws Exception {
+ resultPath = getTempDirPath("result");
+ }
+
+ @Override
+ protected void testProgram() throws Exception {
+ expectedResult = AggregatorProgs.runProgram(curProgId, resultPath);
+ }
+
+ @Override
+ protected void postSubmit() throws Exception {
+
+ compareResultsByLinesInMemory(expectedResult, resultPath);
+ }
+
+ @Parameters
+ public static Collection<Object[]> getConfigurations() throws FileNotFoundException, IOException {
+
+ LinkedList<Configuration> tConfigs = new LinkedList<Configuration>();
+
+ for(int i=1; i <= NUM_PROGRAMS; i++) {
+ Configuration config = new Configuration();
+ config.setInteger("ProgramId", i);
+ tConfigs.add(config);
+ }
+
+ return toParameterList(tConfigs);
+ }
+
+ private static class AggregatorProgs {
+
+ private static final String NEGATIVE_ELEMENTS_AGGR = "count.negative.elements";
+
+ public static String runProgram(int progId, String resultPath) throws Exception {
+
+ switch(progId) {
+ case 1: {
+ /*
+ * Test aggregator without parameter for iterate
+ */
+
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(DOP);
+
+ DataSet<Integer> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env);
+ IterativeDataSet<Integer> iteration = initialSolutionSet.iterate(MAX_ITERATIONS);
+
+ // register aggregator
+ LongSumAggregator aggr = new LongSumAggregator();
+ iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr);
+
+ // register convergence criterion
+ iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr,
+ new NegativeElementsConvergenceCriterion());
+
+ DataSet<Integer> updatedDs = iteration.map(new SubtractOneMap());
+ iteration.closeWith(updatedDs).writeAsText(resultPath);
+ env.execute();
+
+ // return expected result
+ return "-3\n" + "-2\n" + "-2\n" + "-1\n" + "-1\n"
+ + "-1\n" + "0\n" + "0\n" + "0\n" + "0\n"
+ + "1\n" + "1\n" + "1\n" + "1\n" + "1\n";
+ }
+ case 2: {
+ /*
+ * Test aggregator with parameter for iterate
+ */
+
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(DOP);
+
+ DataSet<Integer> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env);
+ IterativeDataSet<Integer> iteration = initialSolutionSet.iterate(MAX_ITERATIONS);
+
+ // register aggregator
+ LongSumAggregatorWithParameter aggr = new LongSumAggregatorWithParameter(0);
+ iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr);
+
+ // register convergence criterion
+ iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr,
+ new NegativeElementsConvergenceCriterion());
+
+ DataSet<Integer> updatedDs = iteration.map(new SubtractOneMapWithParam());
+ iteration.closeWith(updatedDs).writeAsText(resultPath);
+ env.execute();
+
+ // return expected result
+ return "-3\n" + "-2\n" + "-2\n" + "-1\n" + "-1\n"
+ + "-1\n" + "0\n" + "0\n" + "0\n" + "0\n"
+ + "1\n" + "1\n" + "1\n" + "1\n" + "1\n";
+ }
+ case 3: {
+ /*
+ * Test convergence criterion with parameter for iterate
+ */
+
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(DOP);
+
+ DataSet<Integer> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env);
+ IterativeDataSet<Integer> iteration = initialSolutionSet.iterate(MAX_ITERATIONS);
+
+ // register aggregator
+ LongSumAggregator aggr = new LongSumAggregator();
+ iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr);
+
+ // register convergence criterion
+ iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr,
+ new NegativeElementsConvergenceCriterionWithParam(3));
+
+ DataSet<Integer> updatedDs = iteration.map(new SubtractOneMap());
+ iteration.closeWith(updatedDs).writeAsText(resultPath);
+ env.execute();
+
+ // return expected result
+ return "-3\n" + "-2\n" + "-2\n" + "-1\n" + "-1\n"
+ + "-1\n" + "0\n" + "0\n" + "0\n" + "0\n"
+ + "1\n" + "1\n" + "1\n" + "1\n" + "1\n";
+ }
+ case 4: {
+ /*
+ * Test aggregator without parameter for iterateDelta
+ */
+
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(DOP);
+
+ DataSet<Tuple2<Integer, Integer>> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env).map(new TupleMakerMap());
+
+ DeltaIteration<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> iteration = initialSolutionSet.iterateDelta(
+ initialSolutionSet, MAX_ITERATIONS, 0);
+
+ // register aggregator
+ LongSumAggregator aggr = new LongSumAggregator();
+ iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr);
+
+ DataSet<Tuple2<Integer, Integer>> updatedDs = iteration.getWorkset().map(new AggregateMapDelta());
+
+ DataSet<Tuple2<Integer, Integer>> newElements = updatedDs.join(iteration.getSolutionSet())
+ .where(0).equalTo(0).flatMap(new UpdateFilter());
+
+ DataSet<Tuple2<Integer, Integer>> iterationRes = iteration.closeWith(newElements, newElements);
+ DataSet<Integer> result = iterationRes.map(new ProjectSecondMapper());
+ result.writeAsText(resultPath);
+
+ env.execute();
+
+ // return expected result
+ return "1\n" + "2\n" + "2\n" + "3\n" + "3\n"
+ + "3\n" + "4\n" + "4\n" + "4\n" + "4\n"
+ + "5\n" + "5\n" + "5\n" + "5\n" + "5\n";
+
+ }
+ case 5: {
+ /*
+ * Test aggregator with parameter for iterateDelta
+ */
+
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(DOP);
+
+ DataSet<Tuple2<Integer, Integer>> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env).map(new TupleMakerMap());
+
+ DeltaIteration<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> iteration = initialSolutionSet.iterateDelta(
+ initialSolutionSet, MAX_ITERATIONS, 0);
+
+ // register aggregator
+ LongSumAggregator aggr = new LongSumAggregatorWithParameter(4);
+ iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr);
+
+ DataSet<Tuple2<Integer, Integer>> updatedDs = iteration.getWorkset().map(new AggregateMapDelta());
+
+ DataSet<Tuple2<Integer, Integer>> newElements = updatedDs.join(iteration.getSolutionSet())
+ .where(0).equalTo(0).flatMap(new UpdateFilter());
+
+ DataSet<Tuple2<Integer, Integer>> iterationRes = iteration.closeWith(newElements, newElements);
+ DataSet<Integer> result = iterationRes.map(new ProjectSecondMapper());
+ result.writeAsText(resultPath);
+
+ env.execute();
+
+ // return expected result
+ return "1\n" + "2\n" + "2\n" + "3\n" + "3\n"
+ + "3\n" + "4\n" + "4\n" + "4\n" + "4\n"
+ + "5\n" + "5\n" + "5\n" + "5\n" + "5\n";
+ }
+ default:
+ throw new IllegalArgumentException("Invalid program id");
+ }
+
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static final class NegativeElementsConvergenceCriterion implements ConvergenceCriterion<LongValue> {
+
+ @Override
+ public boolean isConverged(int iteration, LongValue value) {
+ return value.getValue() > 3;
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static final class NegativeElementsConvergenceCriterionWithParam implements ConvergenceCriterion<LongValue> {
+
+ private int value;
+
+ public NegativeElementsConvergenceCriterionWithParam(int val) {
+ this.value = val;
+ }
+
+ public int getValue() {
+ return this.value;
+ }
+
+ @Override
+ public boolean isConverged(int iteration, LongValue value) {
+ return value.getValue() > this.value;
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static final class SubtractOneMap extends MapFunction<Integer, Integer> {
+
+ private LongSumAggregator aggr;
+
+ @Override
+ public void open(Configuration conf) {
+
+ aggr = getIterationRuntimeContext().getIterationAggregator(AggregatorProgs.NEGATIVE_ELEMENTS_AGGR);
+ }
+
+ @Override
+ public Integer map(Integer value) {
+ Integer newValue = new Integer(value.intValue() - 1);
+ // count negative numbers
+ if (newValue.intValue() < 0) {
+ aggr.aggregate(1l);
+ }
+ return newValue;
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static final class SubtractOneMapWithParam extends MapFunction<Integer, Integer> {
+
+ private LongSumAggregatorWithParameter aggr;
+
+ @Override
+ public void open(Configuration conf) {
+
+ aggr = getIterationRuntimeContext().getIterationAggregator(AggregatorProgs.NEGATIVE_ELEMENTS_AGGR);
+ }
+
+ @Override
+ public Integer map(Integer value) {
+ Integer newValue = new Integer(value.intValue() - 1);
+ // count numbers less then the aggregator parameter
+ if ( newValue.intValue() < aggr.getValue() ) {
+ aggr.aggregate(1l);
+ }
+ return newValue;
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static class LongSumAggregatorWithParameter extends LongSumAggregator {
+
+ private int value;
+
+ public LongSumAggregatorWithParameter(int val) {
+ this.value = val;
+ }
+
+ public int getValue() {
+ return this.value;
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static final class TupleMakerMap extends MapFunction<Integer, Tuple2<Integer, Integer>> {
+
+ @Override
+ public Tuple2<Integer, Integer> map(Integer value) throws Exception {
+ Random ran = new Random();
+ Integer nodeId = new Integer(ran.nextInt(100000));
+ return new Tuple2<Integer, Integer>(nodeId, value);
+ }
+
+ }
+
+ @SuppressWarnings("serial")
+ public static final class AggregateMapDelta extends MapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
+
+ private LongSumAggregator aggr;
+ private LongValue previousAggr;
+ private int superstep;
+
+ @Override
+ public void open(Configuration conf) {
+
+ aggr = getIterationRuntimeContext().getIterationAggregator(AggregatorProgs.NEGATIVE_ELEMENTS_AGGR);
+ superstep = getIterationRuntimeContext().getSuperstepNumber();
+
+ if (superstep > 1) {
+ previousAggr = getIterationRuntimeContext().getPreviousIterationAggregate(AggregatorProgs.NEGATIVE_ELEMENTS_AGGR);
+ // check previous aggregator value
+ Assert.assertEquals(superstep - 1, previousAggr.getValue());
+ }
+
+ }
+
+ @Override
+ public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) {
+ // count the elements that are equal to the superstep number
+ if (value.f1.intValue() == superstep) {
+ aggr.aggregate(1l);
+ }
+ return value;
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static final class UpdateFilter extends FlatMapFunction<Tuple2<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>,
+ Tuple2<Integer, Integer>> {
+
+ private int superstep;
+
+ @Override
+ public void open(Configuration conf) {
+
+ superstep = getIterationRuntimeContext().getSuperstepNumber();
+
+ }
+
+ @Override
+ public void flatMap(Tuple2<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> value,
+ Collector<Tuple2<Integer, Integer>> out) throws Exception {
+
+ if (value.f0.f1 > superstep) {
+ out.collect(value.f0);
+ }
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static final class ProjectSecondMapper extends MapFunction<Tuple2<Integer, Integer>, Integer> {
+
+ @Override
+ public Integer map(Tuple2<Integer, Integer> value) {
+ return value.f1;
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static final class AggregateMapDeltaWithParam extends MapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
+
+ private LongSumAggregatorWithParameter aggr;
+ private LongValue previousAggr;
+ private int superstep;
+
+ @Override
+ public void open(Configuration conf) {
+
+ aggr = getIterationRuntimeContext().getIterationAggregator(AggregatorProgs.NEGATIVE_ELEMENTS_AGGR);
+ superstep = getIterationRuntimeContext().getSuperstepNumber();
+
+ if (superstep > 1) {
+ previousAggr = getIterationRuntimeContext().getPreviousIterationAggregate(AggregatorProgs.NEGATIVE_ELEMENTS_AGGR);
+
+ // check previous aggregator value
+ switch(superstep) {
+ case 2: {
+ Assert.assertEquals(6, previousAggr.getValue());
+ }
+ case 3: {
+ Assert.assertEquals(5, previousAggr.getValue());
+ }
+ case 4: {
+ Assert.assertEquals(3, previousAggr.getValue());
+ }
+ case 5: {
+ Assert.assertEquals(0, previousAggr.getValue());
+ }
+ default:
+ }
+ Assert.assertEquals(superstep-1, previousAggr.getValue());
+ }
+
+ }
+
+ @Override
+ public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) {
+ // count the elements that are equal to the superstep number
+ if (value.f1.intValue() < aggr.getValue()) {
+ aggr.aggregate(1l);
+ }
+ return value;
+ }
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/ConnectedComponentsWithParametrizableAggregatorITCase.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/ConnectedComponentsWithParametrizableAggregatorITCase.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/ConnectedComponentsWithParametrizableAggregatorITCase.java
new file mode 100644
index 0000000..42538eb
--- /dev/null
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/ConnectedComponentsWithParametrizableAggregatorITCase.java
@@ -0,0 +1,237 @@
+/***********************************************************************************************************************
+ * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ **********************************************************************************************************************/
+
+package eu.stratosphere.test.iterative.aggregators;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import junit.framework.Assert;
+
+import eu.stratosphere.api.common.aggregators.LongSumAggregator;
+import eu.stratosphere.api.java.DataSet;
+import eu.stratosphere.api.java.ExecutionEnvironment;
+import eu.stratosphere.api.java.IterativeDataSet;
+import eu.stratosphere.api.java.functions.FlatMapFunction;
+import eu.stratosphere.api.java.functions.GroupReduceFunction;
+import eu.stratosphere.api.java.functions.JoinFunction;
+import eu.stratosphere.api.java.tuple.Tuple2;
+import eu.stratosphere.configuration.Configuration;
+import eu.stratosphere.test.util.JavaProgramTestBase;
+import eu.stratosphere.types.LongValue;
+import eu.stratosphere.util.Collector;
+
+
+/**
+ *
+ * Connected Components test case that uses a parametrizable aggregator
+ *
+ */
+public class ConnectedComponentsWithParametrizableAggregatorITCase extends JavaProgramTestBase {
+
+ private static final int MAX_ITERATIONS = 5;
+ private static final int DOP = 1;
+
+ protected static List<Tuple2<Long, Long>> verticesInput = new ArrayList<Tuple2<Long, Long>>();
+ protected static List<Tuple2<Long, Long>> edgesInput = new ArrayList<Tuple2<Long, Long>>();
+ private String resultPath;
+ private String expectedResult;
+
+ @Override
+ protected void preSubmit() throws Exception {
+ // vertices input
+ verticesInput.add(new Tuple2<Long, Long>(1l,1l));
+ verticesInput.add(new Tuple2<Long, Long>(2l,2l));
+ verticesInput.add(new Tuple2<Long, Long>(3l,3l));
+ verticesInput.add(new Tuple2<Long, Long>(4l,4l));
+ verticesInput.add(new Tuple2<Long, Long>(5l,5l));
+ verticesInput.add(new Tuple2<Long, Long>(6l,6l));
+ verticesInput.add(new Tuple2<Long, Long>(7l,7l));
+ verticesInput.add(new Tuple2<Long, Long>(8l,8l));
+ verticesInput.add(new Tuple2<Long, Long>(9l,9l));
+
+ // vertices input
+ edgesInput.add(new Tuple2<Long, Long>(1l,2l));
+ edgesInput.add(new Tuple2<Long, Long>(1l,3l));
+ edgesInput.add(new Tuple2<Long, Long>(2l,3l));
+ edgesInput.add(new Tuple2<Long, Long>(2l,4l));
+ edgesInput.add(new Tuple2<Long, Long>(2l,1l));
+ edgesInput.add(new Tuple2<Long, Long>(3l,1l));
+ edgesInput.add(new Tuple2<Long, Long>(3l,2l));
+ edgesInput.add(new Tuple2<Long, Long>(4l,2l));
+ edgesInput.add(new Tuple2<Long, Long>(4l,6l));
+ edgesInput.add(new Tuple2<Long, Long>(5l,6l));
+ edgesInput.add(new Tuple2<Long, Long>(6l,4l));
+ edgesInput.add(new Tuple2<Long, Long>(6l,5l));
+ edgesInput.add(new Tuple2<Long, Long>(7l,8l));
+ edgesInput.add(new Tuple2<Long, Long>(7l,9l));
+ edgesInput.add(new Tuple2<Long, Long>(8l,7l));
+ edgesInput.add(new Tuple2<Long, Long>(8l,9l));
+ edgesInput.add(new Tuple2<Long, Long>(9l,7l));
+ edgesInput.add(new Tuple2<Long, Long>(9l,8l));
+
+ resultPath = getTempDirPath("result");
+
+ expectedResult = "(1, 1)\n" + "(2, 1)\n" + "(3, 1)\n" + "(4, 1)\n" +
+ "(5, 1)\n" + "(6, 1)\n" + "(7, 7)\n" + "(8, 7)\n" + "(9, 7)\n";
+ }
+
+ @Override
+ protected void testProgram() throws Exception {
+ ConnectedComponentsWithAggregatorProgram.runProgram(resultPath);
+ }
+
+ @Override
+ protected void postSubmit() throws Exception {
+ compareResultsByLinesInMemory(expectedResult, resultPath);
+ long[] aggr_values = ConnectedComponentsWithAggregatorProgram.aggr_value;
+ Assert.assertEquals(3, aggr_values[0]);
+ Assert.assertEquals(4, aggr_values[1]);
+ Assert.assertEquals(5, aggr_values[2]);
+ Assert.assertEquals(6, aggr_values[3]);
+ Assert.assertEquals(6, aggr_values[4]);
+ }
+
+
+ private static class ConnectedComponentsWithAggregatorProgram {
+
+ private static final String ELEMENTS_IN_COMPONENT = "elements.in.component.aggregator";
+ private static final long componentId = 1l;
+ private static long [] aggr_value = new long [MAX_ITERATIONS];
+
+ public static String runProgram(String resultPath) throws Exception {
+
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(DOP);
+
+ DataSet<Tuple2<Long, Long>> initialSolutionSet = env.fromCollection(verticesInput);
+ DataSet<Tuple2<Long, Long>> edges = env.fromCollection(edgesInput);
+
+ IterativeDataSet<Tuple2<Long, Long>> iteration =
+ initialSolutionSet.iterate(MAX_ITERATIONS);
+
+ // register the aggregator
+ iteration.registerAggregator(ELEMENTS_IN_COMPONENT, new LongSumAggregatorWithParameter(componentId));
+
+ DataSet<Tuple2<Long, Long>> verticesWithNewComponents = iteration.join(edges).where(0).equalTo(0)
+ .with(new NeighborWithComponentIDJoin())
+ .groupBy(0).reduceGroup(new MinimumReduce());
+
+ DataSet<Tuple2<Long, Long>> updatedComponentId =
+ verticesWithNewComponents.join(iteration).where(0).equalTo(0)
+ .flatMap(new MinimumIdFilter());
+
+ iteration.closeWith(updatedComponentId).writeAsText(resultPath);
+
+ env.execute();
+
+ return resultPath;
+ }
+ }
+
+ public static final class NeighborWithComponentIDJoin extends JoinFunction
+ <Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public Tuple2<Long, Long> join(Tuple2<Long, Long> vertexWithCompId,
+ Tuple2<Long, Long> edge) throws Exception {
+
+ vertexWithCompId.setField(edge.f1, 0);
+ return vertexWithCompId;
+ }
+ }
+
+ public static final class MinimumReduce extends GroupReduceFunction
+ <Tuple2<Long, Long>, Tuple2<Long, Long>> {
+
+ private static final long serialVersionUID = 1L;
+ final Tuple2<Long, Long> resultVertex = new Tuple2<Long, Long>();
+
+ @Override
+ public void reduce(Iterator<Tuple2<Long, Long>> values,
+ Collector<Tuple2<Long, Long>> out) throws Exception {
+
+ final Tuple2<Long, Long> first = values.next();
+ final Long vertexId = first.f0;
+ Long minimumCompId = first.f1;
+
+ while (values.hasNext()) {
+ Long candidateCompId = values.next().f1;
+ if (candidateCompId < minimumCompId) {
+ minimumCompId = candidateCompId;
+ }
+ }
+ resultVertex.setField(vertexId, 0);
+ resultVertex.setField(minimumCompId, 1);
+
+ out.collect(resultVertex);
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static final class MinimumIdFilter extends FlatMapFunction
+ <Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>, Tuple2<Long, Long>> {
+
+ private static LongSumAggregatorWithParameter aggr;
+
+ @Override
+ public void open(Configuration conf) {
+ aggr = getIterationRuntimeContext().getIterationAggregator(
+ ConnectedComponentsWithAggregatorProgram.ELEMENTS_IN_COMPONENT);
+
+ int superstep = getIterationRuntimeContext().getSuperstepNumber();
+
+ if (superstep > 1) {
+ LongValue val = getIterationRuntimeContext().getPreviousIterationAggregate(
+ ConnectedComponentsWithAggregatorProgram.ELEMENTS_IN_COMPONENT);
+ ConnectedComponentsWithAggregatorProgram.aggr_value[superstep-2] = val.getValue();
+ }
+ }
+
+ @Override
+ public void flatMap(
+ Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>> vertexWithNewAndOldId,
+ Collector<Tuple2<Long, Long>> out) throws Exception {
+
+ if (vertexWithNewAndOldId.f0.f1 < vertexWithNewAndOldId.f1.f1) {
+ out.collect(vertexWithNewAndOldId.f0);
+ if (vertexWithNewAndOldId.f0.f1 == aggr.getComponentId()) {
+ aggr.aggregate(1l);
+ }
+ } else {
+ out.collect(vertexWithNewAndOldId.f1);
+ if (vertexWithNewAndOldId.f1.f1 == aggr.getComponentId()) {
+ aggr.aggregate(1l);
+ }
+ }
+ }
+ }
+
+ // A LongSumAggregator with one parameter
+ @SuppressWarnings("serial")
+ public static final class LongSumAggregatorWithParameter extends LongSumAggregator {
+
+ private long componentId;
+
+ public LongSumAggregatorWithParameter(long compId) {
+ this.componentId = compId;
+ }
+
+ public long getComponentId() {
+ return this.componentId;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/ConnectedComponentsWithParametrizableConvergenceITCase.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/ConnectedComponentsWithParametrizableConvergenceITCase.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/ConnectedComponentsWithParametrizableConvergenceITCase.java
new file mode 100644
index 0000000..ef798cc
--- /dev/null
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/aggregators/ConnectedComponentsWithParametrizableConvergenceITCase.java
@@ -0,0 +1,223 @@
+/***********************************************************************************************************************
+ * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ **********************************************************************************************************************/
+
+package eu.stratosphere.test.iterative.aggregators;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import eu.stratosphere.api.common.aggregators.ConvergenceCriterion;
+import eu.stratosphere.api.common.aggregators.LongSumAggregator;
+import eu.stratosphere.api.java.DataSet;
+import eu.stratosphere.api.java.ExecutionEnvironment;
+import eu.stratosphere.api.java.IterativeDataSet;
+import eu.stratosphere.api.java.functions.FlatMapFunction;
+import eu.stratosphere.api.java.functions.GroupReduceFunction;
+import eu.stratosphere.api.java.functions.JoinFunction;
+import eu.stratosphere.api.java.tuple.Tuple2;
+import eu.stratosphere.configuration.Configuration;
+import eu.stratosphere.test.util.JavaProgramTestBase;
+import eu.stratosphere.types.LongValue;
+import eu.stratosphere.util.Collector;
+
+
+/**
+ *
+ * Connected Components test case that uses a parametrizable convergence criterion
+ *
+ */
+public class ConnectedComponentsWithParametrizableConvergenceITCase extends JavaProgramTestBase {
+
+ private static final int MAX_ITERATIONS = 10;
+ private static final int DOP = 1;
+
+ protected static List<Tuple2<Long, Long>> verticesInput = new ArrayList<Tuple2<Long, Long>>();
+ protected static List<Tuple2<Long, Long>> edgesInput = new ArrayList<Tuple2<Long, Long>>();
+ private String resultPath;
+ private String expectedResult;
+
+ @Override
+ protected void preSubmit() throws Exception {
+ // vertices input
+ verticesInput.add(new Tuple2<Long, Long>(1l,1l));
+ verticesInput.add(new Tuple2<Long, Long>(2l,2l));
+ verticesInput.add(new Tuple2<Long, Long>(3l,3l));
+ verticesInput.add(new Tuple2<Long, Long>(4l,4l));
+ verticesInput.add(new Tuple2<Long, Long>(5l,5l));
+ verticesInput.add(new Tuple2<Long, Long>(6l,6l));
+ verticesInput.add(new Tuple2<Long, Long>(7l,7l));
+ verticesInput.add(new Tuple2<Long, Long>(8l,8l));
+ verticesInput.add(new Tuple2<Long, Long>(9l,9l));
+
+ // vertices input
+ edgesInput.add(new Tuple2<Long, Long>(1l,2l));
+ edgesInput.add(new Tuple2<Long, Long>(1l,3l));
+ edgesInput.add(new Tuple2<Long, Long>(2l,3l));
+ edgesInput.add(new Tuple2<Long, Long>(2l,4l));
+ edgesInput.add(new Tuple2<Long, Long>(2l,1l));
+ edgesInput.add(new Tuple2<Long, Long>(3l,1l));
+ edgesInput.add(new Tuple2<Long, Long>(3l,2l));
+ edgesInput.add(new Tuple2<Long, Long>(4l,2l));
+ edgesInput.add(new Tuple2<Long, Long>(4l,6l));
+ edgesInput.add(new Tuple2<Long, Long>(5l,6l));
+ edgesInput.add(new Tuple2<Long, Long>(6l,4l));
+ edgesInput.add(new Tuple2<Long, Long>(6l,5l));
+ edgesInput.add(new Tuple2<Long, Long>(7l,8l));
+ edgesInput.add(new Tuple2<Long, Long>(7l,9l));
+ edgesInput.add(new Tuple2<Long, Long>(8l,7l));
+ edgesInput.add(new Tuple2<Long, Long>(8l,9l));
+ edgesInput.add(new Tuple2<Long, Long>(9l,7l));
+ edgesInput.add(new Tuple2<Long, Long>(9l,8l));
+
+ resultPath = getTempDirPath("result");
+
+ expectedResult = "(1, 1)\n" + "(2, 1)\n" + "(3, 1)\n" + "(4, 1)\n" +
+ "(5, 2)\n" + "(6, 1)\n" + "(7, 7)\n" + "(8, 7)\n" + "(9, 7)\n";
+ }
+
+ @Override
+ protected void testProgram() throws Exception {
+ ConnectedComponentsWithConvergenceProgram.runProgram(resultPath);
+ }
+
+ @Override
+ protected void postSubmit() throws Exception {
+ compareResultsByLinesInMemory(expectedResult, resultPath);
+ }
+
+
+ private static class ConnectedComponentsWithConvergenceProgram {
+
+ private static final String UPDATED_ELEMENTS = "updated.elements.aggr";
+ private static final long convergence_threshold = 3; // the iteration stops if less than this number os elements change value
+
+ public static String runProgram(String resultPath) throws Exception {
+
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(DOP);
+
+ DataSet<Tuple2<Long, Long>> initialSolutionSet = env.fromCollection(verticesInput);
+ DataSet<Tuple2<Long, Long>> edges = env.fromCollection(edgesInput);
+
+ IterativeDataSet<Tuple2<Long, Long>> iteration =
+ initialSolutionSet.iterate(MAX_ITERATIONS);
+
+ // register the convergence criterion
+ iteration.registerAggregationConvergenceCriterion(UPDATED_ELEMENTS,
+ new LongSumAggregator(), new UpdatedElementsConvergenceCriterion(convergence_threshold));
+
+ DataSet<Tuple2<Long, Long>> verticesWithNewComponents = iteration.join(edges).where(0).equalTo(0)
+ .with(new NeighborWithComponentIDJoin())
+ .groupBy(0).reduceGroup(new MinimumReduce());
+
+ DataSet<Tuple2<Long, Long>> updatedComponentId =
+ verticesWithNewComponents.join(iteration).where(0).equalTo(0)
+ .flatMap(new MinimumIdFilter());
+
+ iteration.closeWith(updatedComponentId).writeAsText(resultPath);
+
+ env.execute();
+
+ return resultPath;
+ }
+ }
+
+ public static final class NeighborWithComponentIDJoin extends JoinFunction
+ <Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public Tuple2<Long, Long> join(Tuple2<Long, Long> vertexWithCompId,
+ Tuple2<Long, Long> edge) throws Exception {
+
+ vertexWithCompId.setField(edge.f1, 0);
+ return vertexWithCompId;
+ }
+ }
+
+ public static final class MinimumReduce extends GroupReduceFunction
+ <Tuple2<Long, Long>, Tuple2<Long, Long>> {
+
+ private static final long serialVersionUID = 1L;
+ final Tuple2<Long, Long> resultVertex = new Tuple2<Long, Long>();
+
+ @Override
+ public void reduce(Iterator<Tuple2<Long, Long>> values,
+ Collector<Tuple2<Long, Long>> out) throws Exception {
+
+ final Tuple2<Long, Long> first = values.next();
+ final Long vertexId = first.f0;
+ Long minimumCompId = first.f1;
+
+ while (values.hasNext()) {
+ Long candidateCompId = values.next().f1;
+ if (candidateCompId < minimumCompId) {
+ minimumCompId = candidateCompId;
+ }
+ }
+ resultVertex.setField(vertexId, 0);
+ resultVertex.setField(minimumCompId, 1);
+
+ out.collect(resultVertex);
+ }
+ }
+
+ @SuppressWarnings("serial")
+ public static final class MinimumIdFilter extends FlatMapFunction
+ <Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>, Tuple2<Long, Long>> {
+
+ private static LongSumAggregator aggr;
+
+ @Override
+ public void open(Configuration conf) {
+ aggr = getIterationRuntimeContext().getIterationAggregator(
+ ConnectedComponentsWithConvergenceProgram.UPDATED_ELEMENTS);
+ }
+
+ @Override
+ public void flatMap(
+ Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>> vertexWithNewAndOldId,
+ Collector<Tuple2<Long, Long>> out) throws Exception {
+
+ if (vertexWithNewAndOldId.f0.f1 < vertexWithNewAndOldId.f1.f1) {
+ out.collect(vertexWithNewAndOldId.f0);
+ aggr.aggregate(1l);
+ } else {
+ out.collect(vertexWithNewAndOldId.f1);
+ }
+ }
+ }
+
+ // A Convergence Criterion with one parameter
+ @SuppressWarnings("serial")
+ public static final class UpdatedElementsConvergenceCriterion implements ConvergenceCriterion<LongValue> {
+
+ private long threshold;
+
+ public UpdatedElementsConvergenceCriterion(long u_threshold) {
+ this.threshold = u_threshold;
+ }
+
+ public long getThreshold() {
+ return this.threshold;
+ }
+
+ @Override
+ public boolean isConverged(int iteration, LongValue value) {
+ return value.getValue() < this.threshold;
+ }
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/ConnectedComponentsNepheleITCase.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/ConnectedComponentsNepheleITCase.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/ConnectedComponentsNepheleITCase.java
index 40068b7..7eff1aa 100644
--- a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/ConnectedComponentsNepheleITCase.java
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/ConnectedComponentsNepheleITCase.java
@@ -276,7 +276,7 @@ public class ConnectedComponentsNepheleITCase extends RecordAPITestBase {
headConfig.setMemoryDriver(MEM_PER_CONSUMER * JobGraphUtils.MEGABYTE);
headConfig.addIterationAggregator(
- WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, LongSumAggregator.class);
+ WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator());
}
return head;
@@ -351,9 +351,9 @@ public class ConnectedComponentsNepheleITCase extends RecordAPITestBase {
syncConfig.setNumberOfIterations(maxIterations);
syncConfig.setIterationId(ITERATION_ID);
syncConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME,
- LongSumAggregator.class);
+ new LongSumAggregator());
syncConfig.setConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME,
- WorksetEmptyConvergenceCriterion.class);
+ new WorksetEmptyConvergenceCriterion());
return sync;
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/customdanglingpagerank/CustomCompensatableDanglingPageRank.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/customdanglingpagerank/CustomCompensatableDanglingPageRank.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/customdanglingpagerank/CustomCompensatableDanglingPageRank.java
index 06badd9..b112741 100644
--- a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/customdanglingpagerank/CustomCompensatableDanglingPageRank.java
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/customdanglingpagerank/CustomCompensatableDanglingPageRank.java
@@ -190,7 +190,7 @@ public class CustomCompensatableDanglingPageRank {
headConfig.setStubParameter("compensation.failingWorker", failingWorkers);
headConfig.setStubParameter("compensation.failingIteration", String.valueOf(failingIteration));
headConfig.setStubParameter("compensation.messageLoss", String.valueOf(messageLoss));
- headConfig.addIterationAggregator(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, PageRankStatsAggregator.class);
+ headConfig.addIterationAggregator(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, new PageRankStatsAggregator());
// --------------- the join ---------------------
@@ -277,8 +277,8 @@ public class CustomCompensatableDanglingPageRank {
JobOutputVertex sync = JobGraphUtils.createSync(jobGraph, degreeOfParallelism);
TaskConfig syncConfig = new TaskConfig(sync.getConfiguration());
syncConfig.setNumberOfIterations(numIterations);
- syncConfig.addIterationAggregator(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, PageRankStatsAggregator.class);
- syncConfig.setConvergenceCriterion(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, DiffL1NormConvergenceCriterion.class);
+ syncConfig.addIterationAggregator(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, new PageRankStatsAggregator());
+ syncConfig.setConvergenceCriterion(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, new DiffL1NormConvergenceCriterion());
syncConfig.setIterationId(ITERATION_ID);
// --------------- the wiring ---------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/customdanglingpagerank/CustomCompensatableDanglingPageRankWithCombiner.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/customdanglingpagerank/CustomCompensatableDanglingPageRankWithCombiner.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/customdanglingpagerank/CustomCompensatableDanglingPageRankWithCombiner.java
index 3e51808..37fab39 100644
--- a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/customdanglingpagerank/CustomCompensatableDanglingPageRankWithCombiner.java
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/customdanglingpagerank/CustomCompensatableDanglingPageRankWithCombiner.java
@@ -190,7 +190,7 @@ public class CustomCompensatableDanglingPageRankWithCombiner {
headConfig.setStubParameter("compensation.failingWorker", failingWorkers);
headConfig.setStubParameter("compensation.failingIteration", String.valueOf(failingIteration));
headConfig.setStubParameter("compensation.messageLoss", String.valueOf(messageLoss));
- headConfig.addIterationAggregator(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, PageRankStatsAggregator.class);
+ headConfig.addIterationAggregator(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, new PageRankStatsAggregator());
// --------------- the join ---------------------
@@ -257,7 +257,7 @@ public class CustomCompensatableDanglingPageRankWithCombiner {
tailConfig.setMemoryInput(1, coGroupSortMemory * JobGraphUtils.MEGABYTE);
tailConfig.setFilehandlesInput(1, NUM_FILE_HANDLES_PER_SORT);
tailConfig.setSpillingThresholdInput(1, SORT_SPILL_THRESHOLD);
- tailConfig.addIterationAggregator(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, PageRankStatsAggregator.class);
+ tailConfig.addIterationAggregator(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, new PageRankStatsAggregator());
// output
tailConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
@@ -289,8 +289,8 @@ public class CustomCompensatableDanglingPageRankWithCombiner {
JobOutputVertex sync = JobGraphUtils.createSync(jobGraph, degreeOfParallelism);
TaskConfig syncConfig = new TaskConfig(sync.getConfiguration());
syncConfig.setNumberOfIterations(numIterations);
- syncConfig.addIterationAggregator(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, PageRankStatsAggregator.class);
- syncConfig.setConvergenceCriterion(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, DiffL1NormConvergenceCriterion.class);
+ syncConfig.addIterationAggregator(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, new PageRankStatsAggregator());
+ syncConfig.setConvergenceCriterion(CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, new DiffL1NormConvergenceCriterion());
syncConfig.setIterationId(ITERATION_ID);
// --------------- the wiring ---------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/CompensatableDanglingPageRank.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/CompensatableDanglingPageRank.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/CompensatableDanglingPageRank.java
index b50c33e..944f13b 100644
--- a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/CompensatableDanglingPageRank.java
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/CompensatableDanglingPageRank.java
@@ -170,7 +170,7 @@ public class CompensatableDanglingPageRank {
headConfig.setStubParameter("compensation.failingWorker", failingWorkers);
headConfig.setStubParameter("compensation.failingIteration", String.valueOf(failingIteration));
headConfig.setStubParameter("compensation.messageLoss", String.valueOf(messageLoss));
- headConfig.addIterationAggregator(CompensatableDotProductCoGroup.AGGREGATOR_NAME, PageRankStatsAggregator.class);
+ headConfig.addIterationAggregator(CompensatableDotProductCoGroup.AGGREGATOR_NAME, new PageRankStatsAggregator());
// --------------- the join ---------------------
@@ -257,8 +257,8 @@ public class CompensatableDanglingPageRank {
JobOutputVertex sync = JobGraphUtils.createSync(jobGraph, degreeOfParallelism);
TaskConfig syncConfig = new TaskConfig(sync.getConfiguration());
syncConfig.setNumberOfIterations(numIterations);
- syncConfig.addIterationAggregator(CompensatableDotProductCoGroup.AGGREGATOR_NAME, PageRankStatsAggregator.class);
- syncConfig.setConvergenceCriterion(CompensatableDotProductCoGroup.AGGREGATOR_NAME, DiffL1NormConvergenceCriterion.class);
+ syncConfig.addIterationAggregator(CompensatableDotProductCoGroup.AGGREGATOR_NAME, new PageRankStatsAggregator());
+ syncConfig.setConvergenceCriterion(CompensatableDotProductCoGroup.AGGREGATOR_NAME, new DiffL1NormConvergenceCriterion());
syncConfig.setIterationId(ITERATION_ID);
// --------------- the wiring ---------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/DiffL1NormConvergenceCriterion.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/DiffL1NormConvergenceCriterion.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/DiffL1NormConvergenceCriterion.java
index 3dbc0c5..1cc0651 100644
--- a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/DiffL1NormConvergenceCriterion.java
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/DiffL1NormConvergenceCriterion.java
@@ -18,6 +18,7 @@ import eu.stratosphere.api.common.aggregators.ConvergenceCriterion;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+@SuppressWarnings("serial")
public class DiffL1NormConvergenceCriterion implements ConvergenceCriterion<PageRankStats> {
private static final double EPSILON = 0.00005;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/PageRankStatsAggregator.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/PageRankStatsAggregator.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/PageRankStatsAggregator.java
index 108ef63..062eef4 100644
--- a/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/PageRankStatsAggregator.java
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/iterative/nephele/danglingpagerank/PageRankStatsAggregator.java
@@ -15,6 +15,7 @@ package eu.stratosphere.test.iterative.nephele.danglingpagerank;
import eu.stratosphere.api.common.aggregators.Aggregator;
+@SuppressWarnings("serial")
public class PageRankStatsAggregator implements Aggregator<PageRankStats> {
private double diff = 0;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/DanglingPageRank.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/DanglingPageRank.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/DanglingPageRank.java
index ed104c7..d244c06 100644
--- a/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/DanglingPageRank.java
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/DanglingPageRank.java
@@ -82,7 +82,8 @@ public class DanglingPageRank implements Program, ProgramDescription {
iteration.setNextPartialSolution(rankAggregation);
iteration.setMaximumNumberOfIterations(numIterations);
- iteration.getAggregators().registerAggregationConvergenceCriterion(DotProductCoGroup.AGGREGATOR_NAME, PageRankStatsAggregator.class, DiffL1NormConvergenceCriterion.class);
+ iteration.getAggregators().registerAggregationConvergenceCriterion(DotProductCoGroup.AGGREGATOR_NAME, new PageRankStatsAggregator(),
+ new DiffL1NormConvergenceCriterion());
FileDataSink out = new FileDataSink(new PageWithRankOutFormat(), outputPath, iteration, "Final Ranks");
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/pageRankUtil/DiffL1NormConvergenceCriterion.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/pageRankUtil/DiffL1NormConvergenceCriterion.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/pageRankUtil/DiffL1NormConvergenceCriterion.java
index cc7b06a..3224a6f 100644
--- a/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/pageRankUtil/DiffL1NormConvergenceCriterion.java
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/pageRankUtil/DiffL1NormConvergenceCriterion.java
@@ -18,6 +18,7 @@ import org.apache.commons.logging.LogFactory;
import eu.stratosphere.api.common.aggregators.ConvergenceCriterion;
+@SuppressWarnings("serial")
public class DiffL1NormConvergenceCriterion implements ConvergenceCriterion<PageRankStats> {
private static final double EPSILON = 0.00005;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/08f189ad/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/pageRankUtil/PageRankStatsAggregator.java
----------------------------------------------------------------------
diff --git a/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/pageRankUtil/PageRankStatsAggregator.java b/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/pageRankUtil/PageRankStatsAggregator.java
index 569f2e0..7e3bc7e 100644
--- a/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/pageRankUtil/PageRankStatsAggregator.java
+++ b/stratosphere-tests/src/test/java/eu/stratosphere/test/recordJobs/graph/pageRankUtil/PageRankStatsAggregator.java
@@ -15,6 +15,7 @@ package eu.stratosphere.test.recordJobs.graph.pageRankUtil;
import eu.stratosphere.api.common.aggregators.Aggregator;
+@SuppressWarnings("serial")
public class PageRankStatsAggregator implements Aggregator<PageRankStats> {
private double diff = 0;