You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by sa...@apache.org on 2019/02/21 07:38:58 UTC
[incubator-nemo] branch master updated: [NEMO-338] SkewSamplingPass
(#193)
This is an automated email from the ASF dual-hosted git repository.
sanha pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git
The following commit(s) were added to refs/heads/master by this push:
new a2b02dc [NEMO-338] SkewSamplingPass (#193)
a2b02dc is described below
commit a2b02dc85b433c68b8feafe63248e05c0724b4f8
Author: John Yang <jo...@gmail.com>
AuthorDate: Thu Feb 21 16:38:54 2019 +0900
[NEMO-338] SkewSamplingPass (#193)
JIRA: [NEMO-338: SkewSamplingPass](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-338)
**Major changes:**
- SamplingSkewReshapingPass: Inserts SkewSampling and MessageBarrier vertices
- SamplingVertex: Instantiated with (originalVertex, desiredSampleRate)
- IRDAG: Automatically inserts IREdges from/to SamplingVertex objects, similar to other insert() methods
- PhysicalPlanGenerator: Handles SamplingVertex objects appropriately
- Stage: Uses getTaskIndices(), returns a subset of tasks if consists of SamplingVertex objects, to determine the tasks to execute
**Minor changes to note:**
- Refactors other insert() methods to share code as much as possible
**Tests for the changes:**
- PerKeyMedianITCase#testLargeShuffleSamplingSkew (combines large shuffle + skew handling optimizations)
**Other comments:**
- Sanha(@sanha) wrote the original code. I refactored the code and added comments to create this PR.
Closes #193
---
.../java/org/apache/nemo/client/JobLauncher.java | 2 +-
.../src/main/java/org/apache/nemo/common/Util.java | 104 +++++++-
.../main/java/org/apache/nemo/common/dag/DAG.java | 5 +
.../org/apache/nemo/common/dag/DAGBuilder.java | 27 +-
.../org/apache/nemo/common/dag/DAGInterface.java | 18 +-
.../main/java/org/apache/nemo/common/ir/IRDAG.java | 295 ++++++++++++++++-----
.../org/apache/nemo/common/ir/edge/IREdge.java | 15 +-
...eIdProperty.java => MessageIdEdgeProperty.java} | 10 +-
.../apache/nemo/common/ir/vertex/LoopVertex.java | 2 +-
.../apache/nemo/common/ir/vertex/SourceVertex.java | 2 +-
.../MessageIdVertexProperty.java} | 14 +-
.../ir/vertex/utility/MessageAggregatorVertex.java | 11 +-
.../common/ir/vertex/utility/SamplingVertex.java | 119 +++++++++
.../java/org/apache/nemo/common/util/UtilTest.java | 1 +
.../compiler/backend/nemo/NemoPlanRewriter.java | 6 +-
.../annotating/LargeShuffleAnnotatingPass.java | 18 +-
.../reshaping/SamplingSkewReshapingPass.java | 137 ++++++++++
.../compiletime/reshaping/SkewHandlingUtil.java | 77 ++++++
.../compiletime/reshaping/SkewReshapingPass.java | 49 +---
.../policy/SamplingLargeShuffleSkewPolicy.java | 58 ++++
.../compiler/backend/nemo/DAGConverterTest.java | 4 +-
.../nemo/examples/beam/PerKeyMedianITCase.java | 13 +
.../beam_test_one_executor_resources.json | 2 +-
.../runtime/common/plan/PhysicalPlanGenerator.java | 80 +++++-
.../org/apache/nemo/runtime/common/plan/Stage.java | 18 +-
.../apache/nemo/runtime/common/plan/StageEdge.java | 8 +-
.../common/plan/PhysicalPlanGeneratorTest.java | 5 +-
.../org/apache/nemo/runtime/executor/TestUtil.java | 4 +-
.../executor/datatransfer/DataTransferTest.java | 7 +-
.../nemo/runtime/master/PlanStateManager.java | 46 ++--
.../runtime/master/scheduler/BatchScheduler.java | 32 +--
31 files changed, 950 insertions(+), 239 deletions(-)
diff --git a/client/src/main/java/org/apache/nemo/client/JobLauncher.java b/client/src/main/java/org/apache/nemo/client/JobLauncher.java
index 2a89a5e..ba6194b 100644
--- a/client/src/main/java/org/apache/nemo/client/JobLauncher.java
+++ b/client/src/main/java/org/apache/nemo/client/JobLauncher.java
@@ -181,7 +181,7 @@ public final class JobLauncher {
LOG.info("Wait for the driver to finish");
driverLauncher.wait();
} catch (final InterruptedException e) {
- LOG.warn("Interrupted: " + e);
+ LOG.warn("Interrupted: ", e);
// clean up state...
Thread.currentThread().interrupt();
}
diff --git a/common/src/main/java/org/apache/nemo/common/Util.java b/common/src/main/java/org/apache/nemo/common/Util.java
index 288c5c2..f888ceb 100644
--- a/common/src/main/java/org/apache/nemo/common/Util.java
+++ b/common/src/main/java/org/apache/nemo/common/Util.java
@@ -16,14 +16,22 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.nemo.common.util;
+package org.apache.nemo.common;
+import org.apache.nemo.common.ir.edge.IREdge;
+import org.apache.nemo.common.ir.edge.executionproperty.*;
+import org.apache.nemo.common.ir.vertex.IRVertex;
+
+import java.util.Collection;
import java.util.function.IntPredicate;
+import java.util.stream.Collectors;
/**
* Class to hold the utility methods.
*/
public final class Util {
+ // Assume that this tag is never used in user application
+ public static final String CONTROL_EDGE_TAG = "CONTROL_EDGE";
/**
* Private constructor for utility class.
@@ -42,7 +50,7 @@ public final class Util {
* @return whether or not we can say that they are equal.
*/
public static boolean checkEqualityOfIntPredicates(final IntPredicate firstPredicate,
- final IntPredicate secondPredicate, final int noOfTimes) {
+ final IntPredicate secondPredicate, final int noOfTimes) {
for (int value = 0; value <= noOfTimes; value++) {
if (firstPredicate.test(value) != secondPredicate.test(value)) {
return false;
@@ -51,4 +59,96 @@ public final class Util {
return true;
}
+ /**
+ * @param edgeToClone to copy execution properties from.
+ * @param newSrc of the new edge.
+ * @param newDst of the new edge.
+ * @return the new edge.
+ */
+ public static IREdge cloneEdge(final IREdge edgeToClone,
+ final IRVertex newSrc,
+ final IRVertex newDst) {
+ return cloneEdge(
+ edgeToClone.getPropertyValue(CommunicationPatternProperty.class).get(), edgeToClone, newSrc, newDst);
+ }
+
+ /**
+ * Creates a new edge with several execution properties same as the given edge.
+ * The copied execution properties include those minimally required for execution, such as encoder/decoders.
+ *
+ * @param commPattern to use.
+ * @param edgeToClone to copy execution properties from.
+ * @param newSrc of the new edge.
+ * @param newDst of the new edge.
+ * @return the new edge.
+ */
+ public static IREdge cloneEdge(final CommunicationPatternProperty.Value commPattern,
+ final IREdge edgeToClone,
+ final IRVertex newSrc,
+ final IRVertex newDst) {
+ final IREdge clone = new IREdge(commPattern, newSrc, newDst);
+
+ if (edgeToClone.getPropertySnapshot().containsKey(EncoderProperty.class)) {
+ clone.setProperty(edgeToClone.getPropertySnapshot().get(EncoderProperty.class));
+ } else {
+ clone.setProperty(EncoderProperty.of(edgeToClone.getPropertyValue(EncoderProperty.class)
+ .orElseThrow(IllegalStateException::new)));
+ }
+
+ if (edgeToClone.getPropertySnapshot().containsKey(DecoderProperty.class)) {
+ clone.setProperty(edgeToClone.getPropertySnapshot().get(DecoderProperty.class));
+ } else {
+ clone.setProperty(DecoderProperty.of(edgeToClone.getPropertyValue(DecoderProperty.class)
+ .orElseThrow(IllegalStateException::new)));
+ }
+
+ edgeToClone.getPropertyValue(AdditionalOutputTagProperty.class).ifPresent(tag -> {
+ clone.setProperty(AdditionalOutputTagProperty.of(tag));
+ });
+
+ edgeToClone.getPropertyValue(PartitionerProperty.class).ifPresent(p -> {
+ if (p.right() == PartitionerProperty.NUM_EQUAL_TO_DST_PARALLELISM) {
+ clone.setProperty(PartitionerProperty.of(p.left()));
+ } else {
+ clone.setProperty(PartitionerProperty.of(p.left(), p.right()));
+ }
+ });
+
+ edgeToClone.getPropertyValue(KeyExtractorProperty.class).ifPresent(ke -> {
+ clone.setProperty(KeyExtractorProperty.of(ke));
+ });
+
+ return clone;
+ }
+
+ /**
+ * A control edge enforces an execution ordering between the source vertex and the destination vertex.
+ * The additional output tag property of control edges is set such that no actual data element is transferred
+ * via the edges. This minimizes the run-time overhead of executing control edges.
+ *
+ * @param src vertex.
+ * @param dst vertex.
+ * @return the control edge.
+ */
+ public static IREdge createControlEdge(final IRVertex src, final IRVertex dst) {
+ final IREdge controlEdge = new IREdge(CommunicationPatternProperty.Value.BroadCast, src, dst);
+ controlEdge.setPropertyPermanently(AdditionalOutputTagProperty.of(CONTROL_EDGE_TAG));
+ return controlEdge;
+ }
+
+ /**
+ * @param vertices to stringify ids of.
+ * @return the string of ids.
+ */
+ public static String stringifyIRVertexIds(final Collection<IRVertex> vertices) {
+ return vertices.stream().map(IRVertex::getId).collect(Collectors.toSet()).toString();
+ }
+
+ /**
+ * @param edges to stringify ids of.
+ * @return the string of ids.
+ */
+ public static String stringifyIREdgeIds(final Collection<IREdge> edges) {
+ return edges.stream().map(IREdge::getId).collect(Collectors.toSet()).toString();
+ }
}
diff --git a/common/src/main/java/org/apache/nemo/common/dag/DAG.java b/common/src/main/java/org/apache/nemo/common/dag/DAG.java
index e6a55b7..de11d98 100644
--- a/common/src/main/java/org/apache/nemo/common/dag/DAG.java
+++ b/common/src/main/java/org/apache/nemo/common/dag/DAG.java
@@ -103,6 +103,11 @@ public final class DAG<V extends Vertex, E extends Edge<V>> implements DAGInterf
}
@Override
+ public List<E> getEdges() {
+ return incomingEdges.values().stream().flatMap(List::stream).collect(Collectors.toList());
+ }
+
+ @Override
public List<V> getRootVertices() {
return rootVertices;
}
diff --git a/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java b/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
index 1fe2b13..bd5cc95 100644
--- a/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
+++ b/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
@@ -19,12 +19,11 @@
package org.apache.nemo.common.dag;
import org.apache.nemo.common.exception.CompileTimeOptimizationException;
-import org.apache.nemo.common.ir.edge.IREdge;
-import org.apache.nemo.common.ir.edge.executionproperty.DataFlowProperty;
-import org.apache.nemo.common.ir.edge.executionproperty.MessageIdProperty;
import org.apache.nemo.common.ir.vertex.*;
import org.apache.nemo.common.exception.IllegalVertexOperationException;
+import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
import java.io.Serializable;
import java.util.*;
@@ -227,7 +226,8 @@ public final class DAGBuilder<V extends Vertex, E extends Edge<V>> implements Se
final Supplier<Stream<V>> verticesToObserve = () -> vertices.stream().filter(v -> incomingEdges.get(v).isEmpty())
.filter(v -> v instanceof IRVertex);
// They should all match SourceVertex
- if (verticesToObserve.get().anyMatch(v -> !(v instanceof SourceVertex))) {
+ if (!(verticesToObserve.get().allMatch(v -> (v instanceof SourceVertex)
+ || (v instanceof SamplingVertex && ((SamplingVertex) v).getCloneOfOriginalVertex() instanceof SourceVertex)))) {
final String problematicVertices = verticesToObserve.get()
.filter(v -> !(v instanceof SourceVertex))
.map(V::getId)
@@ -258,16 +258,15 @@ public final class DAGBuilder<V extends Vertex, E extends Edge<V>> implements Se
* Helper method to check that all execution properties are correct and makes sense.
*/
private void executionPropertyCheck() {
- // DataSizeMetricCollection is not compatible with Push (All data have to be stored before the data collection)
- vertices.forEach(v -> incomingEdges.get(v).stream().filter(e -> e instanceof IREdge).map(e -> (IREdge) e)
- .filter(e -> e.getPropertyValue(MessageIdProperty.class).isPresent())
- .filter(e -> !(e.getDst() instanceof OperatorVertex
- && e.getDst() instanceof MessageAggregatorVertex))
- .filter(e -> DataFlowProperty.Value.Push.equals(e.getPropertyValue(DataFlowProperty.class).get()))
- .forEach(e -> {
- throw new CompileTimeOptimizationException("DAG execution property check: "
- + "DataSizeMetricCollection edge is not compatible with push" + e.getId());
- }));
+ final long numOfMAV = vertices.stream().filter(v -> v instanceof MessageAggregatorVertex).count();
+ final long numOfDistinctMessageIds = vertices.stream()
+ .filter(v -> v instanceof MessageAggregatorVertex)
+ .map(v -> ((MessageAggregatorVertex) v).getPropertyValue(MessageIdVertexProperty.class).get())
+ .distinct()
+ .count();
+ if (numOfMAV != numOfDistinctMessageIds) {
+ throw getException("A unique message id must exist for each MessageAggregator", "");
+ }
}
/**
diff --git a/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java b/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java
index 40cb72c..eba2e4a 100644
--- a/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java
+++ b/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java
@@ -44,21 +44,27 @@ public interface DAGInterface<V extends Vertex, E extends Edge<V>> extends Seria
/**
* Retrieves the vertices of this DAG.
- * @return the set of vertices.
+ * @return the list of vertices.
* Note that the result is never null, ensured by {@link DAGBuilder}.
*/
List<V> getVertices();
/**
+ * Retrieves the edges of this DAG.
+ * @return the list of edges.
+ */
+ List<E> getEdges();
+
+ /**
* Retrieves the root vertices of this DAG.
- * @return the set of root vertices.
+ * @return the list of root vertices.
*/
List<V> getRootVertices();
/**
* Retrieves the incoming edges of the given vertex.
* @param v the subject vertex.
- * @return the set of incoming edges to the vertex.
+ * @return the list of incoming edges to the vertex.
* Note that the result is never null, ensured by {@link DAGBuilder}.
*/
List<E> getIncomingEdgesOf(final V v);
@@ -66,7 +72,7 @@ public interface DAGInterface<V extends Vertex, E extends Edge<V>> extends Seria
/**
* Retrieves the incoming edges of the given vertex.
* @param vertexId the ID of the subject vertex.
- * @return the set of incoming edges to the vertex.
+ * @return the list of incoming edges to the vertex.
* Note that the result is never null, ensured by {@link DAGBuilder}.
*/
List<E> getIncomingEdgesOf(final String vertexId);
@@ -74,7 +80,7 @@ public interface DAGInterface<V extends Vertex, E extends Edge<V>> extends Seria
/**
* Retrieves the outgoing edges of the given vertex.
* @param v the subject vertex.
- * @return the set of outgoing edges to the vertex.
+ * @return the list of outgoing edges to the vertex.
* Note that the result is never null, ensured by {@link DAGBuilder}.
*/
List<E> getOutgoingEdgesOf(final V v);
@@ -82,7 +88,7 @@ public interface DAGInterface<V extends Vertex, E extends Edge<V>> extends Seria
/**
* Retrieves the outgoing edges of the given vertex.
* @param vertexId the ID of the subject vertex.
- * @return the set of outgoing edges to the vertex.
+ * @return the list of outgoing edges to the vertex.
* Note that the result is never null, ensured by {@link DAGBuilder}.
*/
List<E> getOutgoingEdgesOf(final String vertexId);
diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
index a6cd1c7..af79e72 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
@@ -19,25 +19,31 @@
package org.apache.nemo.common.ir;
import com.fasterxml.jackson.databind.node.ObjectNode;
-import org.apache.nemo.common.PairKeyExtractor;
+import com.google.common.collect.Sets;
+import org.apache.nemo.common.KeyExtractor;
+import org.apache.nemo.common.Pair;
+import org.apache.nemo.common.Util;
+import org.apache.nemo.common.coder.BytesDecoderFactory;
+import org.apache.nemo.common.coder.BytesEncoderFactory;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.dag.DAGInterface;
+import org.apache.nemo.common.exception.CompileTimeOptimizationException;
import org.apache.nemo.common.exception.IllegalEdgeOperationException;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.*;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.LoopVertex;
+import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
import org.apache.nemo.common.ir.vertex.utility.MessageBarrierVertex;
+import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
import org.apache.nemo.common.ir.vertex.utility.StreamVertex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.concurrent.NotThreadSafe;
-import java.util.List;
-import java.util.Set;
-import java.util.concurrent.atomic.AtomicInteger;
+import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
@@ -53,13 +59,13 @@ import java.util.stream.Collectors;
* All of these methods preserve application semantics.
* - Annotation: setProperty(), getPropertyValue() on each IRVertex/IREdge
* - Reshaping: insert(), delete() on the IRDAG
+ *
+ * TODO #341: Rethink IRDAG insert() signatures
*/
@NotThreadSafe
public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
private static final Logger LOG = LoggerFactory.getLogger(IRDAG.class.getName());
- private final AtomicInteger metricCollectionId;
-
private DAG<IRVertex, IREdge> dagSnapshot; // the DAG that was saved most recently.
private DAG<IRVertex, IREdge> modifiedDAG; // the DAG that is being updated.
@@ -69,7 +75,6 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
public IRDAG(final DAG<IRVertex, IREdge> originalUserApplicationDAG) {
this.modifiedDAG = originalUserApplicationDAG;
this.dagSnapshot = originalUserApplicationDAG;
- this.metricCollectionId = new AtomicInteger(0);
}
//////////////////////////////////////////////////
@@ -103,15 +108,25 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
* After: src - edgeToStreamizeWithNewDestination - streamVertex - oneToOneEdge - dst
* (replaces the "Before" relationships)
*
+ * This preserves semantics as the streamVertex simply forwards data elements from the input edge to the output edge.
+ *
* @param streamVertex to insert.
* @param edgeToStreamize to modify.
*/
public void insert(final StreamVertex streamVertex, final IREdge edgeToStreamize) {
+ assertNonExistence(streamVertex);
+
// Create a completely new DAG with the vertex inserted.
- final DAGBuilder builder = new DAGBuilder();
+ final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
+
+ // Integrity check
+ if (edgeToStreamize.getPropertyValue(MessageIdEdgeProperty.class).isPresent()) {
+ throw new CompileTimeOptimizationException(edgeToStreamize.getId() + " has a MessageId, and cannot be removed");
+ }
// Insert the vertex.
- builder.addVertex(streamVertex);
+ final IRVertex vertexToInsert = wrapSamplingVertexIfNeeded(streamVertex, edgeToStreamize.getSrc());
+ builder.addVertex(vertexToInsert);
// Build the new DAG to reflect the new topology.
modifiedDAG.topologicalDo(v -> {
@@ -122,20 +137,35 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
// MATCH!
// Edge to the streamVertex
- final IREdge edgeToStreamizeWithNewDestination = new IREdge(
+ final IREdge toSV = new IREdge(
edgeToStreamize.getPropertyValue(CommunicationPatternProperty.class).get(),
edgeToStreamize.getSrc(),
- streamVertex);
- edgeToStreamize.copyExecutionPropertiesTo(edgeToStreamizeWithNewDestination);
+ vertexToInsert);
+ edgeToStreamize.copyExecutionPropertiesTo(toSV);
// Edge from the streamVertex.
- final IREdge oneToOneEdge = new IREdge(CommunicationPatternProperty.Value.OneToOne, streamVertex, v);
- oneToOneEdge.setProperty(EncoderProperty.of(edgeToStreamize.getPropertyValue(EncoderProperty.class).get()));
- oneToOneEdge.setProperty(DecoderProperty.of(edgeToStreamize.getPropertyValue(DecoderProperty.class).get()));
+ final IREdge fromSV = new IREdge(CommunicationPatternProperty.Value.OneToOne, vertexToInsert, v);
+ fromSV.setProperty(EncoderProperty.of(edgeToStreamize.getPropertyValue(EncoderProperty.class).get()));
+ fromSV.setProperty(DecoderProperty.of(edgeToStreamize.getPropertyValue(DecoderProperty.class).get()));
+
+ // Future optimizations may want to use the original encoders/compressions.
+ toSV.setPropertySnapshot();
+ fromSV.setPropertySnapshot();
+
+ // Annotations for efficient data transfers - toSV
+ toSV.setPropertyPermanently(DecoderProperty.of(BytesDecoderFactory.of()));
+ toSV.setPropertyPermanently(CompressionProperty.of(CompressionProperty.Value.LZ4));
+ toSV.setPropertyPermanently(DecompressionProperty.of(CompressionProperty.Value.None));
+
+ // Annotations for efficient data transfers - fromSV
+ fromSV.setPropertyPermanently(EncoderProperty.of(BytesEncoderFactory.of()));
+ fromSV.setPropertyPermanently(CompressionProperty.of(CompressionProperty.Value.None));
+ fromSV.setPropertyPermanently(DecompressionProperty.of(CompressionProperty.Value.LZ4));
+ fromSV.setPropertyPermanently(PartitionerProperty.of(PartitionerProperty.Type.DedicatedKeyPerElement));
// Track the new edges.
- builder.connectVertices(edgeToStreamizeWithNewDestination);
- builder.connectVertices(oneToOneEdge);
+ builder.connectVertices(toSV);
+ builder.connectVertices(fromSV);
} else {
// NO MATCH, so simply connect vertices as before.
builder.connectVertices(edge);
@@ -156,77 +186,174 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
* shuffleEdge - messageAggregatorVertex - broadcastEdge - dst
* (the "Before" relationships are unmodified)
*
+ * This preserves semantics as the results of the inserted message vertices are never consumed by the original IRDAG.
+ *
* @param messageBarrierVertex to insert.
* @param messageAggregatorVertex to insert.
* @param mbvOutputEncoder to use.
* @param mbvOutputDecoder to use.
* @param edgesToGetStatisticsOf to examine.
+ * @param edgesToOptimize to optimize.
*/
public void insert(final MessageBarrierVertex messageBarrierVertex,
final MessageAggregatorVertex messageAggregatorVertex,
final EncoderProperty mbvOutputEncoder,
final DecoderProperty mbvOutputDecoder,
- final Set<IREdge> edgesToGetStatisticsOf) {
+ final Set<IREdge> edgesToGetStatisticsOf,
+ final Set<IREdge> edgesToOptimize) {
+ assertNonExistence(messageBarrierVertex);
+ assertNonExistence(messageAggregatorVertex);
+
if (edgesToGetStatisticsOf.stream().map(edge -> edge.getDst().getId()).collect(Collectors.toSet()).size() != 1) {
- throw new IllegalArgumentException("Not destined to the same vertex: " + edgesToGetStatisticsOf.toString());
+ throw new IllegalArgumentException("Not destined to the same vertex: " + edgesToOptimize.toString());
+ }
+ if (edgesToOptimize.stream().map(edge -> edge.getDst().getId()).collect(Collectors.toSet()).size() != 1) {
+ throw new IllegalArgumentException("Not destined to the same vertex: " + edgesToOptimize.toString());
}
- final IRVertex dst = edgesToGetStatisticsOf.iterator().next().getDst();
// Create a completely new DAG with the vertex inserted.
- final DAGBuilder builder = new DAGBuilder();
+ final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
- // Current metric collection id.
- final int currentMetricCollectionId = metricCollectionId.incrementAndGet();
+ // All of the existing vertices and edges remain intact
+ modifiedDAG.topologicalDo(v -> {
+ builder.addVertex(v);
+ modifiedDAG.getIncomingEdgesOf(v).forEach(builder::connectVertices);
+ });
- // First, add all the vertices.
- modifiedDAG.topologicalDo(v -> builder.addVertex(v));
+ ////////////////////////////////// STEP 1: Insert new vertices and edges (src - mbv - mav - dst)
- // Add a control dependency (no output) from the messageAggregatorVertex to the destination.
+ // From src to mbv
+ final List<IRVertex> mbvList = new ArrayList<>();
+ for (final IREdge edge : edgesToGetStatisticsOf) {
+ final IRVertex mbvToAdd = wrapSamplingVertexIfNeeded(
+ new MessageBarrierVertex<>(messageBarrierVertex.getMessageFunction()), edge.getSrc());
+ builder.addVertex(mbvToAdd);
+ mbvList.add(mbvToAdd);
+
+ final IREdge clone = Util.cloneEdge(CommunicationPatternProperty.Value.OneToOne, edge, edge.getSrc(), mbvToAdd);
+ builder.connectVertices(clone);
+ }
+
+ // Add mav (no need to wrap with a sampling vertex)
builder.addVertex(messageAggregatorVertex);
- final IREdge noDataEdge = new IREdge(CommunicationPatternProperty.Value.BroadCast, messageAggregatorVertex, dst);
- builder.connectVertices(noDataEdge);
- // Add the edges and the messageBarrierVertex.
+ // From mbv to mav
+ for (final IRVertex mbv : mbvList) {
+ final IREdge edgeToMav = edgeBetweenMessageVertices(
+ mbv, messageAggregatorVertex, mbvOutputEncoder, mbvOutputDecoder);
+ builder.connectVertices(edgeToMav);
+ }
+
+ // From mav to dst
+ // Add a control dependency (no output) from the messageAggregatorVertex to the destination.
+ builder.connectVertices(
+ Util.createControlEdge(messageAggregatorVertex, edgesToGetStatisticsOf.iterator().next().getDst()));
+
+ ////////////////////////////////// STEP 2: Annotate the MessageId on optimization target edges
+
modifiedDAG.topologicalDo(v -> {
- for (final IREdge edge : modifiedDAG.getIncomingEdgesOf(v)) {
- if (edgesToGetStatisticsOf.contains(edge)) {
- // MATCH!
- final MessageBarrierVertex mbv = new MessageBarrierVertex<>(messageBarrierVertex.getMessageFunction());
- builder.addVertex(mbv);
-
- // Clone the edgeToGetStatisticsOf
- final IREdge clone = new IREdge(CommunicationPatternProperty.Value.OneToOne, edge.getSrc(), mbv);
- clone.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
- clone.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
- edge.getPropertyValue(AdditionalOutputTagProperty.class).ifPresent(tag -> {
- clone.setProperty(AdditionalOutputTagProperty.of(tag));
- });
- builder.connectVertices(clone);
-
- // messageBarrierVertex to the messageAggregatorVertex
- final IREdge edgeToABV = edgeBetweenMessageVertices(mbv,
- messageAggregatorVertex, mbvOutputEncoder, mbvOutputDecoder, currentMetricCollectionId);
- builder.connectVertices(edgeToABV);
-
- // The original edge
- // We then insert the vertex with MessageBarrierTransform and vertex with MessageAggregatorTransform
- // between the vertex and incoming vertices.
- final IREdge edgeToOriginalDst =
- new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(), edge.getSrc(), v);
- edge.copyExecutionPropertiesTo(edgeToOriginalDst);
- edgeToOriginalDst.setPropertyPermanently(MessageIdProperty.of(currentMetricCollectionId));
- builder.connectVertices(edgeToOriginalDst);
- } else {
- // NO MATCH, so simply connect vertices as before.
- builder.connectVertices(edge);
+ modifiedDAG.getIncomingEdgesOf(v).forEach(inEdge -> {
+ if (edgesToOptimize.contains(inEdge)) {
+ inEdge.setPropertyPermanently(MessageIdEdgeProperty.of(
+ messageAggregatorVertex.getPropertyValue(MessageIdVertexProperty.class).get()));
}
- }
+ });
});
modifiedDAG = builder.build(); // update the DAG.
}
/**
+ * Inserts a set of samplingVertices that process sampled data.
+ *
+ * This method automatically inserts the following three types of edges.
+ * (1) Edges between samplingVertices to reflect the original relationship
+ * (2) Edges from the original IRDAG to samplingVertices that clone the inEdges of the original vertices
+ * (3) Edges from the samplingVertices to the original IRDAG to respect executeAfterSamplingVertices
+ *
+ * Suppose the caller supplies the following arguments to perform a "sampled run" of vertices {V1, V2},
+ * prior to executing them.
+ * - samplingVertices: {V1', V2'}
+ * - childrenOfSamplingVertices: {V1}
+ *
+ * Before: V1 - oneToOneEdge - V2 - shuffleEdge - V3
+ * After: V1' - oneToOneEdge - V2' - controlEdge - V1 - oneToOneEdge - V2 - shuffleEdge - V3
+ *
+ * This preserves semantics as the original IRDAG remains unchanged and unaffected.
+ *
+ * (Future calls to insert() can add new vertices that connect to sampling vertices. Such new vertices will also be
+ * wrapped with sampling vertices, as new vertices that consume outputs from sampling vertices will process
+ * a subset of data anyways, and no such new vertex will reach the original DAG except via control edges)
+ *
+ * TODO #343: Extend SamplingVertex control edges
+ *
+ * @param samplingVertices to insert.
+ * @param executeAfterSamplingVertices that must be executed after samplingVertices.
+ */
+ public void insert(final Set<SamplingVertex> samplingVertices,
+ final Set<IRVertex> executeAfterSamplingVertices) {
+ samplingVertices.forEach(this::assertNonExistence);
+
+ // Create a completely new DAG with the vertex inserted.
+ final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
+
+ // All of the existing vertices and edges remain intact
+ modifiedDAG.topologicalDo(v -> {
+ builder.addVertex(v);
+ modifiedDAG.getIncomingEdgesOf(v).forEach(builder::connectVertices);
+ });
+
+ // Add the sampling vertices
+ samplingVertices.forEach(builder::addVertex);
+
+ // Get the original vertices
+ final Map<IRVertex, IRVertex> originalToSampling = samplingVertices.stream()
+ .collect(Collectors.toMap(sv -> modifiedDAG.getVertexById(sv.getOriginalVertexId()), Function.identity()));
+ final Set<IREdge> inEdgesOfOriginals = originalToSampling.keySet()
+ .stream()
+ .flatMap(ov -> modifiedDAG.getIncomingEdgesOf(ov).stream())
+ .collect(Collectors.toSet());
+
+ // [EDGE TYPE 1] Between sampling vertices
+ final Set<IREdge> betweenOriginals = inEdgesOfOriginals
+ .stream()
+ .filter(ovInEdge -> originalToSampling.containsKey(ovInEdge.getSrc()))
+ .collect(Collectors.toSet());
+ betweenOriginals.stream().map(boEdge -> Util.cloneEdge(
+ boEdge,
+ originalToSampling.get(boEdge.getSrc()),
+ originalToSampling.get(boEdge.getDst()))).forEach(builder::connectVertices);
+
+ // [EDGE TYPE 2] From original IRDAG to sampling vertices
+ final Set<IREdge> notBetweenOriginals = inEdgesOfOriginals
+ .stream()
+ .filter(ovInEdge -> !originalToSampling.containsKey(ovInEdge.getSrc()))
+ .collect(Collectors.toSet());
+ notBetweenOriginals.stream().map(nboEdge -> {
+ final IREdge cloneEdge = Util.cloneEdge(
+ nboEdge,
+ nboEdge.getSrc(), // sampling vertices consume a subset of original data partitions here
+ originalToSampling.get(nboEdge.getDst()));
+ nboEdge.copyExecutionPropertiesTo(cloneEdge); // exec properties must be exactly the same
+ return cloneEdge;
+ }).forEach(builder::connectVertices);
+
+ // [EDGE TYPE 3] From sampling vertices to vertices that should be executed after
+ final Set<IRVertex> sinks = getSinksWithinVertexSet(modifiedDAG, originalToSampling.keySet())
+ .stream()
+ .map(originalToSampling::get)
+ .collect(Collectors.toSet());
+ for (final IRVertex executeAfter : executeAfterSamplingVertices) {
+ for (final IRVertex sink : sinks) {
+ // Control edge that enforces execution ordering
+ builder.connectVertices(Util.createControlEdge(sink, executeAfter));
+ }
+ }
+
+ modifiedDAG = builder.build(); // update the DAG.
+ }
+
+ /**
* Reshape unsafely, without guarantees on preserving application semantics.
* TODO #330: Refactor Unsafe Reshaping Passes
* @param unsafeReshapingFunction takes as input the underlying DAG, and outputs a reshaped DAG.
@@ -237,25 +364,52 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
////////////////////////////////////////////////// Private helper methods.
+ private Set<IRVertex> getSinksWithinVertexSet(final DAG<IRVertex, IREdge> dag,
+ final Set<IRVertex> vertexSet) {
+ final Set<IRVertex> parentsOfAnotherVertex = vertexSet.stream()
+ .flatMap(v -> dag.getOutgoingEdgesOf(v).stream())
+ .filter(e -> vertexSet.contains(e.getDst()))
+ .map(IREdge::getSrc) // makes the result a subset of the input vertexSet
+ .collect(Collectors.toSet());
+ return Sets.difference(vertexSet, parentsOfAnotherVertex);
+ }
+
+ private IRVertex wrapSamplingVertexIfNeeded(final IRVertex newVertex, final IRVertex existingVertexToConnectWith) {
+ // If the connecting vertex is a sampling vertex, the new vertex must be wrapped inside a sampling vertex too.
+ return existingVertexToConnectWith instanceof SamplingVertex
+ ? new SamplingVertex(newVertex, ((SamplingVertex) existingVertexToConnectWith).getDesiredSampleRate())
+ : newVertex;
+ }
+
+ private void assertNonExistence(final IRVertex v) {
+ if (getVertices().contains(v)) {
+ throw new IllegalArgumentException(v.getId());
+ }
+ }
+
/**
* @param mbv src.
* @param mav dst.
* @param encoder src-dst encoder.
* @param decoder src-dst decoder.
- * @param currentMetricCollectionId of the edge.
* @return the edge.
*/
- private IREdge edgeBetweenMessageVertices(final MessageBarrierVertex mbv,
- final MessageAggregatorVertex mav,
+ private IREdge edgeBetweenMessageVertices(final IRVertex mbv,
+ final IRVertex mav,
final EncoderProperty encoder,
- final DecoderProperty decoder,
- final int currentMetricCollectionId) {
+ final DecoderProperty decoder) {
final IREdge newEdge = new IREdge(CommunicationPatternProperty.Value.Shuffle, mbv, mav);
newEdge.setProperty(DataStoreProperty.of(DataStoreProperty.Value.LocalFileStore));
newEdge.setProperty(DataPersistenceProperty.of(DataPersistenceProperty.Value.Keep));
newEdge.setProperty(DataFlowProperty.of(DataFlowProperty.Value.Push));
- newEdge.setPropertyPermanently(MessageIdProperty.of(currentMetricCollectionId));
- newEdge.setProperty(KeyExtractorProperty.of(new PairKeyExtractor()));
+ final KeyExtractor pairKeyExtractor = (element) -> {
+ if (element instanceof Pair) {
+ return ((Pair) element).left();
+ } else {
+ throw new IllegalStateException(element.toString());
+ }
+ };
+ newEdge.setProperty(KeyExtractorProperty.of(pairKeyExtractor));
newEdge.setPropertyPermanently(encoder);
newEdge.setPropertyPermanently(decoder);
return newEdge;
@@ -322,6 +476,11 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
}
@Override
+ public List<IREdge> getEdges() {
+ return modifiedDAG.getEdges();
+ }
+
+ @Override
public List<IRVertex> getRootVertices() {
return modifiedDAG.getRootVertices();
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/edge/IREdge.java b/common/src/main/java/org/apache/nemo/common/ir/edge/IREdge.java
index 1aa117b..68e5381 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/edge/IREdge.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/edge/IREdge.java
@@ -30,7 +30,7 @@ import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import java.io.Serializable;
-import java.util.Optional;
+import java.util.*;
/**
* Physical execution plan of intermediate data movement.
@@ -139,4 +139,17 @@ public final class IREdge extends Edge<IRVertex> {
node.set("executionProperties", executionProperties.asJsonNode());
return node;
}
+
+ /////////// For saving original EPs (e.g., save original encoders/decoders of StreamVertex edges)
+
+ private final Map<Class, EdgeExecutionProperty> snapshot = new HashMap<>();
+
+ public void setPropertySnapshot() {
+ snapshot.clear();
+ executionProperties.forEachProperties(p -> snapshot.put(p.getClass(), p));
+ }
+
+ public Map<Class, EdgeExecutionProperty> getPropertySnapshot() {
+ return snapshot;
+ }
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdEdgeProperty.java
similarity index 78%
copy from common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java
copy to common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdEdgeProperty.java
index 3bd74f8..b45c85e 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdEdgeProperty.java
@@ -21,14 +21,14 @@ package org.apache.nemo.common.ir.edge.executionproperty;
import org.apache.nemo.common.ir.executionproperty.EdgeExecutionProperty;
/**
- * Edges with the same MessageId are subject to the same run-time optimization.
+ * Vertices and edges with the same MessageId are subject to the same run-time optimization.
*/
-public final class MessageIdProperty extends EdgeExecutionProperty<Integer> {
+public final class MessageIdEdgeProperty extends EdgeExecutionProperty<Integer> {
/**
* Constructor.
* @param value value of the execution property.
*/
- private MessageIdProperty(final Integer value) {
+ private MessageIdEdgeProperty(final Integer value) {
super(value);
}
@@ -37,7 +37,7 @@ public final class MessageIdProperty extends EdgeExecutionProperty<Integer> {
* @param value value of the new execution property.
* @return the newly created execution property.
*/
- public static MessageIdProperty of(final Integer value) {
- return new MessageIdProperty(value);
+ public static MessageIdEdgeProperty of(final Integer value) {
+ return new MessageIdEdgeProperty(value);
}
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
index f279783..b1ce3ca 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
@@ -21,13 +21,13 @@ package org.apache.nemo.common.ir.vertex;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
+import org.apache.nemo.common.Util;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupPropertyValue;
-import org.apache.nemo.common.util.Util;
import java.io.Serializable;
import java.util.*;
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/SourceVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/SourceVertex.java
index d8f6eee..7ad6a82 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/SourceVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/SourceVertex.java
@@ -28,7 +28,6 @@ import java.util.List;
* @param <O> output type.
*/
public abstract class SourceVertex<O> extends IRVertex {
-
/**
* Constructor for SourceVertex.
*/
@@ -49,6 +48,7 @@ public abstract class SourceVertex<O> extends IRVertex {
public SourceVertex(final SourceVertex that) {
super(that);
}
+
/**
* Gets parallel readables.
*
diff --git a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/MessageIdVertexProperty.java
similarity index 69%
rename from common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java
rename to common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/MessageIdVertexProperty.java
index 3bd74f8..9ba842f 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/MessageIdVertexProperty.java
@@ -16,19 +16,19 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.nemo.common.ir.edge.executionproperty;
+package org.apache.nemo.common.ir.vertex.executionproperty;
-import org.apache.nemo.common.ir.executionproperty.EdgeExecutionProperty;
+import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
/**
- * Edges with the same MessageId are subject to the same run-time optimization.
+ * Vertices and edges with the same MessageId belong to the same run-time optimization.
*/
-public final class MessageIdProperty extends EdgeExecutionProperty<Integer> {
+public final class MessageIdVertexProperty extends VertexExecutionProperty<Integer> {
/**
* Constructor.
* @param value value of the execution property.
*/
- private MessageIdProperty(final Integer value) {
+ private MessageIdVertexProperty(final Integer value) {
super(value);
}
@@ -37,7 +37,7 @@ public final class MessageIdProperty extends EdgeExecutionProperty<Integer> {
* @param value value of the new execution property.
* @return the newly created execution property.
*/
- public static MessageIdProperty of(final Integer value) {
- return new MessageIdProperty(value);
+ public static MessageIdVertexProperty of(final Integer value) {
+ return new MessageIdVertexProperty(value);
}
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java
index a5a9280..4db54d7 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java
@@ -20,8 +20,12 @@ package org.apache.nemo.common.ir.vertex.utility;
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
+import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
import org.apache.nemo.common.ir.vertex.transform.MessageAggregatorTransform;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
/**
@@ -30,12 +34,17 @@ import java.util.function.BiFunction;
* @param <V> of the input pair.
* @param <O> of the output aggregated message.
*/
-public class MessageAggregatorVertex<K, V, O> extends OperatorVertex {
+public final class MessageAggregatorVertex<K, V, O> extends OperatorVertex {
+ private static final Logger LOG = LoggerFactory.getLogger(MessageAggregatorVertex.class.getName());
+
+ private static final AtomicInteger MESSAGE_ID_GENERATOR = new AtomicInteger(0);
+
/**
* @param initialState to use.
* @param userFunction for aggregating the messages.
*/
public MessageAggregatorVertex(final O initialState, final BiFunction<Pair<K, V>, O, O> userFunction) {
super(new MessageAggregatorTransform<>(initialState, userFunction));
+ this.setPropertyPermanently(MessageIdVertexProperty.of(MESSAGE_ID_GENERATOR.incrementAndGet()));
}
}
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java
new file mode 100644
index 0000000..fe54952
--- /dev/null
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.common.ir.vertex.utility;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import org.apache.nemo.common.Util;
+import org.apache.nemo.common.ir.edge.IREdge;
+import org.apache.nemo.common.ir.vertex.IRVertex;
+
+/**
+ * Executes the original IRVertex using a subset of input data partitions.
+ */
+public final class SamplingVertex extends IRVertex {
+ private final IRVertex originalVertex;
+ private final IRVertex cloneOfOriginalVertex;
+ private final float desiredSampleRate;
+
+ /**
+ * @param originalVertex to clone.
+ * @param desiredSampleRate percentage of tasks to execute.
+ * The actual sample rate may vary depending on neighboring sampling vertices.
+ */
+ public SamplingVertex(final IRVertex originalVertex, final float desiredSampleRate) {
+ super();
+ if (originalVertex instanceof SamplingVertex) {
+ throw new IllegalArgumentException("Cannot sample again: " + originalVertex.toString());
+ }
+ if (desiredSampleRate > 1 || desiredSampleRate <= 0) {
+ throw new IllegalArgumentException(String.valueOf(desiredSampleRate));
+ }
+ this.originalVertex = originalVertex;
+ this.cloneOfOriginalVertex = originalVertex.getClone();
+ this.desiredSampleRate = desiredSampleRate;
+
+ // Copy execution properties.
+ originalVertex.copyExecutionPropertiesTo(cloneOfOriginalVertex);
+ originalVertex.copyExecutionPropertiesTo(this);
+ }
+
+ /**
+ * @return the id of the original vertex for reference.
+ */
+ public String getOriginalVertexId() {
+ return originalVertex.getId();
+ }
+
+ /**
+ * @return the clone of the original vertex.
+ * This clone is intended to be used during the actual execution, as the sampling vertex itself is not executable
+ * and the original vertex should not be executed again.
+ */
+ public IRVertex getCloneOfOriginalVertex() {
+ return cloneOfOriginalVertex;
+ }
+
+ /**
+ * @return the desired sample rate.
+ */
+ public float getDesiredSampleRate() {
+ return desiredSampleRate;
+ }
+
+ /**
+ * Obtains a clone of an original edge that is attached to this sampling vertex.
+ *
+ * Original edge: src - to - dst
+ * When src == originalVertex, return thisSamplingVertex - to - dst
+ * When dst == originalVertex, return src - to - thisSamplingVertex
+ *
+ * @param originalEdge to clone.
+ * @return a clone of the edge.
+ */
+ public IREdge getCloneOfOriginalEdge(final IREdge originalEdge) {
+ if (originalEdge.getSrc().equals(originalVertex)) {
+ return Util.cloneEdge(originalEdge, this, originalEdge.getDst());
+ } else if (originalEdge.getDst().equals(originalVertex)) {
+ return Util.cloneEdge(originalEdge, originalEdge.getSrc(), this);
+ } else {
+ throw new IllegalArgumentException(originalEdge.getId());
+ }
+ }
+
+ @Override
+ public String toString() {
+ final StringBuilder sb = new StringBuilder();
+ sb.append("SamplingVertex(desiredSampleRate:");
+ sb.append(String.valueOf(desiredSampleRate));
+ sb.append(")[");
+ sb.append(originalVertex);
+ sb.append("]");
+ return sb.toString();
+ }
+
+ @Override
+ public IRVertex getClone() {
+ return new SamplingVertex(originalVertex, desiredSampleRate);
+ }
+
+ @Override
+ public JsonNode getPropertiesAsJsonNode() {
+ return getCloneOfOriginalVertex().getPropertiesAsJsonNode();
+ }
+}
diff --git a/common/src/test/java/org/apache/nemo/common/util/UtilTest.java b/common/src/test/java/org/apache/nemo/common/util/UtilTest.java
index 4e6869f..e46db33 100644
--- a/common/src/test/java/org/apache/nemo/common/util/UtilTest.java
+++ b/common/src/test/java/org/apache/nemo/common/util/UtilTest.java
@@ -22,6 +22,7 @@ import static org.junit.Assert.assertEquals;
import java.util.function.IntPredicate;
+import org.apache.nemo.common.Util;
import org.junit.Test;
public class UtilTest {
diff --git a/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java b/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java
index a457de4..7131be0 100644
--- a/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java
+++ b/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java
@@ -20,7 +20,7 @@ package org.apache.nemo.compiler.backend.nemo;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
-import org.apache.nemo.common.ir.edge.executionproperty.MessageIdProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.MessageIdEdgeProperty;
import org.apache.nemo.common.ir.executionproperty.ExecutionPropertyMap;
import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
@@ -87,8 +87,8 @@ public final class NemoPlanRewriter implements PlanRewriter {
.getVertices()
.stream()
.flatMap(v -> currentIRDAG.getIncomingEdgesOf(v).stream())
- .filter(e -> e.getPropertyValue(MessageIdProperty.class).isPresent()
- && e.getPropertyValue(MessageIdProperty.class).get() == messageId
+ .filter(e -> e.getPropertyValue(MessageIdEdgeProperty.class).isPresent()
+ && e.getPropertyValue(MessageIdEdgeProperty.class).get() == messageId
&& !(e.getDst() instanceof MessageAggregatorVertex))
.collect(Collectors.toSet());
if (examiningEdges.isEmpty()) {
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java
index 0958037..b6e4194 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java
@@ -18,8 +18,6 @@
*/
package org.apache.nemo.compiler.optimizer.pass.compiletime.annotating;
-import org.apache.nemo.common.coder.BytesDecoderFactory;
-import org.apache.nemo.common.coder.BytesEncoderFactory;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.executionproperty.*;
import org.apache.nemo.common.ir.vertex.executionproperty.ResourceSlotProperty;
@@ -42,9 +40,7 @@ import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
* Do not encode/compress the byte[]
* Perform a pull-based and on-disk data transfer with the DedicatedKeyPerElement.
*/
-@Annotates({CompressionProperty.class, DataFlowProperty.class, CompressionProperty.class,
- DataPersistenceProperty.class, DataStoreProperty.class, DecoderProperty.class, DecompressionProperty.class,
- EncoderProperty.class, PartitionerProperty.class, ResourceSlotProperty.class})
+@Annotates({DataFlowProperty.class, DataPersistenceProperty.class, DataStoreProperty.class, ResourceSlotProperty.class})
@Requires(CommunicationPatternProperty.class)
public final class LargeShuffleAnnotatingPass extends AnnotatingPass {
/**
@@ -61,11 +57,6 @@ public final class LargeShuffleAnnotatingPass extends AnnotatingPass {
if (edge.getDst().getClass().equals(StreamVertex.class)) {
// CASE #1: To a stream vertex
- // Coder and Compression
- edge.setPropertyPermanently(DecoderProperty.of(BytesDecoderFactory.of()));
- edge.setPropertyPermanently(CompressionProperty.of(CompressionProperty.Value.LZ4));
- edge.setPropertyPermanently(DecompressionProperty.of(CompressionProperty.Value.None));
-
// Data transfers
edge.setPropertyPermanently(DataFlowProperty.of(DataFlowProperty.Value.Push));
edge.setPropertyPermanently(DataPersistenceProperty.of(DataPersistenceProperty.Value.Discard));
@@ -76,16 +67,9 @@ public final class LargeShuffleAnnotatingPass extends AnnotatingPass {
} else if (edge.getSrc().getClass().equals(StreamVertex.class)) {
// CASE #2: From a stream vertex
- // Coder and Compression
- edge.setPropertyPermanently(EncoderProperty.of(BytesEncoderFactory.of()));
- edge.setPropertyPermanently(CompressionProperty.of(CompressionProperty.Value.None));
- edge.setPropertyPermanently(DecompressionProperty.of(CompressionProperty.Value.LZ4));
-
// Data transfers
edge.setPropertyPermanently(DataFlowProperty.of(DataFlowProperty.Value.Pull));
edge.setPropertyPermanently(DataStoreProperty.of(DataStoreProperty.Value.LocalFileStore));
- edge.setPropertyPermanently(
- PartitionerProperty.of(PartitionerProperty.Type.DedicatedKeyPerElement));
} else {
// CASE #3: Unrelated to any stream vertices
edge.setPropertyPermanently(DataFlowProperty.of(DataFlowProperty.Value.Pull));
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java
new file mode 100644
index 0000000..8db8085
--- /dev/null
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
+
+import org.apache.nemo.common.KeyExtractor;
+import org.apache.nemo.common.dag.Edge;
+import org.apache.nemo.common.ir.IRDAG;
+import org.apache.nemo.common.ir.edge.IREdge;
+import org.apache.nemo.common.ir.edge.executionproperty.*;
+import org.apache.nemo.common.ir.vertex.IRVertex;
+import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.MessageBarrierVertex;
+import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+/**
+ * Optimizes the PartitionSet property of shuffle edges to handle data skews using the SamplingVertex.
+ *
+ * This pass effectively partitions the IRDAG by non-oneToOne edges, clones each subDAG partition using SamplingVertex
+ * to process sampled data, and executes each cloned partition prior to executing the corresponding original partition.
+ *
+ * Suppose the IRDAG is partitioned into three sub-DAG partitions with shuffle dependencies as follows:
+ * P1 - P2 - P3
+ *
+ * Then, this pass will produce something like:
+ * P1' - P1
+ * - P2' - P2 - P3
+ * where Px' consists of SamplingVertex objects that clone the execution of Px.
+ * (P3 is not cloned here because it is a sink partition, and none of the outgoing edges of its vertices needs to be
+ * optimized)
+ *
+ * For each Px' this pass also inserts a MessageBarrierVertex, to use its data statistics for dynamically optimizing
+ * the execution behaviors of Px.
+ */
+@Requires(CommunicationPatternProperty.class)
+public final class SamplingSkewReshapingPass extends ReshapingPass {
+ private static final Logger LOG = LoggerFactory.getLogger(SamplingSkewReshapingPass.class.getName());
+ private static final float SAMPLE_RATE = 0.1f;
+
+ /**
+ * Default constructor.
+ */
+ public SamplingSkewReshapingPass() {
+ super(SamplingSkewReshapingPass.class);
+ }
+
+ @Override
+ public IRDAG apply(final IRDAG dag) {
+ dag.topologicalDo(v -> {
+ for (final IREdge e : dag.getIncomingEdgesOf(v)) {
+ if (CommunicationPatternProperty.Value.Shuffle.equals(
+ e.getPropertyValue(CommunicationPatternProperty.class).get())) {
+ // Compute the partition and its source vertices
+ final IRVertex shuffleWriter = e.getSrc();
+ final Set<IRVertex> partitionAll = recursivelyBuildPartition(shuffleWriter, dag);
+ final Set<IRVertex> partitionSources = partitionAll.stream().filter(vertexInPartition ->
+ !dag.getIncomingEdgesOf(vertexInPartition).stream()
+ .map(Edge::getSrc)
+ .anyMatch(partitionAll::contains)
+ ).collect(Collectors.toSet());
+
+ // Check if the partition is a sink, in which case we do not create sampling vertices
+ final boolean isSinkPartition = partitionAll.stream()
+ .flatMap(vertexInPartition -> dag.getOutgoingEdgesOf(vertexInPartition).stream())
+ .map(Edge::getDst)
+ .allMatch(partitionAll::contains);
+ if (isSinkPartition) {
+ break;
+ }
+
+ // Insert sampling vertices.
+ final Set<SamplingVertex> samplingVertices = partitionAll
+ .stream()
+ .map(vertexInPartition -> new SamplingVertex(vertexInPartition, SAMPLE_RATE))
+ .collect(Collectors.toSet());
+ dag.insert(samplingVertices, partitionSources);
+
+ // Insert the message vertex.
+ // We first obtain a clonedShuffleEdge to analyze the data statistics of the shuffle outputs of
+ // the sampling vertex right before shuffle.
+ final SamplingVertex rightBeforeShuffle = samplingVertices.stream()
+ .filter(sv -> sv.getOriginalVertexId().equals(e.getSrc().getId()))
+ .findFirst()
+ .orElseThrow(() -> new IllegalStateException());
+ final IREdge clonedShuffleEdge = rightBeforeShuffle.getCloneOfOriginalEdge(e);
+
+ final KeyExtractor keyExtractor = e.getPropertyValue(KeyExtractorProperty.class).get();
+ dag.insert(
+ new MessageBarrierVertex<>(SkewHandlingUtil.getDynOptCollector(keyExtractor)),
+ new MessageAggregatorVertex(new HashMap(), SkewHandlingUtil.getDynOptAggregator()),
+ SkewHandlingUtil.getEncoder(e),
+ SkewHandlingUtil.getDecoder(e),
+ new HashSet<>(Arrays.asList(clonedShuffleEdge)), // this works although the clone is not in the dag
+ new HashSet<>(Arrays.asList(e))); // we want to optimize the original edge, not the clone
+ }
+ }
+ });
+
+ return dag;
+ }
+
+ private Set<IRVertex> recursivelyBuildPartition(final IRVertex curVertex, final IRDAG dag) {
+ final Set<IRVertex> unionSet = new HashSet<>();
+ unionSet.add(curVertex);
+ for (final IREdge inEdge : dag.getIncomingEdgesOf(curVertex)) {
+ if (CommunicationPatternProperty.Value.OneToOne
+ .equals(inEdge.getPropertyValue(CommunicationPatternProperty.class).get())
+ && DataStoreProperty.Value.MemoryStore
+ .equals(inEdge.getPropertyValue(DataStoreProperty.class).get())
+ && dag.getIncomingEdgesOf(curVertex).size() == 1) {
+ unionSet.addAll(recursivelyBuildPartition(inEdge.getSrc(), dag));
+ }
+ }
+ return unionSet;
+ }
+}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java
new file mode 100644
index 0000000..6231ba4
--- /dev/null
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
+
+import org.apache.nemo.common.KeyExtractor;
+import org.apache.nemo.common.Pair;
+import org.apache.nemo.common.coder.*;
+import org.apache.nemo.common.ir.edge.IREdge;
+import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.KeyDecoderProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.KeyEncoderProperty;
+
+import java.io.Serializable;
+import java.util.Map;
+import java.util.function.BiFunction;
+
+/**
+ * A utility class for skew handling passes.
+ */
+final class SkewHandlingUtil {
+ private SkewHandlingUtil() {
+ }
+
+ static BiFunction<Object, Map<Object, Long>, Map<Object, Long>> getDynOptCollector(final KeyExtractor keyExtractor) {
+ return (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
+ (element, dynOptData) -> {
+ Object key = keyExtractor.extractKey(element);
+ if (dynOptData.containsKey(key)) {
+ dynOptData.compute(key, (existingKey, existingCount) -> (long) existingCount + 1L);
+ } else {
+ dynOptData.put(key, 1L);
+ }
+ return dynOptData;
+ };
+ }
+
+ static BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> getDynOptAggregator() {
+ return (BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> & Serializable)
+ (element, aggregatedDynOptData) -> {
+ final Object key = element.left();
+ final Long count = element.right();
+ if (aggregatedDynOptData.containsKey(key)) {
+ aggregatedDynOptData.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
+ } else {
+ aggregatedDynOptData.put(key, count);
+ }
+ return aggregatedDynOptData;
+ };
+ }
+
+ static EncoderProperty getEncoder(final IREdge irEdge) {
+ return EncoderProperty.of(PairEncoderFactory
+ .of(irEdge.getPropertyValue(KeyEncoderProperty.class).get(), LongEncoderFactory.of()));
+ }
+
+ static DecoderProperty getDecoder(final IREdge irEdge) {
+ return DecoderProperty.of(PairDecoderFactory
+ .of(irEdge.getPropertyValue(KeyDecoderProperty.class).get(), LongDecoderFactory.of()));
+ }
+}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
index 62c3c9a..046dd75 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
@@ -19,24 +19,17 @@
package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
import org.apache.nemo.common.KeyExtractor;
-import org.apache.nemo.common.Pair;
-import org.apache.nemo.common.coder.*;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.*;
import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
import org.apache.nemo.common.ir.vertex.utility.MessageBarrierVertex;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
-import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty;
-import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
-import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.Annotates;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.Serializable;
import java.util.*;
-import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -45,7 +38,6 @@ import java.util.stream.Collectors;
* We insert a {@link MessageBarrierVertex} for each shuffle edge,
* and aggregate messages for multiple same-destination shuffle edges.
* */
-@Annotates(PartitionerProperty.class)
@Requires(CommunicationPatternProperty.class)
public final class SkewReshapingPass extends ReshapingPass {
private static final Logger LOG = LoggerFactory.getLogger(SkewReshapingPass.class.getName());
@@ -78,43 +70,12 @@ public final class SkewReshapingPass extends ReshapingPass {
// Get the key extractor
final KeyExtractor keyExtractor = representativeEdge.getPropertyValue(KeyExtractorProperty.class).get();
- // For collecting the data
- final BiFunction<Object, Map<Object, Long>, Map<Object, Long>> dynOptDataCollector =
- (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
- (element, dynOptData) -> {
- Object key = keyExtractor.extractKey(element);
- if (dynOptData.containsKey(key)) {
- dynOptData.compute(key, (existingKey, existingCount) -> (long) existingCount + 1L);
- } else {
- dynOptData.put(key, 1L);
- }
- return dynOptData;
- };
-
- // For aggregating the collected data
- final BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> dynOptDataAggregator =
- (BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> & Serializable)
- (element, aggregatedDynOptData) -> {
- final Object key = element.left();
- final Long count = element.right();
- if (aggregatedDynOptData.containsKey(key)) {
- aggregatedDynOptData.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
- } else {
- aggregatedDynOptData.put(key, count);
- }
- return aggregatedDynOptData;
- };
-
- // Coders to use
- final EncoderProperty encoderProperty = EncoderProperty.of(PairEncoderFactory.
- of(representativeEdge.getPropertyValue(KeyEncoderProperty.class).get(), LongEncoderFactory.of()));
- final DecoderProperty decoderProperty = DecoderProperty.of(PairDecoderFactory
- .of(representativeEdge.getPropertyValue(KeyDecoderProperty.class).get(), LongDecoderFactory.of()));
-
// Insert the vertices
- final MessageBarrierVertex mbv = new MessageBarrierVertex<>(dynOptDataCollector);
- final MessageAggregatorVertex mav = new MessageAggregatorVertex(new HashMap(), dynOptDataAggregator);
- dag.insert(mbv, mav, encoderProperty, decoderProperty, shuffleEdgeGroup);
+ final MessageBarrierVertex mbv = new MessageBarrierVertex<>(SkewHandlingUtil.getDynOptCollector(keyExtractor));
+ final MessageAggregatorVertex mav =
+ new MessageAggregatorVertex(new HashMap(), SkewHandlingUtil.getDynOptAggregator());
+ dag.insert(mbv, mav, SkewHandlingUtil.getEncoder(representativeEdge),
+ SkewHandlingUtil.getDecoder(representativeEdge), shuffleEdgeGroup, shuffleEdgeGroup);
}
});
return dag;
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/SamplingLargeShuffleSkewPolicy.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/SamplingLargeShuffleSkewPolicy.java
new file mode 100644
index 0000000..84d9470
--- /dev/null
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/SamplingLargeShuffleSkewPolicy.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.compiler.optimizer.policy;
+
+import org.apache.nemo.common.ir.IRDAG;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.DefaultParallelismPass;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.composite.*;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping.SamplingSkewReshapingPass;
+import org.apache.nemo.compiler.optimizer.pass.runtime.Message;
+import org.apache.nemo.compiler.optimizer.pass.runtime.SkewRunTimePass;
+
+/**
+ * A policy to demonstrate the large shuffle optimization, witch batches disk seek during data shuffle.
+ */
+public final class SamplingLargeShuffleSkewPolicy implements Policy {
+ public static final PolicyBuilder BUILDER =
+ new PolicyBuilder()
+ .registerCompileTimePass(new DefaultParallelismPass())
+ .registerCompileTimePass(new LargeShuffleCompositePass())
+ .registerRunTimePass(new SkewRunTimePass(), new SamplingSkewReshapingPass())
+ .registerCompileTimePass(new LoopOptimizationCompositePass())
+ .registerCompileTimePass(new DefaultCompositePass());
+
+ private final Policy policy;
+
+ /**
+ * Default constructor.
+ */
+ public SamplingLargeShuffleSkewPolicy() {
+ this.policy = BUILDER.build();
+ }
+
+ @Override
+ public IRDAG runCompileTimeOptimization(final IRDAG dag, final String dagDirectory) {
+ return this.policy.runCompileTimeOptimization(dag, dagDirectory);
+ }
+
+ @Override
+ public IRDAG runRunTimeOptimizations(final IRDAG dag, final Message<?> message) {
+ return this.policy.runRunTimeOptimizations(dag, message);
+ }
+}
diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/backend/nemo/DAGConverterTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/backend/nemo/DAGConverterTest.java
index 4e0b0bd..95193e1 100644
--- a/compiler/test/src/test/java/org/apache/nemo/compiler/backend/nemo/DAGConverterTest.java
+++ b/compiler/test/src/test/java/org/apache/nemo/compiler/backend/nemo/DAGConverterTest.java
@@ -102,8 +102,8 @@ public final class DAGConverterTest {
assertEquals(physicalDAG.getOutgoingEdgesOf(physicalStage1).size(), 1);
assertEquals(physicalDAG.getOutgoingEdgesOf(physicalStage2).size(), 0);
- assertEquals(3, physicalStage1.getParallelism());
- assertEquals(2, physicalStage2.getParallelism());
+ assertEquals(3, physicalStage1.getTaskIndices().size());
+ assertEquals(2, physicalStage2.getTaskIndices().size());
}
@Test
diff --git a/examples/beam/src/test/java/org/apache/nemo/examples/beam/PerKeyMedianITCase.java b/examples/beam/src/test/java/org/apache/nemo/examples/beam/PerKeyMedianITCase.java
index 39d8d28..7ee9dd6 100644
--- a/examples/beam/src/test/java/org/apache/nemo/examples/beam/PerKeyMedianITCase.java
+++ b/examples/beam/src/test/java/org/apache/nemo/examples/beam/PerKeyMedianITCase.java
@@ -22,6 +22,7 @@ import org.apache.nemo.client.JobLauncher;
import org.apache.nemo.common.test.ArgBuilder;
import org.apache.nemo.common.test.ExampleTestArgs;
import org.apache.nemo.common.test.ExampleTestUtil;
+import org.apache.nemo.compiler.optimizer.policy.SamplingLargeShuffleSkewPolicy;
import org.apache.nemo.examples.beam.policy.DataSkewPolicyParallelismFive;
import org.junit.After;
import org.junit.Before;
@@ -73,4 +74,16 @@ public final class PerKeyMedianITCase {
.addOptimizationPolicy(DataSkewPolicyParallelismFive.class.getCanonicalName())
.build());
}
+
+ /**
+ * Testing large shuffle and data skew dynamic optimization.
+ * @throws Exception exception on the way.
+ */
+ @Test (timeout = ExampleTestArgs.TIMEOUT)
+ public void testLargeShuffleSamplingSkew() throws Exception {
+ JobLauncher.main(builder
+ .addJobId(PerKeyMedianITCase.class.getSimpleName() + "_LargeShuffleSamplingSkew")
+ .addOptimizationPolicy(SamplingLargeShuffleSkewPolicy.class.getCanonicalName())
+ .build());
+ }
}
diff --git a/examples/resources/executors/beam_test_one_executor_resources.json b/examples/resources/executors/beam_test_one_executor_resources.json
index 4d6aff4..69f4399 100644
--- a/examples/resources/executors/beam_test_one_executor_resources.json
+++ b/examples/resources/executors/beam_test_one_executor_resources.json
@@ -2,6 +2,6 @@
{
"type": "Transient",
"memory_mb": 512,
- "capacity": 2
+ "capacity": 15
}
]
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java
index 242e747..798edeb 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java
@@ -28,6 +28,7 @@ import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
import org.apache.nemo.common.ir.vertex.*;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;
+import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
import org.apache.nemo.conf.JobConf;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
@@ -42,7 +43,10 @@ import org.slf4j.LoggerFactory;
import javax.inject.Inject;
import java.util.*;
+import java.util.function.BinaryOperator;
import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
/**
* A function that converts an IR DAG to physical DAG.
@@ -80,7 +84,11 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
handleDuplicateEdgeGroupProperty(dagOfStages);
// Split StageGroup by Pull StageEdges
- splitScheduleGroupByPullStageEdges(dagOfStages);
+ //
+ // TODO #337: IRDAG Unit Tests
+ // Move this test to IRDAG unit tests.
+ //
+ // splitScheduleGroupByPullStageEdges(dagOfStages);
// for debugging purposes.
dagOfStages.storeJSON(dagDirectory, "plan-logical", "logical execution plan");
@@ -136,6 +144,7 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
final Map<Integer, Stage> stageIdToStageMap = new HashMap<>();
final Map<IRVertex, Integer> vertexToStageIdMap = stagePartitioner.apply(irDAG);
final HashSet<IRVertex> isStagePartitioned = new HashSet<>();
+ final Random random = new Random(hashCode()); // to produce same results for same input IRDAGs
final Map<Integer, Set<IRVertex>> vertexSetForEachStage = new LinkedHashMap<>();
irDAG.topologicalDo(irVertex -> {
@@ -151,10 +160,9 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
final String stageIdentifier = RuntimeIdManager.generateStageId(stageId);
final ExecutionPropertyMap<VertexExecutionProperty> stageProperties = new ExecutionPropertyMap<>(stageIdentifier);
stagePartitioner.getStageProperties(stageVertices.iterator().next()).forEach(stageProperties::put);
-
final int stageParallelism = stageProperties.get(ParallelismProperty.class)
.orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
-
+ final List<Integer> taskIndices = getTaskIndicesToExecute(stageVertices, stageParallelism, random);
final DAGBuilder<IRVertex, RuntimeEdge<IRVertex>> stageInternalDAGBuilder = new DAGBuilder<>();
// Prepare vertexIdToReadables
@@ -164,14 +172,16 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
}
// For each IRVertex,
- for (final IRVertex irVertex : stageVertices) {
+ for (final IRVertex v : stageVertices) {
+ final IRVertex vertexToPutIntoStage = getActualVertexToPutIntoStage(v);
+
// Take care of the readables of a source vertex.
- if (irVertex instanceof SourceVertex && !isStagePartitioned.contains(irVertex)) {
- final SourceVertex sourceVertex = (SourceVertex) irVertex;
+ if (vertexToPutIntoStage instanceof SourceVertex && !isStagePartitioned.contains(vertexToPutIntoStage)) {
+ final SourceVertex sourceVertex = (SourceVertex) vertexToPutIntoStage;
try {
final List<Readable> readables = sourceVertex.getReadables(stageParallelism);
for (int i = 0; i < stageParallelism; i++) {
- vertexIdToReadables.get(i).put(irVertex.getId(), readables.get(i));
+ vertexIdToReadables.get(i).put(vertexToPutIntoStage.getId(), readables.get(i));
}
} catch (final Exception e) {
throw new PhysicalPlanGenerationException(e);
@@ -181,7 +191,7 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
}
// Add vertex to the stage.
- stageInternalDAGBuilder.addVertex(irVertex);
+ stageInternalDAGBuilder.addVertex(vertexToPutIntoStage);
}
for (final IRVertex dstVertex : stageVertices) {
@@ -194,8 +204,8 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
stageInternalDAGBuilder.connectVertices(new RuntimeEdge<>(
irEdge.getId(),
irEdge.getExecutionProperties(),
- irEdge.getSrc(),
- irEdge.getDst()));
+ getActualVertexToPutIntoStage(irEdge.getSrc()),
+ getActualVertexToPutIntoStage(irEdge.getDst())));
} else { // edge comes from another stage
interStageEdges.add(irEdge);
}
@@ -205,7 +215,12 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
if (!stageInternalDAGBuilder.isEmpty()) {
final DAG<IRVertex, RuntimeEdge<IRVertex>> stageInternalDAG
= stageInternalDAGBuilder.buildWithoutSourceSinkCheck();
- final Stage stage = new Stage(stageIdentifier, stageInternalDAG, stageProperties, vertexIdToReadables);
+ final Stage stage = new Stage(
+ stageIdentifier,
+ taskIndices,
+ stageInternalDAG,
+ stageProperties,
+ vertexIdToReadables);
dagOfStagesBuilder.addVertex(stage);
stageIdToStageMap.put(stageId, stage);
}
@@ -225,13 +240,52 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
dstStage == null ? String.format(" destination stage for %s", interStageEdge.getDst()) : ""));
}
dagOfStagesBuilder.connectVertices(new StageEdge(interStageEdge.getId(), interStageEdge.getExecutionProperties(),
- interStageEdge.getSrc(), interStageEdge.getDst(), srcStage, dstStage));
+ getActualVertexToPutIntoStage(interStageEdge.getSrc()), getActualVertexToPutIntoStage(interStageEdge.getDst()),
+ srcStage, dstStage));
}
return dagOfStagesBuilder.build();
}
/**
+ * This method is needed, because we do not want to put Sampling vertices into a stage.
+ * The underlying runtime only understands Source and Operator vertices.
+ */
+ private IRVertex getActualVertexToPutIntoStage(final IRVertex irVertex) {
+ return irVertex instanceof SamplingVertex
+ ? ((SamplingVertex) irVertex).getCloneOfOriginalVertex()
+ : irVertex;
+ }
+
+ /**
+ * Randomly select task indices for Sampling vertices.
+ * Select all task indices for non-Sampling vertices.
+ */
+ private List<Integer> getTaskIndicesToExecute(final Set<IRVertex> vertices,
+ final int stageParallelism,
+ final Random random) {
+ if (vertices.stream().map(v -> v instanceof SamplingVertex).collect(Collectors.toSet()).size() != 1) {
+ throw new IllegalArgumentException("Must be either all sampling vertices, or none: " + vertices.toString());
+ }
+
+ if (vertices.iterator().next() instanceof SamplingVertex) {
+ // Use min of the desired sample rates
+ final float minSampleRate = vertices.stream()
+ .map(v -> ((SamplingVertex) v).getDesiredSampleRate())
+ .reduce(BinaryOperator.minBy(Float::compareTo))
+ .orElseThrow(() -> new IllegalArgumentException(vertices.toString()));
+
+ // Compute and return indices
+ final int numOfTaskIndices = (int) Math.ceil(stageParallelism * minSampleRate);
+ final List<Integer> randomIndices = IntStream.range(0, stageParallelism).boxed().collect(Collectors.toList());
+ Collections.shuffle(randomIndices, random);
+ return new ArrayList<>(randomIndices.subList(0, numOfTaskIndices)); // subList is not serializable.
+ } else {
+ return IntStream.range(0, stageParallelism).boxed().collect(Collectors.toList());
+ }
+ }
+
+ /**
* Integrity check for Stage.
* @param stage to check for
*/
@@ -243,7 +297,7 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
stage.getIRDAG().getVertices().forEach(irVertex -> {
// Check vertex type.
- if (!(irVertex instanceof SourceVertex
+ if (!(irVertex instanceof SourceVertex
|| irVertex instanceof OperatorVertex)) {
throw new UnsupportedOperationException(irVertex.toString());
}
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java
index 600b4ad..070ca62 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java
@@ -39,6 +39,7 @@ import java.util.Optional;
* Stage.
*/
public final class Stage extends Vertex {
+ private final List<Integer> taskIndices;
private final DAG<IRVertex, RuntimeEdge<IRVertex>> irDag;
private final byte[] serializedIRDag;
private final List<Map<String, Readable>> vertexIdToReadables;
@@ -49,15 +50,18 @@ public final class Stage extends Vertex {
* Constructor.
*
* @param stageId ID of the stage.
+ * @param taskIndices indices of the tasks to execute.
* @param irDag the DAG of the task in this stage.
* @param executionProperties set of {@link VertexExecutionProperty} for this stage
* @param vertexIdToReadables the list of maps between vertex ID and {@link Readable}.
*/
public Stage(final String stageId,
+ final List<Integer> taskIndices,
final DAG<IRVertex, RuntimeEdge<IRVertex>> irDag,
final ExecutionPropertyMap<VertexExecutionProperty> executionProperties,
final List<Map<String, Readable>> vertexIdToReadables) {
super(stageId);
+ this.taskIndices = taskIndices;
this.irDag = irDag;
this.serializedIRDag = SerializationUtils.serialize(irDag);
this.executionProperties = executionProperties;
@@ -79,11 +83,20 @@ public final class Stage extends Vertex {
}
/**
- * @return the parallelism
+ * @return task indices of this stage to execute.
+ * For non-sampling vertices, returns [0, 1, 2, ..., parallelism-1].
+ * For sampling vertices, returns a list of size (parallelism * samplingRate).
+ */
+ public List<Integer> getTaskIndices() {
+ return taskIndices;
+ }
+
+ /**
+ * @return the parallelism.
*/
public int getParallelism() {
return executionProperties.get(ParallelismProperty.class)
- .orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
+ .orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
}
/**
@@ -133,6 +146,7 @@ public final class Stage extends Vertex {
node.put("scheduleGroup", getScheduleGroup());
node.set("irDag", irDag.asJsonNode());
node.put("parallelism", getParallelism());
+ node.put("num of task indices", getTaskIndices().size());
node.set("executionProperties", executionProperties.asJsonNode());
return node;
}
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StageEdge.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StageEdge.java
index f73cc75..04c0f74 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StageEdge.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StageEdge.java
@@ -158,14 +158,12 @@ public final class StageEdge extends RuntimeEdge<Stage> {
* @return {@link org.apache.nemo.common.ir.edge.executionproperty.PartitionSetProperty} value.
*/
public List<KeyRange> getKeyRanges() {
- final ArrayList<KeyRange> defaultPartitionSet = new ArrayList<>(getDst().getParallelism());
- for (int taskIdx = 0; taskIdx < getDst().getParallelism(); taskIdx++) {
- defaultPartitionSet.add(taskIdx, HashRange.of(taskIdx, taskIdx + 1));
+ final ArrayList<KeyRange> defaultPartitionSet = new ArrayList<>();
+ for (int taskIndex = 0; taskIndex < getDst().getParallelism(); taskIndex++) {
+ defaultPartitionSet.add(taskIndex, HashRange.of(taskIndex, taskIndex + 1));
}
final List<KeyRange> keyRanges = getExecutionProperties()
.get(PartitionSetProperty.class).orElse(defaultPartitionSet);
- LOG.info("{} -> {} getKeyRanges {}", srcVertex.getId(), dstVertex.getId(), keyRanges);
-
return keyRanges;
}
}
diff --git a/runtime/common/src/test/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java b/runtime/common/src/test/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java
index 7e6dcbe..40f2c2b 100644
--- a/runtime/common/src/test/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java
+++ b/runtime/common/src/test/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java
@@ -30,7 +30,6 @@ import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;
import org.apache.reef.tang.Injector;
import org.apache.reef.tang.Tang;
-import org.junit.Test;
import java.util.Iterator;
@@ -45,8 +44,10 @@ public final class PhysicalPlanGeneratorTest {
/**
* Test splitting ScheduleGroups by Pull StageEdges.
* @throws Exception exceptions on the way
+ *
+ * TODO #337: IRDAG Unit Tests
+ * Move this test to IRDAG unit tests.
*/
- @Test
public void testSplitScheduleGroupByPullStageEdges() throws Exception {
final Injector injector = Tang.Factory.getTang().newInjector();
final PhysicalPlanGenerator physicalPlanGenerator = injector.getInstance(PhysicalPlanGenerator.class);
diff --git a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/TestUtil.java b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/TestUtil.java
index 9b752c8..b1832ec 100644
--- a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/TestUtil.java
+++ b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/TestUtil.java
@@ -26,9 +26,9 @@ import java.util.List;
public final class TestUtil {
public static List<String> generateTaskIds(final Stage stage) {
- final List<String> result = new ArrayList<>(stage.getParallelism());
+ final List<String> result = new ArrayList<>();
final int first_attempt = 0;
- for (int taskIndex = 0; taskIndex < stage.getParallelism(); taskIndex++) {
+ for (final int taskIndex : stage.getTaskIndices()) {
result.add(RuntimeIdManager.generateTaskId(stage.getId(), taskIndex, first_attempt));
}
return result;
diff --git a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java
index 2a3d390..88884eb 100644
--- a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java
+++ b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java
@@ -520,7 +520,12 @@ public final class DataTransferTest {
final ExecutionPropertyMap<VertexExecutionProperty> stageExecutionProperty = new ExecutionPropertyMap<>(stageId);
stageExecutionProperty.put(ParallelismProperty.of(PARALLELISM_TEN));
stageExecutionProperty.put(ScheduleGroupProperty.of(0));
- return new Stage(stageId, emptyDag, stageExecutionProperty, Collections.emptyList());
+ return new Stage(
+ stageId,
+ IntStream.range(0, PARALLELISM_TEN).boxed().collect(Collectors.toList()),
+ emptyDag,
+ stageExecutionProperty,
+ Collections.emptyList());
}
/**
diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java
index 88f3686..702e9ea 100644
--- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java
+++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java
@@ -44,7 +44,6 @@ import org.apache.nemo.runtime.common.state.TaskState;
import org.apache.nemo.runtime.common.metric.JobMetric;
import org.apache.nemo.runtime.common.metric.StageMetric;
import org.apache.nemo.runtime.common.metric.TaskMetric;
-import org.apache.nemo.runtime.master.metric.MetricMessageHandler;
import org.apache.nemo.runtime.master.metric.MetricStore;
import org.apache.reef.annotations.audience.DriverSide;
import org.apache.reef.tang.annotations.Parameter;
@@ -77,7 +76,9 @@ public final class PlanStateManager {
*/
private PlanState planState;
private final Map<String, StageState> stageIdToState;
- private final Map<String, List<List<TaskState>>> stageIdToTaskAttemptStates; // sorted by task idx, and then attempt
+
+ // list of attempt states sorted by attempt idx
+ private final Map<String, Map<Integer, List<TaskState>>> stageIdToTaskIdxToAttemptStates;
/**
* Used for speculative cloning. (in the unit of milliseconds - ms)
@@ -101,21 +102,16 @@ public final class PlanStateManager {
* For metrics.
*/
private final String dagDirectory;
- private final MetricMessageHandler metricMessageHandler;
private MetricStore metricStore;
/**
* Constructor.
- *
- * @param metricMessageHandler the metric handler for the plan.
*/
@Inject
- private PlanStateManager(@Parameter(JobConf.DAGDirectory.class) final String dagDirectory,
- final MetricMessageHandler metricMessageHandler) {
- this.metricMessageHandler = metricMessageHandler;
+ private PlanStateManager(@Parameter(JobConf.DAGDirectory.class) final String dagDirectory) {
this.planState = new PlanState();
this.stageIdToState = new HashMap<>();
- this.stageIdToTaskAttemptStates = new HashMap<>();
+ this.stageIdToTaskIdxToAttemptStates = new HashMap<>();
this.finishLock = new ReentrantLock();
this.planFinishedCondition = finishLock.newCondition();
this.dagDirectory = dagDirectory;
@@ -155,11 +151,11 @@ public final class PlanStateManager {
physicalPlan.getStageDAG().topologicalDo(stage -> {
if (!stageIdToState.containsKey(stage.getId())) {
stageIdToState.put(stage.getId(), new StageState());
- stageIdToTaskAttemptStates.put(stage.getId(), new ArrayList<>(stage.getParallelism()));
+ stageIdToTaskIdxToAttemptStates.put(stage.getId(), new HashMap<>());
// for each task idx of this stage
- for (int taskIndex = 0; taskIndex < stage.getParallelism(); taskIndex++) {
- stageIdToTaskAttemptStates.get(stage.getId()).add(new ArrayList<>());
+ for (final int taskIndex : stage.getTaskIndices()) {
+ stageIdToTaskIdxToAttemptStates.get(stage.getId()).put(taskIndex, new ArrayList<>());
// task states will be initialized lazily in getTaskAttemptsToSchedule()
}
}
@@ -184,9 +180,9 @@ public final class PlanStateManager {
// For each task index....
final List<String> taskAttemptsToSchedule = new ArrayList<>();
final Stage stage = physicalPlan.getStageDAG().getVertexById(stageId);
- for (int taskIndex = 0; taskIndex < stage.getParallelism(); taskIndex++) {
+ for (final int taskIndex : stage.getTaskIndices()) {
final List<TaskState> attemptStatesForThisTaskIndex =
- stageIdToTaskAttemptStates.get(stageId).get(taskIndex);
+ stageIdToTaskIdxToAttemptStates.get(stageId).get(taskIndex);
// If one of the attempts is COMPLETE, do not schedule
if (attemptStatesForThisTaskIndex
@@ -247,9 +243,9 @@ public final class PlanStateManager {
final long curTime = System.currentTimeMillis();
final Map<String, Long> result = new HashMap<>();
- final List<List<TaskState>> taskStates = stageIdToTaskAttemptStates.get(stageId);
- for (int taskIndex = 0; taskIndex < taskStates.size(); taskIndex++) {
- final List<TaskState> attemptStates = taskStates.get(taskIndex);
+ final Map<Integer, List<TaskState>> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId);
+ for (final int taskIndex : taskIdToState.keySet()) {
+ final List<TaskState> attemptStates = taskIdToState.get(taskIndex);
for (int attempt = 0; attempt < attemptStates.size(); attempt++) {
if (TaskState.State.EXECUTING.equals(attemptStates.get(attempt).getStateMachine().getCurrentState())) {
final String taskId = RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt);
@@ -317,8 +313,8 @@ public final class PlanStateManager {
// Log not-yet-completed tasks for us humans to track progress
final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId);
- final List<List<TaskState>> taskStatesOfThisStage = stageIdToTaskAttemptStates.get(stageId);
- final long numOfCompletedTaskIndicesInThisStage = taskStatesOfThisStage.stream()
+ final Map<Integer, List<TaskState>> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId);
+ final long numOfCompletedTaskIndicesInThisStage = taskStatesOfThisStage.values().stream()
.filter(attempts -> {
final List<TaskState.State> states = attempts
.stream()
@@ -357,7 +353,7 @@ public final class PlanStateManager {
case COMPLETE:
case ON_HOLD:
if (numOfCompletedTaskIndicesInThisStage
- == physicalPlan.getStageDAG().getVertexById(stageId).getParallelism()) {
+ == physicalPlan.getStageDAG().getVertexById(stageId).getTaskIndices().size()) {
onStageStateChanged(stageId, StageState.State.COMPLETE);
}
break;
@@ -540,9 +536,9 @@ public final class PlanStateManager {
private Map<String, TaskState.State> getTaskAttemptIdsToItsState(final String stageId) {
final Map<String, TaskState.State> result = new HashMap<>();
- final List<List<TaskState>> taskStates = stageIdToTaskAttemptStates.get(stageId);
- for (int taskIndex = 0; taskIndex < taskStates.size(); taskIndex++) {
- final List<TaskState> attemptStates = taskStates.get(taskIndex);
+ final Map<Integer, List<TaskState>> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId);
+ for (final int taskIndex : taskIdToState.keySet()) {
+ final List<TaskState> attemptStates = taskIdToState.get(taskIndex);
for (int attempt = 0; attempt < attemptStates.size(); attempt++) {
result.put(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt),
(TaskState.State) attemptStates.get(attempt).getStateMachine().getCurrentState());
@@ -552,7 +548,7 @@ public final class PlanStateManager {
}
private TaskState getTaskStateHelper(final String taskId) {
- return stageIdToTaskAttemptStates
+ return stageIdToTaskIdxToAttemptStates
.get(RuntimeIdManager.getStageIdFromTaskId(taskId))
.get(RuntimeIdManager.getIndexFromTaskId(taskId))
.get(RuntimeIdManager.getAttemptFromTaskId(taskId));
@@ -571,7 +567,7 @@ public final class PlanStateManager {
final int attempt = RuntimeIdManager.getAttemptFromTaskId(taskId);
final List<TaskState> otherAttemptsforTheSameTaskIndex =
- new ArrayList<>(stageIdToTaskAttemptStates.get(stageId).get(taskIndex));
+ new ArrayList<>(stageIdToTaskIdxToAttemptStates.get(stageId).get(taskIndex));
otherAttemptsforTheSameTaskIndex.remove(attempt);
return otherAttemptsforTheSameTaskIndex.stream()
diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
index bcf0976..26bbb27 100644
--- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
+++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
@@ -22,9 +22,10 @@ import com.google.common.collect.Sets;
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.ir.Readable;
-import org.apache.nemo.common.ir.edge.executionproperty.MessageIdProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.MessageIdEdgeProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ClonedSchedulingProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.IgnoreSchedulingTempDataReceiverProperty;
+import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
import org.apache.nemo.runtime.common.RuntimeIdManager;
import org.apache.nemo.runtime.common.plan.*;
import org.apache.nemo.runtime.common.state.BlockState;
@@ -163,7 +164,7 @@ public final class BatchScheduler implements Scheduler {
private int getMessageId(final Set<StageEdge> stageEdges) {
final Set<Integer> messageIds = stageEdges.stream()
- .map(edge -> edge.getExecutionProperties().get(MessageIdProperty.class).get())
+ .map(edge -> edge.getExecutionProperties().get(MessageIdEdgeProperty.class).get())
.collect(Collectors.toSet());
if (messageIds.size() != 1) {
throw new IllegalArgumentException(stageEdges.toString());
@@ -288,13 +289,12 @@ public final class BatchScheduler implements Scheduler {
stage.getPropertyValue(ClonedSchedulingProperty.class).ifPresent(cloneConf -> {
if (!cloneConf.isUpFrontCloning()) { // Upfront cloning is already handled.
final double fractionToWaitFor = cloneConf.getFractionToWaitFor();
- final int parallelism = stage.getParallelism();
final Object[] completedTaskTimes = planStateManager.getCompletedTaskTimeListMs(stageId).toArray();
// Only after the fraction of the tasks are done...
// Delayed cloning (aggressive)
if (completedTaskTimes.length > 0
- && completedTaskTimes.length >= Math.round(parallelism * fractionToWaitFor)) {
+ && completedTaskTimes.length >= Math.round(stage.getTaskIndices().size() * fractionToWaitFor)) {
Arrays.sort(completedTaskTimes);
final long medianTime = (long) completedTaskTimes[completedTaskTimes.length / 2];
final double medianTimeMultiplier = cloneConf.getMedianTimeMultiplier();
@@ -465,7 +465,7 @@ public final class BatchScheduler implements Scheduler {
/**
* Get the target edges of dynamic optimization.
- * The edges are annotated with {@link MessageIdProperty}, which are outgoing edges of
+ * The edges are annotated with {@link MessageIdEdgeProperty}, which are outgoing edges of
* parents of the stage put on hold.
*
* See {@link org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping.SkewReshapingPass}
@@ -485,23 +485,25 @@ public final class BatchScheduler implements Scheduler {
// Stage put on hold, i.e. stage with vertex containing MessageAggregatorTransform
// should have a parent stage whose outgoing edges contain the target edge of dynamic optimization.
- final List<StageEdge> edgesToStagePutOnHold = stageDag.getIncomingEdgesOf(stagePutOnHold);
- if (edgesToStagePutOnHold.isEmpty()) {
- throw new RuntimeException("No edges toward specified put on hold stage");
+ final List<Integer> messageIds = stagePutOnHold.getIRDAG()
+ .getVertices()
+ .stream()
+ .filter(v -> v.getPropertyValue(MessageIdVertexProperty.class).isPresent())
+ .map(v -> v.getPropertyValue(MessageIdVertexProperty.class).get())
+ .collect(Collectors.toList());
+ if (messageIds.size() != 1) {
+ throw new IllegalStateException("Must be exactly one vertex with the message id: " + messageIds.toString());
}
- final int messageId = edgesToStagePutOnHold.get(0).getPropertyValue(MessageIdProperty.class)
- .orElseThrow(() -> new RuntimeException("No message id for this put on hold stage"));
-
+ final int messageId = messageIds.get(0);
final Set<StageEdge> targetEdges = new HashSet<>();
- // Get edges with identical MessageIdProperty (except the put on hold stage)
+ // Get edges with identical MessageIdEdgeProperty (except the put on hold stage)
for (final Stage stage : stageDag.getVertices()) {
final Set<StageEdge> targetEdgesFound = stageDag.getOutgoingEdgesOf(stage).stream()
.filter(candidateEdge -> {
final Optional<Integer> candidateMCId =
- candidateEdge.getPropertyValue(MessageIdProperty.class);
- return candidateMCId.isPresent() && candidateMCId.get().equals(messageId)
- && !edgesToStagePutOnHold.contains(candidateEdge);
+ candidateEdge.getPropertyValue(MessageIdEdgeProperty.class);
+ return candidateMCId.isPresent() && candidateMCId.get().equals(messageId);
})
.collect(Collectors.toSet());
targetEdges.addAll(targetEdgesFound);