You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by jo...@apache.org on 2019/01/22 08:43:32 UTC

[incubator-nemo] branch reshaping updated: skew debugged

This is an automated email from the ASF dual-hosted git repository.

johnyangk pushed a commit to branch reshaping
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git


The following commit(s) were added to refs/heads/reshaping by this push:
     new 409fc67  skew debugged
409fc67 is described below

commit 409fc67cbe6934a243cc0f9e8a75a00387aa67e1
Author: John Yang <jo...@apache.org>
AuthorDate: Tue Jan 22 17:43:21 2019 +0900

    skew debugged
---
 .../nemo/common/coder/PairDecoderFactory.java      | 11 +++++
 .../nemo/common/coder/PairEncoderFactory.java      | 11 +++++
 .../main/java/org/apache/nemo/common/ir/IRDAG.java | 51 ++++++++++------------
 .../vertex/transform/MessageBarrierTransform.java  |  2 +
 .../nemo/compiler/optimizer/PairKeyExtractor.java  | 36 ---------------
 .../compiletime/reshaping/SkewReshapingPass.java   | 32 +++++++++-----
 6 files changed, 67 insertions(+), 76 deletions(-)

diff --git a/common/src/main/java/org/apache/nemo/common/coder/PairDecoderFactory.java b/common/src/main/java/org/apache/nemo/common/coder/PairDecoderFactory.java
index 519a2e2..790690d 100644
--- a/common/src/main/java/org/apache/nemo/common/coder/PairDecoderFactory.java
+++ b/common/src/main/java/org/apache/nemo/common/coder/PairDecoderFactory.java
@@ -63,6 +63,17 @@ public final class PairDecoderFactory<A, B> implements DecoderFactory<Pair<A, B>
     return new PairDecoder<>(inputStream, leftDecoderFactory, rightDecoderFactory);
   }
 
+  @Override
+  public String toString() {
+    final StringBuilder sb = new StringBuilder();
+    sb.append("Pair(");
+    sb.append(leftDecoderFactory.toString());
+    sb.append(", ");
+    sb.append(rightDecoderFactory.toString());
+    sb.append(")");
+    return sb.toString();
+  }
+
   /**
    * PairDecoder.
    * @param <T1> type for the left coder.
diff --git a/common/src/main/java/org/apache/nemo/common/coder/PairEncoderFactory.java b/common/src/main/java/org/apache/nemo/common/coder/PairEncoderFactory.java
index 12edfcc..030c336 100644
--- a/common/src/main/java/org/apache/nemo/common/coder/PairEncoderFactory.java
+++ b/common/src/main/java/org/apache/nemo/common/coder/PairEncoderFactory.java
@@ -62,6 +62,17 @@ public final class PairEncoderFactory<A, B> implements EncoderFactory<Pair<A, B>
     return new PairEncoder<>(outputStream, leftEncoderFactory, rightEncoderFactory);
   }
 
+  @Override
+  public String toString() {
+    final StringBuilder sb = new StringBuilder();
+    sb.append("Pair(");
+    sb.append(leftEncoderFactory.toString());
+    sb.append(", ");
+    sb.append(rightEncoderFactory.toString());
+    sb.append(")");
+    return sb.toString();
+  }
+
   /**
    * PairEncoder.
    * @param <T1> type for the left coder.
diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
index 23b038b..1972e9c 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
@@ -20,16 +20,16 @@ package org.apache.nemo.common.ir;
 
 import org.apache.nemo.common.KeyExtractor;
 import org.apache.nemo.common.Pair;
-import org.apache.nemo.common.coder.*;
 import org.apache.nemo.common.dag.DAG;
 import org.apache.nemo.common.dag.DAGBuilder;
 import org.apache.nemo.common.ir.edge.IREdge;
 import org.apache.nemo.common.ir.edge.executionproperty.*;
 import org.apache.nemo.common.ir.vertex.IRVertex;
-import org.apache.nemo.common.ir.vertex.OperatorVertex;
 import org.apache.nemo.common.ir.vertex.system.MessageAggregationVertex;
 import org.apache.nemo.common.ir.vertex.system.MessageBarrierVertex;
 import org.apache.nemo.common.ir.vertex.system.StreamVertex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.util.List;
 import java.util.function.Consumer;
@@ -47,6 +47,8 @@ import java.util.function.Predicate;
  * - Reshaping: insert(), delete() on the IRDAG
  */
 public final class IRDAG {
+  private static final Logger LOG = LoggerFactory.getLogger(IRDAG.class.getName());
+
   private DAG<IRVertex, IREdge> dag; // internal DAG, can be updated by reshaping methods.
 
   public IRDAG(final DAG<IRVertex, IREdge> dag) {
@@ -177,6 +179,8 @@ public final class IRDAG {
    */
   public void insert(final MessageBarrierVertex messageBarrierVertex,
                      final MessageAggregationVertex messageAggregationVertex,
+                     final EncoderProperty mbvOutputEncoder,
+                     final DecoderProperty mbvOutputDecoder,
                      final IREdge edgeToGetStatisticsOf) {
     // Create a completely new DAG with the vertex inserted.
     final DAGBuilder builder = new DAGBuilder();
@@ -199,14 +203,18 @@ public final class IRDAG {
             CommunicationPatternProperty.Value.OneToOne, edge.getSrc(), messageBarrierVertex);
           clone.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
           clone.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
+          edge.getPropertyValue(AdditionalOutputTagProperty.class).ifPresent(tag -> {
+            clone.setProperty(AdditionalOutputTagProperty.of(tag));
+          });
           builder.connectVertices(clone);
 
           // messageBarrierVertex to the messageAggregationVertex
-          final IREdge edgeToABV = edgeBetweenMessageVertices(edge, messageBarrierVertex, messageAggregationVertex);
+          final IREdge edgeToABV = edgeBetweenMessageVertices(
+            messageBarrierVertex, messageAggregationVertex, mbvOutputEncoder, mbvOutputDecoder);
           builder.connectVertices(edgeToABV);
 
           // Connection vertex
-          // Add an control dependency (no output)
+          // Add a control dependency (no output)
           final IREdge emptyEdge =
             new IREdge(CommunicationPatternProperty.Value.BroadCast, messageAggregationVertex, v);
           builder.connectVertices(emptyEdge);
@@ -226,6 +234,8 @@ public final class IRDAG {
         }
       }
     });
+
+    dag = builder.build(); // update the DAG.
   }
 
   ////////////////////////////////////////////////// "Un-safe" direct reshaping (semantic-preserving is not guaranteed).
@@ -237,15 +247,15 @@ public final class IRDAG {
   ////////////////////////////////////////////////// Private helper methods.
 
   /**
-   * @param edge the original shuffle edge.
-   * @param mcv the vertex with MessageBarrierTransform.
-   * @param abv the vertex with MessageAggregateTransform.
+   * @param mbv the vertex with MessageBarrierTransform.
+   * @param mav the vertex with MessageAggregateTransform.
    * @return the generated egde from {@code mcv} to {@code abv}.
    */
-  private IREdge edgeBetweenMessageVertices(final IREdge edge,
-                                            final OperatorVertex mcv,
-                                            final OperatorVertex abv) {
-    final IREdge newEdge = new IREdge(CommunicationPatternProperty.Value.Shuffle, mcv, abv);
+  private IREdge edgeBetweenMessageVertices(final MessageBarrierVertex mbv,
+                                            final MessageAggregationVertex mav,
+                                            final EncoderProperty encoder,
+                                            final DecoderProperty decoder) {
+    final IREdge newEdge = new IREdge(CommunicationPatternProperty.Value.Shuffle, mbv, mav);
     newEdge.setProperty(DataStoreProperty.of(DataStoreProperty.Value.LocalFileStore));
     newEdge.setProperty(DataPersistenceProperty.of(DataPersistenceProperty.Value.Keep));
     newEdge.setProperty(DataFlowProperty.of(DataFlowProperty.Value.Push));
@@ -257,23 +267,8 @@ public final class IRDAG {
       }
     };
     newEdge.setProperty(KeyExtractorProperty.of(pairKeyExtractor));
-
-    // Dynamic optimization handles statistics on key-value data by default.
-    // We need to get coders for encoding/decoding the keys to send data to
-    // vertex with MessageAggregateTransform.
-    if (edge.getPropertyValue(KeyEncoderProperty.class).isPresent()
-      && edge.getPropertyValue(KeyDecoderProperty.class).isPresent()) {
-      final EncoderFactory keyEncoderFactory = edge.getPropertyValue(KeyEncoderProperty.class).get();
-      final DecoderFactory keyDecoderFactory = edge.getPropertyValue(KeyDecoderProperty.class).get();
-      newEdge.setPropertyPermanently(
-        EncoderProperty.of(PairEncoderFactory.of(keyEncoderFactory, LongEncoderFactory.of())));
-      newEdge.setPropertyPermanently(
-        DecoderProperty.of(PairDecoderFactory.of(keyDecoderFactory, LongDecoderFactory.of())));
-    } else {
-      // If not specified, follow encoder/decoder of the given shuffle edge.
-      throw new RuntimeException("Skew optimization request for none key - value format data!");
-    }
-
+    newEdge.setPropertyPermanently(encoder);
+    newEdge.setPropertyPermanently(decoder);
     return newEdge;
   }
 }
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageBarrierTransform.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageBarrierTransform.java
index 9036930..4c9a007 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageBarrierTransform.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/transform/MessageBarrierTransform.java
@@ -23,6 +23,7 @@ import org.apache.nemo.common.ir.OutputCollector;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.HashMap;
 import java.util.Map;
 import java.util.function.BiFunction;
 
@@ -50,6 +51,7 @@ public final class MessageBarrierTransform<I, K, V> extends NoWatermarkEmitTrans
   @Override
   public void prepare(final Context context, final OutputCollector<Pair<K, V>> oc) {
     this.outputCollector = oc;
+    this.holder = new HashMap<>();
   }
 
   @Override
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/PairKeyExtractor.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/PairKeyExtractor.java
deleted file mode 100644
index e7a4101..0000000
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/PairKeyExtractor.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.nemo.compiler.optimizer;
-
-import org.apache.nemo.common.KeyExtractor;
-import org.apache.nemo.common.Pair;
-
-/**
- * Extracts the key from a pair element.
- */
-public final class PairKeyExtractor implements KeyExtractor {
-  @Override
-  public Object extractKey(final Object element) {
-    if (element instanceof Pair) {
-      return ((Pair) element).left();
-    } else {
-      throw new IllegalStateException(element.toString());
-    }
-  }
-}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
index 55edfeb..eaa8045 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
@@ -20,6 +20,7 @@ package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
 
 import org.apache.nemo.common.KeyExtractor;
 import org.apache.nemo.common.Pair;
+import org.apache.nemo.common.coder.*;
 import org.apache.nemo.common.ir.IRDAG;
 import org.apache.nemo.common.ir.edge.IREdge;
 import org.apache.nemo.common.ir.edge.executionproperty.*;
@@ -59,12 +60,14 @@ public final class SkewReshapingPass extends ReshapingPass {
       for (final IREdge edge : dag.getIncomingEdgesOf(v)) {
         if (CommunicationPatternProperty.Value.Shuffle
           .equals(edge.getPropertyValue(CommunicationPatternProperty.class).get())) {
+          // Shuffle edge has the KeyExtractor, KeyEncoder, and KeyDecoder
+
           // Get the key extractor
           final KeyExtractor keyExtractor = edge.getPropertyValue(KeyExtractorProperty.class).get();
 
           // For collecting the data
-          final BiFunction<Object, Map<Object, Object>, Map<Object, Object>> dynOptDataCollector =
-            (BiFunction<Object, Map<Object, Object>, Map<Object, Object>> & Serializable)
+          final BiFunction<Object, Map<Object, Long>, Map<Object, Long>> dynOptDataCollector =
+            (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
               (element, dynOptData) -> {
                 Object key = keyExtractor.extractKey(element);
                 if (dynOptData.containsKey(key)) {
@@ -76,28 +79,33 @@ public final class SkewReshapingPass extends ReshapingPass {
               };
 
           // For aggregating the collected data
-          final BiFunction<Object, Map<Object, Long>, Map<Object, Long>> dynOptDataAggregator =
-            (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
+          final BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> dynOptDataAggregator =
+            (BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> & Serializable)
               (element, aggregatedDynOptData) -> {
-                final Object key = ((Pair<Object, Long>) element).left();
-                final Long count = ((Pair<Object, Long>) element).right();
-
-                final Map<Object, Long> aggregatedDynOptDataMap = (Map<Object, Long>) aggregatedDynOptData;
-                if (aggregatedDynOptDataMap.containsKey(key)) {
-                  aggregatedDynOptDataMap.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
+                final Object key = element.left();
+                final Long count = element.right();
+                if (aggregatedDynOptData.containsKey(key)) {
+                  aggregatedDynOptData.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
                 } else {
-                  aggregatedDynOptDataMap.put(key, count);
+                  aggregatedDynOptData.put(key, count);
                 }
                 return aggregatedDynOptData;
               };
 
+          // Coders to use
+          final EncoderProperty encoderProperty = EncoderProperty.of(
+            PairEncoderFactory.of(edge.getPropertyValue(KeyEncoderProperty.class).get(), LongEncoderFactory.of()));
+          final DecoderProperty decoderProperty = DecoderProperty.of(
+            PairDecoderFactory.of(edge.getPropertyValue(KeyDecoderProperty.class).get(), LongDecoderFactory.of()));
+
           // Insert the vertices
           final MessageBarrierVertex mbv = new MessageBarrierVertex<>(dynOptDataCollector);
           final MessageAggregationVertex mav = new MessageAggregationVertex(new HashMap(), dynOptDataAggregator);
-          dag.insert(mbv, mav, edge);
+          dag.insert(mbv, mav, encoderProperty, decoderProperty, edge);
         }
       }
     });
     return dag;
   }
 }
+