You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by jo...@apache.org on 2019/01/22 04:25:12 UTC

[incubator-nemo] branch reshaping updated: aggr done

This is an automated email from the ASF dual-hosted git repository.

johnyangk pushed a commit to branch reshaping
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git


The following commit(s) were added to refs/heads/reshaping by this push:
     new 35bf7e9  aggr done
35bf7e9 is described below

commit 35bf7e9894196a24c296c27e074fbb5d9377de51
Author: John Yang <jo...@apache.org>
AuthorDate: Tue Jan 22 13:24:59 2019 +0900

    aggr done
---
 .../main/java/org/apache/nemo/common/ir/IRDAG.java | 20 +++++++++++-----
 .../ir/vertex/system/MessageAggregationVertex.java | 28 +++-------------------
 .../transform/MessageAggregateTransform.java       | 12 ++++++----
 .../compiletime/reshaping/SkewReshapingPass.java   | 28 ++++++++++++++++++----
 4 files changed, 47 insertions(+), 41 deletions(-)

diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
index e12b7d3..c5de9cb 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
@@ -25,6 +25,7 @@ import org.apache.nemo.common.ir.edge.IREdge;
 import org.apache.nemo.common.ir.edge.executionproperty.*;
 import org.apache.nemo.common.ir.vertex.IRVertex;
 import org.apache.nemo.common.ir.vertex.OperatorVertex;
+import org.apache.nemo.common.ir.vertex.system.MessageAggregationVertex;
 import org.apache.nemo.common.ir.vertex.system.MessageBarrierVertex;
 import org.apache.nemo.common.ir.vertex.system.StreamVertex;
 import org.apache.nemo.common.ir.vertex.system.SystemIRVertex;
@@ -71,6 +72,8 @@ public class IRDAG {
   ////////////////////////////////////////////////// Reshaping methods.
 
   /**
+   * Inserts a new vertex that streams data.
+   *
    * Before: src > edgeToStreamize > dst
    * After: src > edgeToStreamizeWithNewDestination > streamVertex > oneToOneEdge > dst
    * (replaces the "Before" relationships)
@@ -120,14 +123,19 @@ public class IRDAG {
   }
 
   /**
+   * Inserts a new vertex that analyzes intermediate data, and triggers a dynamic optimization.
+   *
    * Before: src > edgeToGetStatisticsOf > dst
-   * After: src > oneToOneEdge(a clone of edgeToGetStatisticsOf) > messageBarrierVertex
+   * After: src > oneToOneEdge(a clone of edgeToGetStatisticsOf) > messageBarrierVertex > messageAggregationVertex > dst
    * (the "Before" relationships are unmodified)
    *
    * @param messageBarrierVertex to insert.
+   * @param messageAggregationVertex to insert.
    * @param edgeToGetStatisticsOf to clone and examine.
    */
-  public void insert(final MessageBarrierVertex messageBarrierVertex, final IREdge edgeToGetStatisticsOf) {
+  public void insert(final MessageBarrierVertex messageBarrierVertex,
+                     final MessageAggregationVertex messageAggregationVertex,
+                     final IREdge edgeToGetStatisticsOf) {
     // Create a completely new DAG with the vertex inserted.
     final DAGBuilder builder = new DAGBuilder();
 
@@ -141,9 +149,8 @@ public class IRDAG {
       for (final IREdge edge : dag.getIncomingEdgesOf(v)) {
         if (edge.equals(edgeToGetStatisticsOf)) {
           // MATCH!
-          final OperatorVertex abv = generateMetricAggregationVertex();
           builder.addVertex(messageBarrierVertex);
-          builder.addVertex(abv);
+          builder.addVertex(messageAggregationVertex);
 
           // Clone the edgeToGetStatisticsOf
           final IREdge clone = new IREdge(
@@ -153,12 +160,13 @@ public class IRDAG {
           builder.connectVertices(clone);
 
           // messageBarrierVertex to the messageAggregationVertex
-          final IREdge edgeToABV = generateEdgeToABV(edge, messageBarrierVertex, abv);
+          final IREdge edgeToABV = generateEdgeToABV(edge, messageBarrierVertex, messageAggregationVertex);
           builder.connectVertices(edgeToABV);
 
           // Connection vertex
           // Add an control dependency (no output)
-          final IREdge emptyEdge = new IREdge(CommunicationPatternProperty.Value.BroadCast, abv, v);
+          final IREdge emptyEdge =
+            new IREdge(CommunicationPatternProperty.Value.BroadCast, messageAggregationVertex, v);
           builder.connectVertices(emptyEdge);
 
           // The original edge
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/system/MessageAggregationVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/system/MessageAggregationVertex.java
index 6b1ff1a..f4c634a 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/system/MessageAggregationVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/system/MessageAggregationVertex.java
@@ -19,34 +19,12 @@
 package org.apache.nemo.common.ir.vertex.system;
 
 import org.apache.nemo.common.Pair;
-import org.apache.nemo.common.ir.vertex.OperatorVertex;
 import org.apache.nemo.common.ir.vertex.transform.MessageAggregateTransform;
 
-import java.io.Serializable;
-import java.util.HashMap;
-import java.util.Map;
 import java.util.function.BiFunction;
 
-public class MessageAggregationVertex extends SystemIRVertex {
-  public MessageAggregationVertex() {
-    // Define a custom data aggregator for skew handling.
-    // Here, the aggregator gathers key frequency data used in shuffle data repartitioning.
-    final BiFunction<Object, Map<Object, Long>, Map<Object, Long>> dynOptDataAggregator =
-      (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
-        (element, aggregatedDynOptData) -> {
-          final Object key = ((Pair<Object, Long>) element).left();
-          final Long count = ((Pair<Object, Long>) element).right();
-
-          final Map<Object, Long> aggregatedDynOptDataMap = (Map<Object, Long>) aggregatedDynOptData;
-          if (aggregatedDynOptDataMap.containsKey(key)) {
-            aggregatedDynOptDataMap.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
-          } else {
-            aggregatedDynOptDataMap.put(key, count);
-          }
-          return aggregatedDynOptData;
-        };
-    final MessageAggregateTransform abt =
-      new MessageAggregateTransform<Pair<Object, Long>, Map<Object, Long>>(new HashMap<>(), dynOptDataAggregator);
-    return new OperatorVertex(abt);
+public class MessageAggregationVertex<K, V, O> extends SystemIRVertex {
+  public MessageAggregationVertex(final O initialState, final BiFunction<Pair<K, V>, O, O> userFunction) {
+    super(new MessageAggregateTransform<>(initialState, userFunction));
   }
 }
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregateTransform.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregateTransform.java
index 97e6613..34af7d5 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregateTransform.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageAggregateTransform.java
@@ -18,6 +18,7 @@
  */
 package org.apache.nemo.common.ir.vertex.transform;
 
+import org.apache.nemo.common.Pair;
 import org.apache.nemo.common.ir.OutputCollector;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -27,14 +28,15 @@ import java.util.function.BiFunction;
 /**
  * A {@link Transform} that aggregates statistics generated by the {@link MessageBarrierTransform}.
  *
- * @param <I> input type.
+ * @param <K> input key type.
+ * @param <V> input value type.
  * @param <O> output type.
  */
-public final class MessageAggregateTransform<I, O> extends NoWatermarkEmitTransform<I, O> {
+public final class MessageAggregateTransform<K, V, O> extends NoWatermarkEmitTransform<Pair<K, V>, O> {
   private static final Logger LOG = LoggerFactory.getLogger(MessageAggregateTransform.class.getName());
   private OutputCollector<O> outputCollector;
   private O aggregatedDynOptData;
-  private final BiFunction<Object, O, O> dynOptDataAggregator;
+  private final BiFunction<Pair<K, V>, O, O> dynOptDataAggregator;
 
   /**
    * Default constructor.
@@ -42,7 +44,7 @@ public final class MessageAggregateTransform<I, O> extends NoWatermarkEmitTransf
    * @param dynOptDataAggregator aggregator to use.
    */
   public MessageAggregateTransform(final O aggregatedDynOptData,
-                                   final BiFunction<Object, O, O> dynOptDataAggregator) {
+                                   final BiFunction<Pair<K, V>, O, O> dynOptDataAggregator) {
     this.aggregatedDynOptData = aggregatedDynOptData;
     this.dynOptDataAggregator = dynOptDataAggregator;
   }
@@ -53,7 +55,7 @@ public final class MessageAggregateTransform<I, O> extends NoWatermarkEmitTransf
   }
 
   @Override
-  public void onData(final I element) {
+  public void onData(final Pair<K, V> element) {
     aggregatedDynOptData = dynOptDataAggregator.apply(element, aggregatedDynOptData);
   }
 
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
index fe6f9e9..680b725 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
@@ -19,9 +19,11 @@
 package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
 
 import org.apache.nemo.common.KeyExtractor;
+import org.apache.nemo.common.Pair;
 import org.apache.nemo.common.ir.IRDAG;
 import org.apache.nemo.common.ir.edge.IREdge;
 import org.apache.nemo.common.ir.edge.executionproperty.*;
+import org.apache.nemo.common.ir.vertex.system.MessageAggregationVertex;
 import org.apache.nemo.common.ir.vertex.system.MessageBarrierVertex;
 import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
 import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.Annotates;
@@ -29,6 +31,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.Serializable;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.function.BiFunction;
 
@@ -59,8 +62,7 @@ public final class SkewReshapingPass extends ReshapingPass {
           // Get the key extractor
           final KeyExtractor keyExtractor = edge.getPropertyValue(KeyExtractorProperty.class).get();
 
-          // Define a custom data collector for skew handling.
-          // Here, the collector gathers key frequency data used in shuffle data repartitioning.
+          // For collecting the data
           final BiFunction<Object, Map<Object, Object>, Map<Object, Object>> dynOptDataCollector =
             (BiFunction<Object, Map<Object, Object>, Map<Object, Object>> & Serializable)
               (element, dynOptData) -> {
@@ -73,10 +75,26 @@ public final class SkewReshapingPass extends ReshapingPass {
                 return dynOptData;
               };
 
-          final MessageBarrierVertex mbv = new MessageBarrierVertex<>(dynOptDataCollector);
+          // For aggregating the collected data
+          final BiFunction<Object, Map<Object, Long>, Map<Object, Long>> dynOptDataAggregator =
+            (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
+              (element, aggregatedDynOptData) -> {
+                final Object key = ((Pair<Object, Long>) element).left();
+                final Long count = ((Pair<Object, Long>) element).right();
+
+                final Map<Object, Long> aggregatedDynOptDataMap = (Map<Object, Long>) aggregatedDynOptData;
+                if (aggregatedDynOptDataMap.containsKey(key)) {
+                  aggregatedDynOptDataMap.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
+                } else {
+                  aggregatedDynOptDataMap.put(key, count);
+                }
+                return aggregatedDynOptData;
+              };
 
-          // Insert the vertex
-          dag.insert(mbv, edge);
+          // Insert the vertices
+          final MessageBarrierVertex mbv = new MessageBarrierVertex<>(dynOptDataCollector);
+          final MessageAggregationVertex mav = new MessageAggregationVertex(new HashMap(), dynOptDataAggregator);
+          dag.insert(mbv, mav, edge);
         }
       }
     });