You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by jo...@apache.org on 2019/01/22 04:25:12 UTC
[incubator-nemo] branch reshaping updated: aggr done
This is an automated email from the ASF dual-hosted git repository.
johnyangk pushed a commit to branch reshaping
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git
The following commit(s) were added to refs/heads/reshaping by this push:
new 35bf7e9 aggr done
35bf7e9 is described below
commit 35bf7e9894196a24c296c27e074fbb5d9377de51
Author: John Yang <jo...@apache.org>
AuthorDate: Tue Jan 22 13:24:59 2019 +0900
aggr done
---
.../main/java/org/apache/nemo/common/ir/IRDAG.java | 20 +++++++++++-----
.../ir/vertex/system/MessageAggregationVertex.java | 28 +++-------------------
.../transform/MessageAggregateTransform.java | 12 ++++++----
.../compiletime/reshaping/SkewReshapingPass.java | 28 ++++++++++++++++++----
4 files changed, 47 insertions(+), 41 deletions(-)
diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
index e12b7d3..c5de9cb 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
@@ -25,6 +25,7 @@ import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.*;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
+import org.apache.nemo.common.ir.vertex.system.MessageAggregationVertex;
import org.apache.nemo.common.ir.vertex.system.MessageBarrierVertex;
import org.apache.nemo.common.ir.vertex.system.StreamVertex;
import org.apache.nemo.common.ir.vertex.system.SystemIRVertex;
@@ -71,6 +72,8 @@ public class IRDAG {
////////////////////////////////////////////////// Reshaping methods.
/**
+ * Inserts a new vertex that streams data.
+ *
* Before: src > edgeToStreamize > dst
* After: src > edgeToStreamizeWithNewDestination > streamVertex > oneToOneEdge > dst
* (replaces the "Before" relationships)
@@ -120,14 +123,19 @@ public class IRDAG {
}
/**
+ * Inserts a new vertex that analyzes intermediate data, and triggers a dynamic optimization.
+ *
* Before: src > edgeToGetStatisticsOf > dst
- * After: src > oneToOneEdge(a clone of edgeToGetStatisticsOf) > messageBarrierVertex
+ * After: src > oneToOneEdge(a clone of edgeToGetStatisticsOf) > messageBarrierVertex > messageAggregationVertex > dst
* (the "Before" relationships are unmodified)
*
* @param messageBarrierVertex to insert.
+ * @param messageAggregationVertex to insert.
* @param edgeToGetStatisticsOf to clone and examine.
*/
- public void insert(final MessageBarrierVertex messageBarrierVertex, final IREdge edgeToGetStatisticsOf) {
+ public void insert(final MessageBarrierVertex messageBarrierVertex,
+ final MessageAggregationVertex messageAggregationVertex,
+ final IREdge edgeToGetStatisticsOf) {
// Create a completely new DAG with the vertex inserted.
final DAGBuilder builder = new DAGBuilder();
@@ -141,9 +149,8 @@ public class IRDAG {
for (final IREdge edge : dag.getIncomingEdgesOf(v)) {
if (edge.equals(edgeToGetStatisticsOf)) {
// MATCH!
- final OperatorVertex abv = generateMetricAggregationVertex();
builder.addVertex(messageBarrierVertex);
- builder.addVertex(abv);
+ builder.addVertex(messageAggregationVertex);
// Clone the edgeToGetStatisticsOf
final IREdge clone = new IREdge(
@@ -153,12 +160,13 @@ public class IRDAG {
builder.connectVertices(clone);
// messageBarrierVertex to the messageAggregationVertex
- final IREdge edgeToABV = generateEdgeToABV(edge, messageBarrierVertex, abv);
+ final IREdge edgeToABV = generateEdgeToABV(edge, messageBarrierVertex, messageAggregationVertex);
builder.connectVertices(edgeToABV);
// Connection vertex
// Add an control dependency (no output)
- final IREdge emptyEdge = new IREdge(CommunicationPatternProperty.Value.BroadCast, abv, v);
+ final IREdge emptyEdge =
+ new IREdge(CommunicationPatternProperty.Value.BroadCast, messageAggregationVertex, v);
builder.connectVertices(emptyEdge);
// The original edge
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/system/MessageAggregationVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/system/MessageAggregationVertex.java
index 6b1ff1a..f4c634a 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/system/MessageAggregationVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/system/MessageAggregationVertex.java
@@ -19,34 +19,12 @@
package org.apache.nemo.common.ir.vertex.system;
import org.apache.nemo.common.Pair;
-import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.transform.MessageAggregateTransform;
-import java.io.Serializable;
-import java.util.HashMap;
-import java.util.Map;
import java.util.function.BiFunction;
-public class MessageAggregationVertex extends SystemIRVertex {
- public MessageAggregationVertex() {
- // Define a custom data aggregator for skew handling.
- // Here, the aggregator gathers key frequency data used in shuffle data repartitioning.
- final BiFunction<Object, Map<Object, Long>, Map<Object, Long>> dynOptDataAggregator =
- (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
- (element, aggregatedDynOptData) -> {
- final Object key = ((Pair<Object, Long>) element).left();
- final Long count = ((Pair<Object, Long>) element).right();
-
- final Map<Object, Long> aggregatedDynOptDataMap = (Map<Object, Long>) aggregatedDynOptData;
- if (aggregatedDynOptDataMap.containsKey(key)) {
- aggregatedDynOptDataMap.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
- } else {
- aggregatedDynOptDataMap.put(key, count);
- }
- return aggregatedDynOptData;
- };
- final MessageAggregateTransform abt =
- new MessageAggregateTransform<Pair<Object, Long>, Map<Object, Long>>(new HashMap<>(), dynOptDataAggregator);
- return new OperatorVertex(abt);
+public class MessageAggregationVertex<K, V, O> extends SystemIRVertex {
+ public MessageAggregationVertex(final O initialState, final BiFunction<Pair<K, V>, O, O> userFunction) {
+ super(new MessageAggregateTransform<>(initialState, userFunction));
}
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregateTransform.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregateTransform.java
index 97e6613..34af7d5 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregateTransform.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregateTransform.java
@@ -18,6 +18,7 @@
*/
package org.apache.nemo.common.ir.vertex.transform;
+import org.apache.nemo.common.Pair;
import org.apache.nemo.common.ir.OutputCollector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -27,14 +28,15 @@ import java.util.function.BiFunction;
/**
* A {@link Transform} that aggregates statistics generated by the {@link MessageBarrierTransform}.
*
- * @param <I> input type.
+ * @param <K> input key type.
+ * @param <V> input value type.
* @param <O> output type.
*/
-public final class MessageAggregateTransform<I, O> extends NoWatermarkEmitTransform<I, O> {
+public final class MessageAggregateTransform<K, V, O> extends NoWatermarkEmitTransform<Pair<K, V>, O> {
private static final Logger LOG = LoggerFactory.getLogger(MessageAggregateTransform.class.getName());
private OutputCollector<O> outputCollector;
private O aggregatedDynOptData;
- private final BiFunction<Object, O, O> dynOptDataAggregator;
+ private final BiFunction<Pair<K, V>, O, O> dynOptDataAggregator;
/**
* Default constructor.
@@ -42,7 +44,7 @@ public final class MessageAggregateTransform<I, O> extends NoWatermarkEmitTransf
* @param dynOptDataAggregator aggregator to use.
*/
public MessageAggregateTransform(final O aggregatedDynOptData,
- final BiFunction<Object, O, O> dynOptDataAggregator) {
+ final BiFunction<Pair<K, V>, O, O> dynOptDataAggregator) {
this.aggregatedDynOptData = aggregatedDynOptData;
this.dynOptDataAggregator = dynOptDataAggregator;
}
@@ -53,7 +55,7 @@ public final class MessageAggregateTransform<I, O> extends NoWatermarkEmitTransf
}
@Override
- public void onData(final I element) {
+ public void onData(final Pair<K, V> element) {
aggregatedDynOptData = dynOptDataAggregator.apply(element, aggregatedDynOptData);
}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
index fe6f9e9..680b725 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
@@ -19,9 +19,11 @@
package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
import org.apache.nemo.common.KeyExtractor;
+import org.apache.nemo.common.Pair;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.*;
+import org.apache.nemo.common.ir.vertex.system.MessageAggregationVertex;
import org.apache.nemo.common.ir.vertex.system.MessageBarrierVertex;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.Annotates;
@@ -29,6 +31,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Serializable;
+import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;
@@ -59,8 +62,7 @@ public final class SkewReshapingPass extends ReshapingPass {
// Get the key extractor
final KeyExtractor keyExtractor = edge.getPropertyValue(KeyExtractorProperty.class).get();
- // Define a custom data collector for skew handling.
- // Here, the collector gathers key frequency data used in shuffle data repartitioning.
+ // For collecting the data
final BiFunction<Object, Map<Object, Object>, Map<Object, Object>> dynOptDataCollector =
(BiFunction<Object, Map<Object, Object>, Map<Object, Object>> & Serializable)
(element, dynOptData) -> {
@@ -73,10 +75,26 @@ public final class SkewReshapingPass extends ReshapingPass {
return dynOptData;
};
- final MessageBarrierVertex mbv = new MessageBarrierVertex<>(dynOptDataCollector);
+ // For aggregating the collected data
+ final BiFunction<Object, Map<Object, Long>, Map<Object, Long>> dynOptDataAggregator =
+ (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
+ (element, aggregatedDynOptData) -> {
+ final Object key = ((Pair<Object, Long>) element).left();
+ final Long count = ((Pair<Object, Long>) element).right();
+
+ final Map<Object, Long> aggregatedDynOptDataMap = (Map<Object, Long>) aggregatedDynOptData;
+ if (aggregatedDynOptDataMap.containsKey(key)) {
+ aggregatedDynOptDataMap.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
+ } else {
+ aggregatedDynOptDataMap.put(key, count);
+ }
+ return aggregatedDynOptData;
+ };
- // Insert the vertex
- dag.insert(mbv, edge);
+ // Insert the vertices
+ final MessageBarrierVertex mbv = new MessageBarrierVertex<>(dynOptDataCollector);
+ final MessageAggregationVertex mav = new MessageAggregationVertex(new HashMap(), dynOptDataAggregator);
+ dag.insert(mbv, mav, edge);
}
}
});