You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by sa...@apache.org on 2018/08/31 11:39:54 UTC
[incubator-nemo] branch master updated: [NEMO-15] Run Spark ALS in
distributed mode (#113)
This is an automated email from the ASF dual-hosted git repository.
sanha pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git
The following commit(s) were added to refs/heads/master by this push:
new 954d92f [NEMO-15] Run Spark ALS in distributed mode (#113)
954d92f is described below
commit 954d92fa81b0767305bf6ddc43f63a8470dd28b5
Author: John Yang <jo...@gmail.com>
AuthorDate: Fri Aug 31 20:39:51 2018 +0900
[NEMO-15] Run Spark ALS in distributed mode (#113)
JIRA: [NEMO-15: Run Spark ALS in distributed mode](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-15)
**Major changes:**
- Removes 'sideinput' from IRDAG, and instead introduces BroadcastVariableProperty
- InMasterBroadcastVariables/BroadcastManagerWorker: Components that deal with broadcast variables
- New RPC messages for exchanging broadcast variables between client->driver, and driver<->executor
**Minor changes to note:**
- Explicit scala lambdas (update_ms, update_us) to get around serialization errors
- Avoid duplicate PlanState completions which can be caused by cloned executions
**Tests for the changes:**
- SparkScala#testALS
**Other comments:**
- This PR also closes: [NEMO-14] Re-investigate Broadcast and SideInputs
Closes #113
---
bin/run_spark.sh | 0
.../main/java/edu/snu/nemo/client/JobLauncher.java | 15 +-
.../edu/snu/nemo/client/ClientEndpointTest.java | 2 +-
.../java/edu/snu/nemo/common/dag/DAGBuilder.java | 5 +-
.../java/edu/snu/nemo/common/ir/IdManager.java | 1 +
.../java/edu/snu/nemo/common/ir/edge/IREdge.java | 25 --
.../BroadcastVariableIdProperty.java | 43 +++
.../edu/snu/nemo/common/ir/vertex/LoopVertex.java | 10 +-
.../nemo/common/ir/vertex/transform/Transform.java | 11 +-
.../snu/nemo/common/coder/CoderFactoryTest.java | 3 +-
.../compiler/frontend/beam/PipelineTranslator.java | 28 +-
.../beam/transform/CreateViewTransform.java | 9 -
.../frontend/beam/transform/DoTransform.java | 23 +-
.../frontend/spark/SparkBroadcastVariables.java | 54 ++++
.../frontend/spark/core/SparkBroadcast.java | 46 ++++
.../compiler/frontend/spark/core/SparkContext.java | 12 +
.../frontend/spark/core/SparkFrontendUtils.java | 28 +-
.../compiler/frontend/spark/core/rdd/RDD.scala | 10 +-
.../snu/nemo/compiler/optimizer/NemoOptimizer.java | 3 +-
.../reshaping/LargeShuffleRelayReshapingPass.java | 4 +-
.../compiletime/reshaping/LoopExtractionPass.java | 17 +-
.../compiletime/reshaping/LoopOptimizations.java | 6 +-
.../compiletime/reshaping/SkewReshapingPass.java | 4 +-
.../edu/snu/nemo/examples/spark/SparkALS.scala | 11 +-
.../edu/snu/nemo/examples/spark/SparkScala.java | 10 +
.../common/message/ncs/NcsMessageEnvironment.java | 6 +
.../runtime/common/plan/PhysicalPlanGenerator.java | 5 +-
.../snu/nemo/runtime/common/plan/RuntimeEdge.java | 13 +-
.../snu/nemo/runtime/common/plan/StageEdge.java | 22 +-
runtime/common/src/main/proto/ControlMessage.proto | 15 ++
.../main/java/edu/snu/nemo/driver/NemoDriver.java | 12 +-
.../edu/snu/nemo/runtime/executor/Executor.java | 7 +-
.../runtime/executor/TransformContextImpl.java | 20 +-
.../executor/data/BroadcastManagerWorker.java | 142 ++++++++++
.../runtime/executor/datatransfer/InputReader.java | 4 -
.../nemo/runtime/executor/task/DataFetcher.java | 16 +-
.../executor/task/ParentTaskDataFetcher.java | 7 +-
.../executor/task/SourceVertexDataFetcher.java | 5 +-
.../nemo/runtime/executor/task/TaskExecutor.java | 300 ++++++++++-----------
.../nemo/runtime/executor/task/VertexHarness.java | 36 +--
.../runtime/executor/TransformContextImplTest.java | 24 +-
.../executor/datatransfer/DataTransferTest.java | 9 +-
.../executor/task/ParentTaskDataFetcherTest.java | 3 +-
.../runtime/executor/task/TaskExecutorTest.java | 126 +++++----
.../nemo/runtime/master/BlockManagerMaster.java | 5 +-
.../runtime/master/BroadcastManagerMaster.java | 50 ++++
.../edu/snu/nemo/runtime/master/PlanAppender.java | 3 +-
.../snu/nemo/runtime/master/PlanStateManager.java | 4 +-
.../edu/snu/nemo/runtime/master/RuntimeMaster.java | 22 ++
.../SkewnessAwareSchedulingConstraintTest.java | 2 +-
50 files changed, 794 insertions(+), 444 deletions(-)
diff --git a/bin/run_spark.sh b/bin/run_spark.sh
old mode 100644
new mode 100755
diff --git a/client/src/main/java/edu/snu/nemo/client/JobLauncher.java b/client/src/main/java/edu/snu/nemo/client/JobLauncher.java
index 68a4422..bf43332 100644
--- a/client/src/main/java/edu/snu/nemo/client/JobLauncher.java
+++ b/client/src/main/java/edu/snu/nemo/client/JobLauncher.java
@@ -16,6 +16,7 @@
package edu.snu.nemo.client;
import com.google.common.annotations.VisibleForTesting;
+import com.google.protobuf.ByteString;
import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.conf.JobConf;
import edu.snu.nemo.driver.NemoDriver;
@@ -42,14 +43,13 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
+import java.io.Serializable;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
-import java.util.ArrayList;
-import java.util.Base64;
-import java.util.List;
+import java.util.*;
import java.util.concurrent.CountDownLatch;
/**
@@ -171,6 +171,14 @@ public final class JobLauncher {
*/
// When modifying the signature of this method, see CompilerTestUtil#compileDAG and make corresponding changes
public static void launchDAG(final DAG dag) {
+ launchDAG(dag, Collections.emptyMap());
+ }
+
+ /**
+ * @param dag the application DAG.
+ * @param broadcastVariables broadcast variables (can be empty).
+ */
+ public static void launchDAG(final DAG dag, final Map<Serializable, Object> broadcastVariables) {
// Wait until the driver is ready.
try {
LOG.info("Waiting for the driver to be ready");
@@ -188,6 +196,7 @@ public final class JobLauncher {
.setType(ControlMessage.ClientToDriverMessageType.LaunchDAG)
.setLaunchDAG(ControlMessage.LaunchDAGMessage.newBuilder()
.setDag(serializedDAG)
+ .setBroadcastVars(ByteString.copyFrom(SerializationUtils.serialize((Serializable) broadcastVariables)))
.build())
.build());
diff --git a/client/src/test/java/edu/snu/nemo/client/ClientEndpointTest.java b/client/src/test/java/edu/snu/nemo/client/ClientEndpointTest.java
index 4f9212b..b033dd8 100644
--- a/client/src/test/java/edu/snu/nemo/client/ClientEndpointTest.java
+++ b/client/src/test/java/edu/snu/nemo/client/ClientEndpointTest.java
@@ -43,7 +43,7 @@ import static org.mockito.Mockito.when;
public class ClientEndpointTest {
private static final int MAX_SCHEDULE_ATTEMPT = 2;
- @Test(timeout = 3000)
+ @Test(timeout = 10000)
public void testState() throws Exception {
// Create a simple client endpoint that returns given job state.
final StateTranslator stateTranslator = mock(StateTranslator.class);
diff --git a/common/src/main/java/edu/snu/nemo/common/dag/DAGBuilder.java b/common/src/main/java/edu/snu/nemo/common/dag/DAGBuilder.java
index 7625cde..f9b266e 100644
--- a/common/src/main/java/edu/snu/nemo/common/dag/DAGBuilder.java
+++ b/common/src/main/java/edu/snu/nemo/common/dag/DAGBuilder.java
@@ -17,6 +17,7 @@ package edu.snu.nemo.common.dag;
import edu.snu.nemo.common.exception.CompileTimeOptimizationException;
import edu.snu.nemo.common.ir.edge.IREdge;
+import edu.snu.nemo.common.ir.edge.executionproperty.BroadcastVariableIdProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.DataFlowProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.MetricCollectionProperty;
import edu.snu.nemo.common.ir.vertex.IRVertex;
@@ -258,11 +259,11 @@ public final class DAGBuilder<V extends Vertex, E extends Edge<V>> implements Se
private void executionPropertyCheck() {
// SideInput is not compatible with Push
vertices.forEach(v -> incomingEdges.get(v).stream().filter(e -> e instanceof IREdge).map(e -> (IREdge) e)
- .filter(e -> Boolean.TRUE.equals(e.isSideInput()))
+ .filter(e -> e.getPropertyValue(BroadcastVariableIdProperty.class).isPresent())
.filter(e -> DataFlowProperty.Value.Push.equals(e.getPropertyValue(DataFlowProperty.class).get()))
.forEach(e -> {
throw new CompileTimeOptimizationException("DAG execution property check: "
- + "SideInput edge is not compatible with push" + e.getId());
+ + "Broadcast edge is not compatible with push" + e.getId());
}));
// DataSizeMetricCollection is not compatible with Push (All data have to be stored before the data collection)
vertices.forEach(v -> incomingEdges.get(v).stream().filter(e -> e instanceof IREdge).map(e -> (IREdge) e)
diff --git a/common/src/main/java/edu/snu/nemo/common/ir/IdManager.java b/common/src/main/java/edu/snu/nemo/common/ir/IdManager.java
index 3a183ae..9fd96fa 100644
--- a/common/src/main/java/edu/snu/nemo/common/ir/IdManager.java
+++ b/common/src/main/java/edu/snu/nemo/common/ir/IdManager.java
@@ -37,6 +37,7 @@ public final class IdManager {
public static String newVertexId() {
return "vertex" + (isDriver ? "(d)" : "") + vertexId.getAndIncrement();
}
+
/**
* @return a new edge ID.
*/
diff --git a/common/src/main/java/edu/snu/nemo/common/ir/edge/IREdge.java b/common/src/main/java/edu/snu/nemo/common/ir/edge/IREdge.java
index d2c75e3..3cd1bd6 100644
--- a/common/src/main/java/edu/snu/nemo/common/ir/edge/IREdge.java
+++ b/common/src/main/java/edu/snu/nemo/common/ir/edge/IREdge.java
@@ -32,11 +32,9 @@ import java.util.Optional;
*/
public final class IREdge extends Edge<IRVertex> {
private final ExecutionPropertyMap<EdgeExecutionProperty> executionProperties;
- private final Boolean isSideInput;
/**
* Constructor of IREdge.
- * This constructor assumes that this edge is not for a side input.
*
* @param commPattern data communication pattern type of the edge.
* @param src source vertex.
@@ -45,23 +43,7 @@ public final class IREdge extends Edge<IRVertex> {
public IREdge(final CommunicationPatternProperty.Value commPattern,
final IRVertex src,
final IRVertex dst) {
- this(commPattern, src, dst, false);
- }
-
- /**
- * Constructor of IREdge.
- *
- * @param commPattern data communication pattern type of the edge.
- * @param src source vertex.
- * @param dst destination vertex.
- * @param isSideInput flag for whether or not the edge is a sideInput.
- */
- public IREdge(final CommunicationPatternProperty.Value commPattern,
- final IRVertex src,
- final IRVertex dst,
- final Boolean isSideInput) {
super(IdManager.newEdgeId(), src, dst);
- this.isSideInput = isSideInput;
this.executionProperties = ExecutionPropertyMap.of(this, commPattern);
}
@@ -105,13 +87,6 @@ public final class IREdge extends Edge<IRVertex> {
}
/**
- * @return whether or not the edge is a side input edge.
- */
- public Boolean isSideInput() {
- return isSideInput;
- }
-
- /**
* @param edge edge to compare.
* @return whether or not the edge has the same itinerary
*/
diff --git a/common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/BroadcastVariableIdProperty.java b/common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/BroadcastVariableIdProperty.java
new file mode 100644
index 0000000..0461763
--- /dev/null
+++ b/common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/BroadcastVariableIdProperty.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.common.ir.edge.executionproperty;
+
+import edu.snu.nemo.common.ir.executionproperty.EdgeExecutionProperty;
+
+import java.io.Serializable;
+
+/**
+ * Edges with this property fetch a broadcast variable.
+ */
+public final class BroadcastVariableIdProperty extends EdgeExecutionProperty<Serializable> {
+
+ /**
+ * Constructor.
+ * @param value id.
+ */
+ private BroadcastVariableIdProperty(final Serializable value) {
+ super(value);
+ }
+
+ /**
+ * Static method exposing constructor.
+ * @param value id.
+ * @return the newly created execution property.
+ */
+ public static BroadcastVariableIdProperty of(final Serializable value) {
+ return new BroadcastVariableIdProperty(value);
+ }
+}
diff --git a/common/src/main/java/edu/snu/nemo/common/ir/vertex/LoopVertex.java b/common/src/main/java/edu/snu/nemo/common/ir/vertex/LoopVertex.java
index 7d25aee..e07320e 100644
--- a/common/src/main/java/edu/snu/nemo/common/ir/vertex/LoopVertex.java
+++ b/common/src/main/java/edu/snu/nemo/common/ir/vertex/LoopVertex.java
@@ -220,8 +220,8 @@ public final class LoopVertex extends IRVertex {
dagBuilder.addVertex(newIrVertex, dagToAdd);
dagToAdd.getIncomingEdgesOf(irVertex).forEach(edge -> {
final IRVertex newSrc = originalToNewIRVertex.get(edge.getSrc());
- final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
- newSrc, newIrVertex, edge.isSideInput());
+ final IREdge newIrEdge =
+ new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(), newSrc, newIrVertex);
edge.copyExecutionPropertiesTo(newIrEdge);
dagBuilder.connectVertices(newIrEdge);
});
@@ -230,7 +230,7 @@ public final class LoopVertex extends IRVertex {
// process DAG incoming edges.
getDagIncomingEdges().forEach((dstVertex, irEdges) -> irEdges.forEach(edge -> {
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
- edge.getSrc(), originalToNewIRVertex.get(dstVertex), edge.isSideInput());
+ edge.getSrc(), originalToNewIRVertex.get(dstVertex));
edge.copyExecutionPropertiesTo(newIrEdge);
dagBuilder.connectVertices(newIrEdge);
}));
@@ -239,7 +239,7 @@ public final class LoopVertex extends IRVertex {
// if termination condition met, we process the DAG outgoing edge.
getDagOutgoingEdges().forEach((srcVertex, irEdges) -> irEdges.forEach(edge -> {
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
- originalToNewIRVertex.get(srcVertex), edge.getDst(), edge.isSideInput());
+ originalToNewIRVertex.get(srcVertex), edge.getDst());
edge.copyExecutionPropertiesTo(newIrEdge);
dagBuilder.addVertex(edge.getDst()).connectVertices(newIrEdge);
}));
@@ -250,7 +250,7 @@ public final class LoopVertex extends IRVertex {
this.nonIterativeIncomingEdges.forEach((dstVertex, irEdges) -> irEdges.forEach(this::addDagIncomingEdge));
this.iterativeIncomingEdges.forEach((dstVertex, irEdges) -> irEdges.forEach(edge -> {
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
- originalToNewIRVertex.get(edge.getSrc()), dstVertex, edge.isSideInput());
+ originalToNewIRVertex.get(edge.getSrc()), dstVertex);
edge.copyExecutionPropertiesTo(newIrEdge);
this.addDagIncomingEdge(newIrEdge);
}));
diff --git a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
index 47fa6c8..b7bb26c 100644
--- a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
+++ b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
@@ -47,20 +47,13 @@ public interface Transform<I, O> extends Serializable {
void close();
/**
- * @return tag
- */
- default Object getTag() {
- return null;
- }
-
- /**
* Context of the transform.
*/
interface Context extends Serializable {
/**
- * @return sideInputs.
+ * @return the broadcast variable.
*/
- Map getSideInputs();
+ Object getBroadcastVariable(Serializable id);
Map<String, String> getTagToAdditionalChildren();
/**
diff --git a/common/src/test/java/edu/snu/nemo/common/coder/CoderFactoryTest.java b/common/src/test/java/edu/snu/nemo/common/coder/CoderFactoryTest.java
index 078b7f9..9f836da 100644
--- a/common/src/test/java/edu/snu/nemo/common/coder/CoderFactoryTest.java
+++ b/common/src/test/java/edu/snu/nemo/common/coder/CoderFactoryTest.java
@@ -16,7 +16,6 @@
package edu.snu.nemo.common.coder;
-import edu.snu.nemo.common.ContextImpl;
import org.junit.Assert;
import org.junit.Test;
@@ -24,7 +23,7 @@ import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
/**
- * Tests {@link ContextImpl}.
+ * Tests coder factories.
*/
public class CoderFactoryTest {
diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
index 7e34ca2..a744ae8 100644
--- a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
+++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
@@ -104,7 +104,7 @@ public final class PipelineTranslator
final Read.Bounded<?> transform) {
final IRVertex vertex = new BeamBoundedSourceVertex<>(transform.getSource());
ctx.addVertex(vertex);
- transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input));
transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
}
@@ -117,8 +117,8 @@ public final class PipelineTranslator
ctx.addVertex(vertex);
transformVertex.getNode().getInputs().values().stream()
.filter(input -> !transform.getAdditionalInputs().values().contains(input))
- .forEach(input -> ctx.addEdgeTo(vertex, input, false));
- transform.getSideInputs().forEach(input -> ctx.addEdgeTo(vertex, input, true));
+ .forEach(input -> ctx.addEdgeTo(vertex, input));
+ transform.getSideInputs().forEach(input -> ctx.addEdgeTo(vertex, input));
transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
}
@@ -131,8 +131,8 @@ public final class PipelineTranslator
ctx.addVertex(vertex);
transformVertex.getNode().getInputs().values().stream()
.filter(input -> !transform.getAdditionalInputs().values().contains(input))
- .forEach(input -> ctx.addEdgeTo(vertex, input, false));
- transform.getSideInputs().forEach(input -> ctx.addEdgeTo(vertex, input, true));
+ .forEach(input -> ctx.addEdgeTo(vertex, input));
+ transform.getSideInputs().forEach(input -> ctx.addEdgeTo(vertex, input));
transformVertex.getNode().getOutputs().entrySet().stream()
.filter(pValueWithTupleTag -> pValueWithTupleTag.getKey().equals(transform.getMainOutputTag()))
.forEach(pValueWithTupleTag -> ctx.registerMainOutputFrom(vertex, pValueWithTupleTag.getValue()));
@@ -148,7 +148,7 @@ public final class PipelineTranslator
final GroupByKey<?, ?> transform) {
final IRVertex vertex = new OperatorVertex(new GroupByKeyTransform());
ctx.addVertex(vertex);
- transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input));
transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
}
@@ -166,7 +166,7 @@ public final class PipelineTranslator
}
final IRVertex vertex = new OperatorVertex(new WindowTransform(windowFn));
ctx.addVertex(vertex);
- transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input));
transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
}
@@ -176,7 +176,7 @@ public final class PipelineTranslator
final View.CreatePCollectionView<?, ?> transform) {
final IRVertex vertex = new OperatorVertex(new CreateViewTransform<>(transform.getView()));
ctx.addVertex(vertex);
- transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input));
ctx.registerMainOutputFrom(vertex, transform.getView());
transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
}
@@ -187,7 +187,7 @@ public final class PipelineTranslator
final Flatten.PCollections<?> transform) {
final IRVertex vertex = new OperatorVertex(new FlattenTransform());
ctx.addVertex(vertex);
- transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input));
transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
}
@@ -232,7 +232,7 @@ public final class PipelineTranslator
final IRVertex groupByKey = new OperatorVertex(new GroupByKeyTransform());
ctx.addVertex(groupByKey);
last.getNode().getOutputs().values().forEach(outputFromCombiner
- -> ctx.addEdgeTo(groupByKey, outputFromCombiner, false));
+ -> ctx.addEdgeTo(groupByKey, outputFromCombiner));
first.getNode().getOutputs().values()
.forEach(outputFromGroupByKey -> ctx.registerMainOutputFrom(groupByKey, outputFromGroupByKey));
@@ -400,9 +400,8 @@ public final class PipelineTranslator
*
* @param dst the destination IR vertex.
* @param input the {@link PValue} {@code dst} consumes
- * @param isSideInput whether it is sideInput or not.
*/
- private void addEdgeTo(final IRVertex dst, final PValue input, final boolean isSideInput) {
+ private void addEdgeTo(final IRVertex dst, final PValue input) {
final IRVertex src = pValueToProducer.get(input);
if (src == null) {
try {
@@ -418,7 +417,7 @@ public final class PipelineTranslator
throw new RuntimeException(String.format("%s have failed to determine communication pattern "
+ "for an edge from %s to %s", communicationPatternSelector, src, dst));
}
- final IREdge edge = new IREdge(communicationPattern, src, dst, isSideInput);
+ final IREdge edge = new IREdge(communicationPattern, src, dst);
final Coder<?> coder;
if (input instanceof PCollection) {
coder = ((PCollection) input).getCoder();
@@ -436,6 +435,9 @@ public final class PipelineTranslator
if (pValueToTag.containsKey(input)) {
edge.setProperty(AdditionalOutputTagProperty.of(pValueToTag.get(input).getId()));
}
+ if (input instanceof PCollectionView) {
+ edge.setProperty(BroadcastVariableIdProperty.of((PCollectionView) input));
+ }
edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
builder.connectVertices(edge);
}
diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java
index 059de81..b4ab6c6 100644
--- a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java
+++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java
@@ -61,15 +61,6 @@ public final class CreateViewTransform<I, O> implements Transform<I, O> {
}
}
- /**
- * get the Tag of the Transform.
- * @return the PCollectionView of the transform.
- */
- @Override
- public PCollectionView getTag() {
- return this.pCollectionView;
- }
-
@Override
public void close() {
final Object view = viewFn.apply(multiView);
diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java
index 8b1ff0f..529e83b 100644
--- a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java
+++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java
@@ -76,8 +76,7 @@ public final class DoTransform<I, O> implements Transform<I, O> {
this.outputCollector = oc;
this.startBundleContext = new StartBundleContext(doFn, serializedOptions);
this.finishBundleContext = new FinishBundleContext(doFn, outputCollector, serializedOptions);
- this.processContext = new ProcessContext(doFn, outputCollector,
- context.getSideInputs(), context.getTagToAdditionalChildren(), serializedOptions);
+ this.processContext = new ProcessContext(doFn, outputCollector, context, serializedOptions);
this.invoker = DoFnInvokers.invokerFor(doFn);
invoker.invokeSetup();
invoker.invokeStartBundle(startBundleContext);
@@ -196,29 +195,27 @@ public final class DoTransform<I, O> implements Transform<I, O> {
implements DoFnInvoker.ArgumentProvider<I, O> {
private I input;
private final OutputCollector<O> outputCollector;
- private final Map sideInputs;
private final Map<String, String> additionalOutputs;
+ private final Context context;
private final ObjectMapper mapper;
private final PipelineOptions options;
/**
* ProcessContext Constructor.
*
- * @param fn Dofn.
- * @param outputCollector OutputCollector.
- * @param sideInputs Map for SideInputs.
- * @param additionalOutputs Map for TaggedOutputs.
- * @param serializedOptions Options, serialized.
+ * @param fn Dofn.
+ * @param outputCollector OutputCollector.
+ * @param context Context.
+ * @param serializedOptions Options, serialized.
*/
ProcessContext(final DoFn<I, O> fn,
final OutputCollector<O> outputCollector,
- final Map sideInputs,
- final Map<String, String> additionalOutputs,
+ final Context context,
final String serializedOptions) {
fn.super();
this.outputCollector = outputCollector;
- this.sideInputs = sideInputs;
- this.additionalOutputs = additionalOutputs;
+ this.context = context;
+ this.additionalOutputs = context.getTagToAdditionalChildren();
this.mapper = new ObjectMapper();
try {
this.options = mapper.readValue(serializedOptions, PipelineOptions.class);
@@ -248,7 +245,7 @@ public final class DoTransform<I, O> implements Transform<I, O> {
@Override
public <T> T sideInput(final PCollectionView<T> view) {
- return (T) sideInputs.get(view);
+ return (T) context.getBroadcastVariable(view);
}
@Override
diff --git a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/SparkBroadcastVariables.java b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/SparkBroadcastVariables.java
new file mode 100644
index 0000000..bcfdcc8
--- /dev/null
+++ b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/SparkBroadcastVariables.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.compiler.frontend.spark;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicLong;
+
+/**
+ * Broadcast variables of Spark.
+ */
+public final class SparkBroadcastVariables {
+ private static final Logger LOG = LoggerFactory.getLogger(SparkBroadcastVariables.class.getName());
+ private static final AtomicLong ID_GENERATOR = new AtomicLong(0);
+ private static final Map<Serializable, Object> ID_TO_VARIABLE = new HashMap<>();
+
+ private SparkBroadcastVariables() {
+ }
+
+ /**
+ * @param variable data.
+ * @return the id of the variable.
+ */
+ public static long register(final Object variable) {
+ final long id = ID_GENERATOR.getAndIncrement();
+ ID_TO_VARIABLE.put(id, variable);
+ LOG.info("Registered Spark broadcast variable with id {}", id);
+ return id;
+ }
+
+ /**
+ * @return all the map from ids to variables.
+ */
+ public static Map<Serializable, Object> getAll() {
+ return ID_TO_VARIABLE;
+ }
+}
diff --git a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/core/SparkBroadcast.java b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/core/SparkBroadcast.java
new file mode 100644
index 0000000..511f78b
--- /dev/null
+++ b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/core/SparkBroadcast.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.compiler.frontend.spark.core;
+
+import edu.snu.nemo.runtime.executor.data.BroadcastManagerWorker;
+import scala.reflect.ClassTag$;
+
+/**
+ * @param <T> type of the broadcast data.
+ */
+public final class SparkBroadcast<T> extends org.apache.spark.broadcast.Broadcast<T> {
+ private final long tag;
+
+ SparkBroadcast(final long tag, final Class<T> classType) {
+ super(tag, ClassTag$.MODULE$.apply(classType));
+ this.tag = tag;
+ }
+
+ @Override
+ public T getValue() {
+ return (T) BroadcastManagerWorker.getStaticReference().get(tag);
+ }
+
+ @Override
+ public void doUnpersist(final boolean blocking) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void doDestroy(final boolean blocking) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/core/SparkContext.java b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/core/SparkContext.java
index 60f43a9..7922064 100644
--- a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/core/SparkContext.java
+++ b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/core/SparkContext.java
@@ -15,9 +15,13 @@
*/
package edu.snu.nemo.compiler.frontend.spark.core;
+import edu.snu.nemo.compiler.frontend.spark.SparkBroadcastVariables;
import edu.snu.nemo.compiler.frontend.spark.core.rdd.JavaRDD;
import edu.snu.nemo.compiler.frontend.spark.core.rdd.RDD;
import org.apache.spark.SparkConf;
+import org.apache.spark.broadcast.Broadcast;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import scala.collection.Seq;
import scala.reflect.ClassTag;
@@ -27,6 +31,7 @@ import java.util.List;
* Spark context wrapper for in Nemo.
*/
public final class SparkContext extends org.apache.spark.SparkContext {
+ private static final Logger LOG = LoggerFactory.getLogger(SparkContext.class.getName());
private final org.apache.spark.SparkContext sparkContext;
/**
@@ -61,4 +66,11 @@ public final class SparkContext extends org.apache.spark.SparkContext {
final List<T> javaList = scala.collection.JavaConversions.seqAsJavaList(seq);
return JavaRDD.of(this.sparkContext, javaList, numSlices).rdd();
}
+
+ @Override
+ public <T> Broadcast<T> broadcast(final T data,
+ final ClassTag<T> evidence) {
+ final long id = SparkBroadcastVariables.register(data);
+ return new SparkBroadcast<>(id, (Class<T>) data.getClass());
+ }
}
diff --git a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/core/SparkFrontendUtils.java b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/core/SparkFrontendUtils.java
index 54de1a6..75d22e0 100644
--- a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/core/SparkFrontendUtils.java
+++ b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/core/SparkFrontendUtils.java
@@ -26,6 +26,7 @@ import edu.snu.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
import edu.snu.nemo.common.ir.vertex.IRVertex;
import edu.snu.nemo.common.ir.vertex.LoopVertex;
import edu.snu.nemo.common.ir.vertex.OperatorVertex;
+import edu.snu.nemo.compiler.frontend.spark.SparkBroadcastVariables;
import edu.snu.nemo.compiler.frontend.spark.SparkKeyExtractor;
import edu.snu.nemo.compiler.frontend.spark.coder.SparkDecoderFactory;
import edu.snu.nemo.compiler.frontend.spark.coder.SparkEncoderFactory;
@@ -36,14 +37,15 @@ import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.serializer.JavaSerializer;
-import org.apache.spark.serializer.KryoSerializer;
-import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.*;
import scala.Function1;
import scala.Tuple2;
import scala.collection.JavaConverters;
import scala.collection.TraversableOnce;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+import java.nio.ByteBuffer;
import java.util.*;
/**
@@ -100,7 +102,7 @@ public final class SparkFrontendUtils {
builder.connectVertices(newEdge);
// launch DAG
- JobLauncher.launchDAG(builder.build());
+ JobLauncher.launchDAG(builder.build(), SparkBroadcastVariables.getAll());
return (List<T>) JobLauncher.getCollectedData();
}
@@ -126,16 +128,32 @@ public final class SparkFrontendUtils {
/**
* Converts a {@link Function1} to a corresponding {@link Function}.
*
+ * Here, we use the Spark 'JavaSerializer' to facilitate debugging in the future.
+ * TODO #205: RDD Closure with Broadcast Variables Serialization Bug
+ *
* @param scalaFunction the scala function to convert.
* @param <I> the type of input.
* @param <O> the type of output.
* @return the converted Java function.
*/
public static <I, O> Function<I, O> toJavaFunction(final Function1<I, O> scalaFunction) {
+ // This 'JavaSerializer' from Spark provides a human-readable NotSerializableException stack traces,
+ // which can be useful when addressing this problem.
+ // Other toJavaFunction can also use this serializer when debugging.
+ final ClassTag<Function1<I, O>> classTag = ClassTag$.MODULE$.apply(scalaFunction.getClass());
+ final byte[] serializedFunction = new JavaSerializer().newInstance().serialize(scalaFunction, classTag).array();
+
return new Function<I, O>() {
+ private Function1<I, O> deserializedFunction;
+
@Override
public O call(final I v1) throws Exception {
- return scalaFunction.apply(v1);
+ if (deserializedFunction == null) {
+ // TODO #205: RDD Closure with Broadcast Variables Serialization Bug
+ final SerializerInstance js = new JavaSerializer().newInstance();
+ deserializedFunction = js.deserialize(ByteBuffer.wrap(serializedFunction), classTag);
+ }
+ return deserializedFunction.apply(v1);
}
};
}
diff --git a/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/RDD.scala b/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/RDD.scala
index a3d75a5..c253f17 100644
--- a/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/RDD.scala
+++ b/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/RDD.scala
@@ -26,10 +26,11 @@ import edu.snu.nemo.common.ir.executionproperty.EdgeExecutionProperty
import edu.snu.nemo.common.ir.vertex.executionproperty.IgnoreSchedulingTempDataReceiverProperty
import edu.snu.nemo.common.ir.vertex.{IRVertex, LoopVertex, OperatorVertex}
import edu.snu.nemo.common.test.EmptyComponents.EmptyTransform
-import edu.snu.nemo.compiler.frontend.spark.SparkKeyExtractor
+import edu.snu.nemo.compiler.frontend.spark.{SparkBroadcastVariables, SparkKeyExtractor}
import edu.snu.nemo.compiler.frontend.spark.coder.{SparkDecoderFactory, SparkEncoderFactory}
import edu.snu.nemo.compiler.frontend.spark.core.SparkFrontendUtils
import edu.snu.nemo.compiler.frontend.spark.transform._
+import org.apache.commons.lang.SerializationUtils
import org.apache.hadoop.io.WritableFactory
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.spark.api.java.function.{FlatMapFunction, Function, Function2}
@@ -37,9 +38,10 @@ import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.{AsyncRDDActions, DoubleRDDFunctions, OrderedRDDFunctions, PartitionCoalescer, SequenceFileRDDFunctions}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.{Dependency, Partition, Partitioner, SparkContext, TaskContext}
+import org.apache.spark.{Dependency, Partition, Partitioner, SparkContext, SparkEnv, TaskContext}
import org.slf4j.LoggerFactory
+import scala.collection.JavaConverters._
import scala.language.implicitConversions
import scala.reflect.ClassTag
@@ -130,7 +132,7 @@ final class RDD[T: ClassTag] protected[rdd] (
* all the data is loaded into the driver's memory.
*/
override def collect(): Array[T] =
- collectAsList().toArray().asInstanceOf[Array[T]]
+ collectAsList().asScala.toArray
/////////////// TRANSFORMATIONS ///////////////
@@ -226,7 +228,7 @@ final class RDD[T: ClassTag] protected[rdd] (
newEdge.setProperty(keyExtractorProperty)
builder.connectVertices(newEdge)
- JobLauncher.launchDAG(builder.build)
+ JobLauncher.launchDAG(builder.build, SparkBroadcastVariables.getAll)
}
/////////////// CACHING ///////////////
diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/NemoOptimizer.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/NemoOptimizer.java
index ba4e11c..3d290e1 100644
--- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/NemoOptimizer.java
+++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/NemoOptimizer.java
@@ -191,8 +191,7 @@ public final class NemoOptimizer implements Optimizer {
edge.getPropertyValue(CommunicationPatternProperty.class)
.orElseThrow(() -> new RuntimeException("No communication pattern on an ir edge")),
cachedDataRelayVertex,
- irVertex,
- edge.isSideInput());
+ irVertex);
edge.copyExecutionPropertiesTo(newEdge);
newEdge.setProperty(CacheIDProperty.of(cacheId.get()));
builder.connectVertices(newEdge);
diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleRelayReshapingPass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleRelayReshapingPass.java
index fa2afee..9918b39 100644
--- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleRelayReshapingPass.java
+++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleRelayReshapingPass.java
@@ -60,8 +60,8 @@ public final class LargeShuffleRelayReshapingPass extends ReshapingPass {
final OperatorVertex iFileMergerVertex = new OperatorVertex(new RelayTransform());
builder.addVertex(iFileMergerVertex);
- final IREdge newEdgeToMerger = new IREdge(CommunicationPatternProperty.Value.Shuffle,
- edge.getSrc(), iFileMergerVertex, edge.isSideInput());
+ final IREdge newEdgeToMerger =
+ new IREdge(CommunicationPatternProperty.Value.Shuffle, edge.getSrc(), iFileMergerVertex);
edge.copyExecutionPropertiesTo(newEdgeToMerger);
final IREdge newEdgeFromMerger = new IREdge(CommunicationPatternProperty.Value.OneToOne,
iFileMergerVertex, v);
diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPass.java
index e11d712..0992539 100644
--- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPass.java
+++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPass.java
@@ -100,9 +100,8 @@ public final class LoopExtractionPass extends ReshapingPass {
// connecting with a loop: loop -> operator.
final LoopVertex srcLoopVertex = dag.getAssignedLoopVertexOf(irEdge.getSrc());
srcLoopVertex.addDagOutgoingEdge(irEdge);
- final IREdge edgeFromLoop =
- new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class).get(),
- srcLoopVertex, operatorVertex, irEdge.isSideInput());
+ final IREdge edgeFromLoop = new IREdge(
+ irEdge.getPropertyValue(CommunicationPatternProperty.class).get(), srcLoopVertex, operatorVertex);
irEdge.copyExecutionPropertiesTo(edgeFromLoop);
builder.connectVertices(edgeFromLoop);
srcLoopVertex.mapEdgeWithLoop(edgeFromLoop, irEdge);
@@ -150,7 +149,7 @@ public final class LoopExtractionPass extends ReshapingPass {
} else { // loop -> loop connection
assignedLoopVertex.addDagIncomingEdge(irEdge);
final IREdge edgeToLoop = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class).get(),
- srcLoopVertex, assignedLoopVertex, irEdge.isSideInput());
+ srcLoopVertex, assignedLoopVertex);
irEdge.copyExecutionPropertiesTo(edgeToLoop);
builder.connectVertices(edgeToLoop);
assignedLoopVertex.mapEdgeWithLoop(edgeToLoop, irEdge);
@@ -158,7 +157,7 @@ public final class LoopExtractionPass extends ReshapingPass {
} else { // operator -> loop
assignedLoopVertex.addDagIncomingEdge(irEdge);
final IREdge edgeToLoop = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class).get(),
- irEdge.getSrc(), assignedLoopVertex, irEdge.isSideInput());
+ irEdge.getSrc(), assignedLoopVertex);
irEdge.copyExecutionPropertiesTo(edgeToLoop);
builder.connectVertices(edgeToLoop);
assignedLoopVertex.mapEdgeWithLoop(edgeToLoop, irEdge);
@@ -229,13 +228,13 @@ public final class LoopExtractionPass extends ReshapingPass {
// add the new IREdge to the iterative incoming edges list.
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
- equivalentSrcVertex, equivalentDstVertex, edge.isSideInput());
+ equivalentSrcVertex, equivalentDstVertex);
edge.copyExecutionPropertiesTo(newIrEdge);
finalRootLoopVertex.addIterativeIncomingEdge(newIrEdge);
} else {
// src is from outside the previous loop. vertex outside previous loop -> DAG.
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
- srcVertex, equivalentDstVertex, edge.isSideInput());
+ srcVertex, equivalentDstVertex);
edge.copyExecutionPropertiesTo(newIrEdge);
finalRootLoopVertex.addNonIterativeIncomingEdge(newIrEdge);
}
@@ -248,7 +247,7 @@ public final class LoopExtractionPass extends ReshapingPass {
final IRVertex equivalentSrcVertex = equivalentVertices.get(srcVertex);
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
- equivalentSrcVertex, dstVertex, edge.isSideInput());
+ equivalentSrcVertex, dstVertex);
edge.copyExecutionPropertiesTo(newIrEdge);
finalRootLoopVertex.addDagOutgoingEdge(newIrEdge);
finalRootLoopVertex.mapEdgeWithLoop(loopVertex.getEdgeWithLoop(edge), newIrEdge);
@@ -293,7 +292,7 @@ public final class LoopExtractionPass extends ReshapingPass {
builder.connectVertices(edge);
} else {
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
- firstEquivalentVertex, irVertex, edge.isSideInput());
+ firstEquivalentVertex, irVertex);
edge.copyExecutionPropertiesTo(newIrEdge);
builder.connectVertices(newIrEdge);
}
diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java
index 9d9f8bc..ce496df 100644
--- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java
+++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java
@@ -163,7 +163,7 @@ public final class LoopOptimizations {
inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
if (builder.contains(irEdge.getSrc())) {
final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
- .get(), irEdge.getSrc(), newLoopVertex, irEdge.isSideInput());
+ .get(), irEdge.getSrc(), newLoopVertex);
irEdge.copyExecutionPropertiesTo(newIREdge);
builder.connectVertices(newIREdge);
}
@@ -172,7 +172,7 @@ public final class LoopOptimizations {
outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
if (builder.contains(irEdge.getDst())) {
final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
- .get(), newLoopVertex, irEdge.getDst(), irEdge.isSideInput());
+ .get(), newLoopVertex, irEdge.getDst());
irEdge.copyExecutionPropertiesTo(newIREdge);
builder.connectVertices(newIREdge);
}
@@ -291,7 +291,7 @@ public final class LoopOptimizations {
.forEach(edge -> {
edgesToRemove.add(edge);
final IREdge newEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
- candidate.getKey(), edge.getDst(), edge.isSideInput());
+ candidate.getKey(), edge.getDst());
newEdge.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
newEdge.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
edgesToAdd.add(newEdge);
diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
index a8763bd..571476d 100644
--- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
@@ -71,8 +71,8 @@ public final class SkewReshapingPass extends ReshapingPass {
newEdge.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
newEdge.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
- final IREdge edgeToGbK = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
- metricCollectionBarrierVertex, v, edge.isSideInput());
+ final IREdge edgeToGbK = new IREdge(
+ edge.getPropertyValue(CommunicationPatternProperty.class).get(), metricCollectionBarrierVertex, v);
edge.copyExecutionPropertiesTo(edgeToGbK);
builder.connectVertices(newEdge);
builder.connectVertices(edgeToGbK);
diff --git a/examples/spark/src/main/scala/edu/snu/nemo/examples/spark/SparkALS.scala b/examples/spark/src/main/scala/edu/snu/nemo/examples/spark/SparkALS.scala
index 2c15769..9b3b1ed 100644
--- a/examples/spark/src/main/scala/edu/snu/nemo/examples/spark/SparkALS.scala
+++ b/examples/spark/src/main/scala/edu/snu/nemo/examples/spark/SparkALS.scala
@@ -19,7 +19,9 @@
package edu.snu.nemo.examples.spark;
import edu.snu.nemo.compiler.frontend.spark.sql.SparkSession
+import org.apache.commons.lang.SerializationUtils
import org.apache.commons.math3.linear._
+import org.apache.spark.SparkEnv
/**
* Alternating least squares matrix factorization.
@@ -126,14 +128,19 @@ object SparkALS {
val Rc = sc.broadcast(R)
var msb = sc.broadcast(ms)
var usb = sc.broadcast(us)
+
+ // TODO #205: RDD Closure with Broadcast Variables Serialization Bug
+ val update_ms = (i : Int) => update(i, msb.value(i), usb.value, Rc.value)
+ val update_us = (i : Int) => update(i, usb.value(i), msb.value, Rc.value.transpose())
+
for (iter <- 1 to ITERATIONS) {
println(s"Iteration $iter:")
ms = sc.parallelize(0 until M, slices)
- .map(i => update(i, msb.value(i), usb.value, Rc.value))
+ .map(update_ms)
.collect()
msb = sc.broadcast(ms) // Re-broadcast ms because it was updated
us = sc.parallelize(0 until U, slices)
- .map(i => update(i, usb.value(i), msb.value, Rc.value.transpose()))
+ .map(update_us)
.collect()
usb = sc.broadcast(us) // Re-broadcast us because it was updated
println("RMSE = " + rmse(R, ms, us))
diff --git a/examples/spark/src/test/java/edu/snu/nemo/examples/spark/SparkScala.java b/examples/spark/src/test/java/edu/snu/nemo/examples/spark/SparkScala.java
index 65e3bf6..fdc9e6d 100644
--- a/examples/spark/src/test/java/edu/snu/nemo/examples/spark/SparkScala.java
+++ b/examples/spark/src/test/java/edu/snu/nemo/examples/spark/SparkScala.java
@@ -104,4 +104,14 @@ public final class SparkScala {
ExampleTestUtil.deleteOutputFile(fileBasePath, outputFileName2);
}
}
+
+ @Test(timeout = TIMEOUT)
+ public void testALS() throws Exception {
+ JobLauncher.main(builder
+ .addJobId(SparkALS.class.getSimpleName() + "_test")
+ .addUserMain(SparkALS.class.getCanonicalName())
+ .addUserArgs("100") // TODO #202: Bug with empty string user_args
+ .addOptimizationPolicy(DefaultPolicy.class.getCanonicalName())
+ .build());
+ }
}
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java
index a5062fe..fc05ad8 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java
@@ -205,8 +205,10 @@ public final class NcsMessageEnvironment implements MessageEnvironment {
case MetricFlushed:
return MessageType.Send;
case RequestBlockLocation:
+ case RequestBroadcastVariable:
return MessageType.Request;
case BlockLocationInfo:
+ case InMasterBroadcastVariable:
return MessageType.Reply;
default:
throw new IllegalArgumentException(controlMessage.toString());
@@ -217,6 +219,8 @@ public final class NcsMessageEnvironment implements MessageEnvironment {
switch (controlMessage.getType()) {
case RequestBlockLocation:
return controlMessage.getRequestBlockLocationMsg().getExecutorId();
+ case RequestBroadcastVariable:
+ return controlMessage.getRequestbroadcastVariableMsg().getExecutorId();
default:
throw new IllegalArgumentException(controlMessage.toString());
}
@@ -226,6 +230,8 @@ public final class NcsMessageEnvironment implements MessageEnvironment {
switch (controlMessage.getType()) {
case BlockLocationInfo:
return controlMessage.getBlockLocationInfoMsg().getRequestId();
+ case InMasterBroadcastVariable:
+ return controlMessage.getBroadcastVariableMsg().getRequestId();
default:
throw new IllegalArgumentException(controlMessage.toString());
}
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
index 598b253..b1fa4fb 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
@@ -193,8 +193,7 @@ public final class PhysicalPlanGenerator implements Function<DAG<IRVertex, IREdg
irEdge.getId(),
irEdge.getExecutionProperties(),
irEdge.getSrc(),
- irEdge.getDst(),
- irEdge.isSideInput()));
+ irEdge.getDst()));
} else { // edge comes from another stage
interStageEdges.add(irEdge);
}
@@ -226,7 +225,7 @@ public final class PhysicalPlanGenerator implements Function<DAG<IRVertex, IREdg
dstStage == null ? String.format(" destination stage for %s", interStageEdge.getDst()) : ""));
}
dagOfStagesBuilder.connectVertices(new StageEdge(interStageEdge.getId(), interStageEdge.getExecutionProperties(),
- interStageEdge.getSrc(), interStageEdge.getDst(), srcStage, dstStage, interStageEdge.isSideInput()));
+ interStageEdge.getSrc(), interStageEdge.getDst(), srcStage, dstStage));
}
return dagOfStagesBuilder.build();
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/RuntimeEdge.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/RuntimeEdge.java
index 01a548a..05406a9 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/RuntimeEdge.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/RuntimeEdge.java
@@ -29,7 +29,6 @@ import java.util.Optional;
*/
public class RuntimeEdge<V extends Vertex> extends Edge<V> {
private final ExecutionPropertyMap<EdgeExecutionProperty> executionProperties;
- private final Boolean isSideInput;
/**
* Constructs the edge given the below parameters.
@@ -38,16 +37,13 @@ public class RuntimeEdge<V extends Vertex> extends Edge<V> {
* @param executionProperties to control the data flow on this edge.
* @param src the source vertex.
* @param dst the destination vertex.
- * @param isSideInput Whether or not the RuntimeEdge is a side input edge.
*/
public RuntimeEdge(final String runtimeEdgeId,
final ExecutionPropertyMap<EdgeExecutionProperty> executionProperties,
final V src,
- final V dst,
- final Boolean isSideInput) {
+ final V dst) {
super(runtimeEdgeId, src, dst);
this.executionProperties = executionProperties;
- this.isSideInput = isSideInput;
}
/**
@@ -70,13 +66,6 @@ public class RuntimeEdge<V extends Vertex> extends Edge<V> {
}
/**
- * @return whether or not the RuntimeEdge is a side input edge.
- */
- public final Boolean isSideInput() {
- return isSideInput;
- }
-
- /**
* @return JSON representation of additional properties
*/
@Override
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
index ced7564..97e4658 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
@@ -69,17 +69,15 @@ public final class StageEdge extends RuntimeEdge<Stage> {
* @param dstVertex destination IRVertex in the dstStage of this edge.
* @param srcStage source stage.
* @param dstStage destination stage.
- * @param isSideInput whether or not the edge is a sideInput edge.
*/
@VisibleForTesting
public StageEdge(final String runtimeEdgeId,
- final ExecutionPropertyMap<EdgeExecutionProperty> edgeProperties,
- final IRVertex srcVertex,
- final IRVertex dstVertex,
- final Stage srcStage,
- final Stage dstStage,
- final Boolean isSideInput) {
- super(runtimeEdgeId, edgeProperties, srcStage, dstStage, isSideInput);
+ final ExecutionPropertyMap<EdgeExecutionProperty> edgeProperties,
+ final IRVertex srcVertex,
+ final IRVertex dstVertex,
+ final Stage srcStage,
+ final Stage dstStage) {
+ super(runtimeEdgeId, edgeProperties, srcStage, dstStage);
this.srcVertex = srcVertex;
this.dstVertex = dstVertex;
// Initialize the key range of each dst task.
@@ -88,11 +86,11 @@ public final class StageEdge extends RuntimeEdge<Stage> {
taskIdxToKeyRange.put(taskIdx, HashRange.of(taskIdx, taskIdx + 1, false));
}
this.dataCommunicationPatternValue = edgeProperties.get(CommunicationPatternProperty.class)
- .orElseThrow(() -> new RuntimeException(String.format(
- "CommunicationPatternProperty not set for %s", runtimeEdgeId)));
+ .orElseThrow(() -> new RuntimeException(String.format(
+ "CommunicationPatternProperty not set for %s", runtimeEdgeId)));
this.dataFlowModelValue = edgeProperties.get(DataFlowProperty.class)
- .orElseThrow(() -> new RuntimeException(String.format(
- "DataFlowProperty not set for %s", runtimeEdgeId)));
+ .orElseThrow(() -> new RuntimeException(String.format(
+ "DataFlowProperty not set for %s", runtimeEdgeId)));
}
/**
diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto
index 9e29e43..3c6bb8e 100644
--- a/runtime/common/src/main/proto/ControlMessage.proto
+++ b/runtime/common/src/main/proto/ControlMessage.proto
@@ -31,6 +31,7 @@ message ClientToDriverMessage {
message LaunchDAGMessage {
required string dag = 1;
+ optional bytes broadcastVars = 2;
}
message DataCollectMessage {
@@ -61,6 +62,8 @@ enum MessageType {
MetricMessageReceived = 8;
RequestMetricFlush = 9;
MetricFlushed = 10;
+ RequestBroadcastVariable = 11;
+ InMasterBroadcastVariable = 12;
}
message Message {
@@ -77,6 +80,8 @@ message Message {
optional ContainerFailedMsg containerFailedMsg = 11;
optional MetricMsg metricMsg = 12;
optional DataCollectMessage dataCollected = 13;
+ optional RequestBroadcastVariableMessage requestbroadcastVariableMsg = 14;
+ optional InMasterBroadcastVariableMessage broadcastVariableMsg = 15;
}
// Messages from Master to Executors
@@ -192,3 +197,13 @@ message Metric {
required string metricField = 3;
required bytes metricValue = 4;
}
+
+message RequestBroadcastVariableMessage {
+ required string executorId = 1;
+ required bytes broadcastId = 2;
+}
+
+message InMasterBroadcastVariableMessage {
+ required int64 requestId = 1; // To find the matching request msg
+ required bytes variable = 2;
+}
diff --git a/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java b/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java
index b36fae6..4f22ac6 100644
--- a/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java
+++ b/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java
@@ -22,7 +22,9 @@ import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.message.MessageParameters;
import edu.snu.nemo.runtime.master.ClientRPC;
+import edu.snu.nemo.runtime.master.BroadcastManagerMaster;
import edu.snu.nemo.runtime.master.RuntimeMaster;
+import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.reef.annotations.audience.DriverSide;
import org.apache.reef.driver.client.JobMessageObserver;
@@ -49,6 +51,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.inject.Inject;
+import java.io.Serializable;
+import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.LogManager;
@@ -104,8 +108,12 @@ public final class NemoDriver {
this.clientRPC = clientRPC;
// TODO #69: Support job-wide execution property
ResourceSitePass.setBandwidthSpecificationString(bandwidthString);
- clientRPC.registerHandler(ControlMessage.ClientToDriverMessageType.LaunchDAG, message ->
- startSchedulingUserDAG(message.getLaunchDAG().getDag()));
+ clientRPC.registerHandler(ControlMessage.ClientToDriverMessageType.LaunchDAG, message -> {
+ startSchedulingUserDAG(message.getLaunchDAG().getDag());
+ final Map<Serializable, Object> broadcastVars =
+ SerializationUtils.deserialize(message.getLaunchDAG().getBroadcastVars().toByteArray());
+ BroadcastManagerMaster.registerBroadcastVariablesFromClient(broadcastVars);
+ });
clientRPC.registerHandler(ControlMessage.ClientToDriverMessageType.DriverShutdown, message -> shutdown());
// Send DriverStarted message to the client
clientRPC.send(ControlMessage.DriverToClientMessage.newBuilder()
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java
index b4bdcd7..23978dc 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java
@@ -33,6 +33,7 @@ import edu.snu.nemo.runtime.common.message.MessageListener;
import edu.snu.nemo.runtime.common.message.PersistentConnectionToMasterMap;
import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
import edu.snu.nemo.runtime.common.plan.Task;
+import edu.snu.nemo.runtime.executor.data.BroadcastManagerWorker;
import edu.snu.nemo.runtime.executor.data.SerializerManager;
import edu.snu.nemo.runtime.executor.datatransfer.DataTransferFactory;
import edu.snu.nemo.runtime.executor.task.TaskExecutor;
@@ -69,6 +70,8 @@ public final class Executor {
*/
private final DataTransferFactory dataTransferFactory;
+ private final BroadcastManagerWorker broadcastManagerWorker;
+
private final PersistentConnectionToMasterMap persistentConnectionToMasterMap;
private final MetricMessageSender metricMessageSender;
@@ -79,6 +82,7 @@ public final class Executor {
final MessageEnvironment messageEnvironment,
final SerializerManager serializerManager,
final DataTransferFactory dataTransferFactory,
+ final BroadcastManagerWorker broadcastManagerWorker,
final MetricManagerWorker metricMessageSender) {
this.executorId = executorId;
this.executorService = Executors.newCachedThreadPool(new BasicThreadFactory.Builder()
@@ -87,6 +91,7 @@ public final class Executor {
this.persistentConnectionToMasterMap = persistentConnectionToMasterMap;
this.serializerManager = serializerManager;
this.dataTransferFactory = dataTransferFactory;
+ this.broadcastManagerWorker = broadcastManagerWorker;
this.metricMessageSender = metricMessageSender;
messageEnvironment.setupListener(MessageEnvironment.EXECUTOR_MESSAGE_LISTENER_ID, new ExecutorMessageReceiver());
}
@@ -130,7 +135,7 @@ public final class Executor {
e.getPropertyValue(DecompressionProperty.class).orElse(null)));
});
- new TaskExecutor(task, irDag, taskStateManager, dataTransferFactory,
+ new TaskExecutor(task, irDag, taskStateManager, dataTransferFactory, broadcastManagerWorker,
metricMessageSender, persistentConnectionToMasterMap).execute();
} catch (final Exception e) {
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
diff --git a/common/src/main/java/edu/snu/nemo/common/ContextImpl.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TransformContextImpl.java
similarity index 68%
rename from common/src/main/java/edu/snu/nemo/common/ContextImpl.java
rename to runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TransformContextImpl.java
index df5809f..eeace7c 100644
--- a/common/src/main/java/edu/snu/nemo/common/ContextImpl.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TransformContextImpl.java
@@ -13,36 +13,38 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package edu.snu.nemo.common;
+package edu.snu.nemo.runtime.executor;
import edu.snu.nemo.common.ir.vertex.transform.Transform;
+import edu.snu.nemo.runtime.executor.data.BroadcastManagerWorker;
+import java.io.Serializable;
import java.util.Map;
import java.util.Optional;
/**
* Transform Context Implementation.
*/
-public final class ContextImpl implements Transform.Context {
- private final Map sideInputs;
+public final class TransformContextImpl implements Transform.Context {
+ private final BroadcastManagerWorker broadcastManagerWorker;
private final Map<String, String> tagToAdditionalChildren;
private String data;
/**
* Constructor of Context Implementation.
- * @param sideInputs side inputs.
+ * @param broadcastManagerWorker for broadcast variables.
* @param tagToAdditionalChildren tag id to additional vertices id map.
*/
- public ContextImpl(final Map sideInputs,
- final Map<String, String> tagToAdditionalChildren) {
- this.sideInputs = sideInputs;
+ public TransformContextImpl(final BroadcastManagerWorker broadcastManagerWorker,
+ final Map<String, String> tagToAdditionalChildren) {
+ this.broadcastManagerWorker = broadcastManagerWorker;
this.tagToAdditionalChildren = tagToAdditionalChildren;
this.data = null;
}
@Override
- public Map getSideInputs() {
- return this.sideInputs;
+ public Object getBroadcastVariable(final Serializable tag) {
+ return broadcastManagerWorker.get(tag);
}
@Override
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BroadcastManagerWorker.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BroadcastManagerWorker.java
new file mode 100644
index 0000000..e90427b
--- /dev/null
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BroadcastManagerWorker.java
@@ -0,0 +1,142 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.runtime.executor.data;
+
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+import com.google.protobuf.ByteString;
+import edu.snu.nemo.conf.JobConf;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
+import edu.snu.nemo.runtime.common.comm.ControlMessage;
+import edu.snu.nemo.runtime.common.message.MessageEnvironment;
+import edu.snu.nemo.runtime.common.message.PersistentConnectionToMasterMap;
+import edu.snu.nemo.runtime.executor.datatransfer.InputReader;
+import net.jcip.annotations.ThreadSafe;
+import org.apache.commons.lang.SerializationUtils;
+import org.apache.reef.tang.annotations.Parameter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.inject.Inject;
+import java.io.Serializable;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Used by tasks to get/fetch (probably remote) broadcast variables.
+ */
+@ThreadSafe
+public final class BroadcastManagerWorker {
+ private static final Logger LOG = LoggerFactory.getLogger(BroadcastManagerWorker.class.getName());
+ private static BroadcastManagerWorker staticReference;
+
+ private final ConcurrentHashMap<Serializable, InputReader> idToReader;
+ private final LoadingCache<Serializable, Object> idToVariableCache;
+
+ /**
+ * Initializes the cache for broadcast variables.
+ * This cache handles concurrent cache operations by multiple threads, and is able to fetch data from
+ * remote executors or the master.
+ *
+ * @param executorId of the executor.
+ * @param toMaster connection.
+ */
+ @Inject
+ private BroadcastManagerWorker(@Parameter(JobConf.ExecutorId.class) final String executorId,
+ final PersistentConnectionToMasterMap toMaster) {
+ staticReference = this;
+ this.idToReader = new ConcurrentHashMap<>();
+ this.idToVariableCache = CacheBuilder.newBuilder()
+ .maximumSize(100)
+ .expireAfterWrite(10, TimeUnit.MINUTES)
+ .build(
+ new CacheLoader<Serializable, Object>() {
+ public Object load(final Serializable id) throws Exception {
+ LOG.info("Start to load broadcast {}", id.toString());
+ if (idToReader.containsKey(id)) {
+ // Get from reader
+ final InputReader inputReader = idToReader.get(id);
+ final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> iterators = inputReader.read();
+ if (iterators.size() != 1) {
+ throw new IllegalStateException(id.toString());
+ }
+ final DataUtil.IteratorWithNumBytes iterator = iterators.get(0).get();
+ if (!iterator.hasNext()) {
+ throw new IllegalStateException(id.toString() + " (no element) " + iterator.toString());
+ }
+ final Object result = iterator.next();
+ if (iterator.hasNext()) {
+ throw new IllegalStateException(id.toString() + " (more than single element) " + iterator.toString());
+ }
+ return result;
+ } else {
+ // Get from master
+ final CompletableFuture<ControlMessage.Message> responseFromMasterFuture = toMaster
+ .getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).request(
+ ControlMessage.Message.newBuilder()
+ .setId(RuntimeIdManager.generateMessageId())
+ .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
+ .setType(ControlMessage.MessageType.RequestBroadcastVariable)
+ .setRequestbroadcastVariableMsg(
+ ControlMessage.RequestBroadcastVariableMessage.newBuilder()
+ .setExecutorId(executorId)
+ .setBroadcastId(ByteString.copyFrom(SerializationUtils.serialize(id)))
+ .build())
+ .build());
+ return SerializationUtils.deserialize(
+ responseFromMasterFuture.get().getBroadcastVariableMsg().getVariable().toByteArray());
+ }
+ }
+ });
+ }
+
+ /**
+ * When the broadcast variable can be read by an input reader.
+ * (i.e., the variable is expressed as an IREdge, and reside in a executor as a block)
+ *
+ * @param id of the broadcast variable.
+ * @param inputReader
+ */
+ public void registerInputReader(final Serializable id,
+ final InputReader inputReader) {
+ this.idToReader.put(id, inputReader);
+ }
+
+ /**
+ * Get the variable with the id.
+ * @param id of the variable.
+ * @return the variable.
+ */
+ public Object get(final Serializable id) {
+ try {
+ return idToVariableCache.get(id);
+ } catch (ExecutionException e) {
+ // TODO #207: Handle broadcast variable fetch exceptions
+ throw new IllegalStateException(e);
+ }
+ }
+
+ /**
+ * @return the static reference for those that do not use TANG and cannot access the singleton object.
+ */
+ public static BroadcastManagerWorker getStaticReference() {
+ return staticReference;
+ }
+}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java
index 7637660..0d3c0d4 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java
@@ -152,10 +152,6 @@ public final class InputReader extends DataTransfer {
return srcVertex;
}
- public boolean isSideInputReader() {
- return Boolean.TRUE.equals(runtimeEdge.isSideInput());
- }
-
/**
* Get the parallelism of the source task.
*
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/DataFetcher.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/DataFetcher.java
index bb80e1a..54b0902 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/DataFetcher.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/DataFetcher.java
@@ -25,17 +25,11 @@ import java.io.IOException;
abstract class DataFetcher {
private final IRVertex dataSource;
private final VertexHarness child;
- private final boolean isToSideInput;
- private final boolean isFromSideInput;
DataFetcher(final IRVertex dataSource,
- final VertexHarness child,
- final boolean isFromSideInput,
- final boolean isToSideInput) {
+ final VertexHarness child) {
this.dataSource = dataSource;
this.child = child;
- this.isToSideInput = isToSideInput;
- this.isFromSideInput = isFromSideInput;
}
/**
@@ -53,12 +47,4 @@ abstract class DataFetcher {
public IRVertex getDataSource() {
return dataSource;
}
-
- boolean isFromSideInput() {
- return isFromSideInput;
- }
-
- boolean isToSideInput() {
- return isToSideInput;
- }
}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcher.java
index 0193bae..298dd0f 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcher.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcher.java
@@ -46,11 +46,8 @@ class ParentTaskDataFetcher extends DataFetcher {
private long serBytes = 0;
private long encodedBytes = 0;
- ParentTaskDataFetcher(final IRVertex dataSource,
- final InputReader readerForParentTask,
- final VertexHarness child,
- final boolean isToSideInput) {
- super(dataSource, child, readerForParentTask.isSideInputReader(), isToSideInput);
+ ParentTaskDataFetcher(final IRVertex dataSource, final InputReader readerForParentTask, final VertexHarness child) {
+ super(dataSource, child);
this.readersForParentTask = readerForParentTask;
this.firstFetch = true;
this.currentIteratorIndex = 0;
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/SourceVertexDataFetcher.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/SourceVertexDataFetcher.java
index 425cb46..343a0c9 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/SourceVertexDataFetcher.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/SourceVertexDataFetcher.java
@@ -34,9 +34,8 @@ class SourceVertexDataFetcher extends DataFetcher {
SourceVertexDataFetcher(final IRVertex dataSource,
final Readable readable,
- final VertexHarness child,
- final boolean isToSideInput) {
- super(dataSource, child, false, isToSideInput);
+ final VertexHarness child) {
+ super(dataSource, child);
this.readable = readable;
}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
index 3b8b364..4b156d9 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
@@ -16,11 +16,11 @@
package edu.snu.nemo.runtime.executor.task;
import com.google.common.collect.Lists;
-import edu.snu.nemo.common.ContextImpl;
import edu.snu.nemo.common.Pair;
import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.ir.Readable;
import edu.snu.nemo.common.ir.edge.executionproperty.AdditionalOutputTagProperty;
+import edu.snu.nemo.common.ir.edge.executionproperty.BroadcastVariableIdProperty;
import edu.snu.nemo.common.ir.vertex.*;
import edu.snu.nemo.common.ir.vertex.transform.Transform;
import edu.snu.nemo.runtime.common.RuntimeIdManager;
@@ -33,6 +33,8 @@ import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
import edu.snu.nemo.runtime.common.state.TaskState;
import edu.snu.nemo.runtime.executor.MetricMessageSender;
import edu.snu.nemo.runtime.executor.TaskStateManager;
+import edu.snu.nemo.runtime.executor.TransformContextImpl;
+import edu.snu.nemo.runtime.executor.data.BroadcastManagerWorker;
import edu.snu.nemo.runtime.executor.datatransfer.*;
import java.io.IOException;
@@ -59,9 +61,9 @@ public final class TaskExecutor {
private boolean isExecuted;
private final String taskId;
private final TaskStateManager taskStateManager;
- private final List<DataFetcher> dataFetchers;
+ private final List<DataFetcher> nonBroadcastDataFetchers;
+ private final BroadcastManagerWorker broadcastManagerWorker;
private final List<VertexHarness> sortedHarnesses;
- private final Map sideInputMap;
// Metrics information
private long boundedSourceReadTime = 0;
@@ -77,22 +79,25 @@ public final class TaskExecutor {
/**
* Constructor.
*
- * @param task Task with information needed during execution.
- * @param irVertexDag A DAG of vertices.
- * @param taskStateManager State manager for this Task.
- * @param dataTransferFactory For reading from/writing to data to other tasks.
- * @param metricMessageSender For sending metric with execution stats to Master.
+ * @param task Task with information needed during execution.
+ * @param irVertexDag A DAG of vertices.
+ * @param taskStateManager State manager for this Task.
+ * @param dataTransferFactory For reading from/writing to data to other tasks.
+ * @param broadcastManagerWorker For broadcasts.
+ * @param metricMessageSender For sending metric with execution stats to Master.
*/
public TaskExecutor(final Task task,
final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag,
final TaskStateManager taskStateManager,
final DataTransferFactory dataTransferFactory,
+ final BroadcastManagerWorker broadcastManagerWorker,
final MetricMessageSender metricMessageSender,
final PersistentConnectionToMasterMap persistentConnectionToMasterMap) {
// Essential information
this.isExecuted = false;
this.taskId = task.getTaskId();
this.taskStateManager = taskStateManager;
+ this.broadcastManagerWorker = broadcastManagerWorker;
// Metric sender
this.metricMessageSender = metricMessageSender;
@@ -104,9 +109,8 @@ public final class TaskExecutor {
this.persistentConnectionToMasterMap = persistentConnectionToMasterMap;
// Prepare data structures
- this.sideInputMap = new HashMap();
final Pair<List<DataFetcher>, List<VertexHarness>> pair = prepare(task, irVertexDag, dataTransferFactory);
- this.dataFetchers = pair.left();
+ this.nonBroadcastDataFetchers = pair.left();
this.sortedHarnesses = pair.right();
}
@@ -140,7 +144,7 @@ public final class TaskExecutor {
final List<IRVertex> reverseTopologicallySorted = Lists.reverse(irVertexDag.getTopologicalSort());
// Create a harness for each vertex
- final List<DataFetcher> dataFetcherList = new ArrayList<>();
+ final List<DataFetcher> nonBroadcastDataFetcherList = new ArrayList<>();
final Map<String, VertexHarness> vertexIdToHarness = new HashMap<>();
reverseTopologicallySorted.forEach(irVertex -> {
final List<VertexHarness> children = getChildrenHarnesses(irVertex, irVertexDag, vertexIdToHarness);
@@ -149,58 +153,73 @@ public final class TaskExecutor {
throw new IllegalStateException(irVertex.toString());
}
- final List<Boolean> isToSideInputs = children.stream()
- .map(VertexHarness::getIRVertex)
- .map(childVertex -> irVertexDag.getEdgeBetween(irVertex.getId(), childVertex.getId()))
- .map(RuntimeEdge::isSideInput)
- .collect(Collectors.toList());
-
+ // Prepare data WRITE
+ // Child-task writes
final Map<String, String> additionalOutputMap =
- getAdditionalOutputMap(irVertex, task.getTaskOutgoingEdges(), irVertexDag);
- final List<Boolean> isToAdditionalTagOutputs = children.stream()
- .map(harness -> harness.getIRVertex().getId())
- .map(additionalOutputMap::containsValue)
- .collect(Collectors.toList());
-
- // Handle writes
- // Main output children task writes
+ getAdditionalOutputMap(irVertex, task.getTaskOutgoingEdges(), irVertexDag);
final List<OutputWriter> mainChildrenTaskWriters = getMainChildrenTaskWriters(
- irVertex, task.getTaskOutgoingEdges(), dataTransferFactory, additionalOutputMap);
- // Additional output children task writes
+ irVertex, task.getTaskOutgoingEdges(), dataTransferFactory, additionalOutputMap);
final Map<String, OutputWriter> additionalChildrenTaskWriters = getAdditionalChildrenTaskWriters(
- irVertex, task.getTaskOutgoingEdges(), dataTransferFactory, additionalOutputMap);
- // Find all main vertices and additional vertices
+ irVertex, task.getTaskOutgoingEdges(), dataTransferFactory, additionalOutputMap);
+ // Intra-task writes
final List<String> additionalOutputVertices = new ArrayList<>(additionalOutputMap.values());
final Set<String> mainChildren =
- getMainOutputVertices(irVertex, irVertexDag, task.getTaskOutgoingEdges(), additionalOutputVertices);
+ getMainOutputVertices(irVertex, irVertexDag, task.getTaskOutgoingEdges(), additionalOutputVertices);
final OutputCollectorImpl oci = new OutputCollectorImpl(mainChildren, additionalOutputVertices);
+ final List<Boolean> isToAdditionalTagOutputs = children.stream()
+ .map(harness -> harness.getIRVertex().getId())
+ .map(additionalOutputMap::containsValue)
+ .collect(Collectors.toList());
- // intra-vertex writes
- final VertexHarness vertexHarness = new VertexHarness(irVertex, oci, children,
- isToSideInputs, isToAdditionalTagOutputs, mainChildrenTaskWriters, additionalChildrenTaskWriters,
- new ContextImpl(sideInputMap, additionalOutputMap));
+ // Create VERTEX HARNESS
+ final VertexHarness vertexHarness = new VertexHarness(
+ irVertex, oci, children, isToAdditionalTagOutputs, mainChildrenTaskWriters, additionalChildrenTaskWriters,
+ new TransformContextImpl(broadcastManagerWorker, additionalOutputMap));
prepareTransform(vertexHarness);
vertexIdToHarness.put(irVertex.getId(), vertexHarness);
- // Handle reads
- final boolean isToSideInput = isToSideInputs.stream().anyMatch(bool -> bool);
+ // Prepare data READ
+ // Source read
if (irVertex instanceof SourceVertex) {
- dataFetcherList.add(new SourceVertexDataFetcher(
- irVertex, sourceReader.get(), vertexHarness, isToSideInput)); // Source vertex read
+ // Source vertex read
+ nonBroadcastDataFetcherList.add(new SourceVertexDataFetcher(irVertex, sourceReader.get(), vertexHarness));
+ }
+ // Parent-task read (broadcasts)
+ final List<StageEdge> inEdgesForThisVertex = task.getTaskIncomingEdges()
+ .stream()
+ .filter(inEdge -> inEdge.getDstIRVertex().getId().equals(irVertex.getId()))
+ .collect(Collectors.toList());
+ final List<StageEdge> broadcastInEdges = inEdgesForThisVertex
+ .stream()
+ .filter(stageEdge -> stageEdge.getPropertyValue(BroadcastVariableIdProperty.class).isPresent())
+ .collect(Collectors.toList());
+ final List<InputReader> broadcastReaders =
+ getParentTaskReaders(taskIndex, broadcastInEdges, dataTransferFactory);
+ if (broadcastInEdges.size() != broadcastReaders.size()) {
+ throw new IllegalStateException(broadcastInEdges.toString() + ", " + broadcastReaders.toString());
}
- final List<InputReader> parentTaskReaders =
- getParentTaskReaders(taskIndex, irVertex, task.getTaskIncomingEdges(), dataTransferFactory);
- parentTaskReaders.forEach(parentTaskReader ->
- dataFetcherList.add(new ParentTaskDataFetcher(parentTaskReader.getSrcIrVertex(), parentTaskReader,
- vertexHarness, isToSideInput))); // Parent-task read
+ for (int i = 0; i < broadcastInEdges.size(); i++) {
+ final StageEdge inEdge = broadcastInEdges.get(i);
+ broadcastManagerWorker.registerInputReader(
+ inEdge.getPropertyValue(BroadcastVariableIdProperty.class)
+ .orElseThrow(() -> new IllegalStateException(inEdge.toString())),
+ broadcastReaders.get(i));
+ }
+ // Parent-task read (non-broadcasts)
+ final List<StageEdge> nonBroadcastInEdges = new ArrayList<>(inEdgesForThisVertex);
+ nonBroadcastInEdges.removeAll(broadcastInEdges);
+ final List<InputReader> nonBroadcastReaders =
+ getParentTaskReaders(taskIndex, nonBroadcastInEdges, dataTransferFactory);
+ nonBroadcastReaders.forEach(parentTaskReader -> nonBroadcastDataFetcherList.add(
+ new ParentTaskDataFetcher(parentTaskReader.getSrcIrVertex(), parentTaskReader, vertexHarness)));
});
final List<VertexHarness> sortedHarnessList = irVertexDag.getTopologicalSort()
- .stream()
- .map(vertex -> vertexIdToHarness.get(vertex.getId()))
- .collect(Collectors.toList());
+ .stream()
+ .map(vertex -> vertexIdToHarness.get(vertex.getId()))
+ .collect(Collectors.toList());
- return Pair.of(dataFetcherList, sortedHarnessList);
+ return Pair.of(nonBroadcastDataFetcherList, sortedHarnessList);
}
/**
@@ -232,7 +251,7 @@ public final class TaskExecutor {
// Recursively process all of the additional output elements.
vertexHarness.getContext().getTagToAdditionalChildren().values().forEach(tag -> {
outputCollector.iterateTag(tag).forEach(
- element -> handleAdditionalOutputElement(vertexHarness, element, tag)); // Recursion
+ element -> handleAdditionalOutputElement(vertexHarness, element, tag)); // Recursion
outputCollector.clearTag(tag);
});
}
@@ -252,9 +271,8 @@ public final class TaskExecutor {
/**
* The task is executed in the following two phases.
- * - Phase 1: Consume task-external side-input data
- * - Phase 2: Consume task-external input data
- * - Phase 3: Finalize task-internal states and data elements
+ * - Phase 1: Consume task-external input data (non-broadcasts)
+ * - Phase 2: Finalize task-internal states and data elements
*/
private void doExecute() {
// Housekeeping stuff
@@ -264,39 +282,21 @@ public final class TaskExecutor {
LOG.info("{} started", taskId);
taskStateManager.onTaskStateChanged(TaskState.State.EXECUTING, Optional.empty(), Optional.empty());
- // Phase 1: Consume task-external side-input related data.
- final Map<Boolean, List<DataFetcher>> sideInputRelated = dataFetchers.stream()
- .collect(Collectors.partitioningBy(fetcher -> fetcher.isFromSideInput() || fetcher.isToSideInput()));
- if (!handleDataFetchers(sideInputRelated.get(true))) {
- return;
- }
- final Set<VertexHarness> finalizeLater = sideInputRelated.get(false).stream()
- .map(DataFetcher::getChild)
- .flatMap(vertex -> getAllReachables(vertex).stream())
- .collect(Collectors.toSet());
- for (final VertexHarness vertexHarness : sortedHarnesses) {
- if (!finalizeLater.contains(vertexHarness)) {
- finalizeVertex(vertexHarness); // finalize early to materialize intra-task side inputs.
- }
- }
-
- // Phase 2: Consume task-external input data.
- if (!handleDataFetchers(sideInputRelated.get(false))) {
+ // Phase 1: Consume task-external input data. (non-broadcasts)
+ if (!handleDataFetchers(nonBroadcastDataFetchers)) {
return;
}
metricMessageSender.send("TaskMetric", taskId,
- "boundedSourceReadTime", SerializationUtils.serialize(boundedSourceReadTime));
+ "boundedSourceReadTime", SerializationUtils.serialize(boundedSourceReadTime));
metricMessageSender.send("TaskMetric", taskId,
- "serializedReadBytes", SerializationUtils.serialize(serializedReadBytes));
+ "serializedReadBytes", SerializationUtils.serialize(serializedReadBytes));
metricMessageSender.send("TaskMetric", taskId,
- "encodedReadBytes", SerializationUtils.serialize(encodedReadBytes));
+ "encodedReadBytes", SerializationUtils.serialize(encodedReadBytes));
- // Phase 3: Finalize task-internal states and elements
+ // Phase 2: Finalize task-internal states and elements
for (final VertexHarness vertexHarness : sortedHarnesses) {
- if (finalizeLater.contains(vertexHarness)) {
- finalizeVertex(vertexHarness);
- }
+ finalizeVertex(vertexHarness);
}
if (idOfVertexPutOnHold == null) {
@@ -304,24 +304,12 @@ public final class TaskExecutor {
LOG.info("{} completed", taskId);
} else {
taskStateManager.onTaskStateChanged(TaskState.State.ON_HOLD,
- Optional.of(idOfVertexPutOnHold),
- Optional.empty());
+ Optional.of(idOfVertexPutOnHold),
+ Optional.empty());
LOG.info("{} on hold", taskId);
}
}
- private List<VertexHarness> getAllReachables(final VertexHarness src) {
- final List<VertexHarness> result = new ArrayList<>();
- result.add(src);
- result.addAll(src.getNonSideInputChildren().stream()
- .flatMap(child -> getAllReachables(child).stream()).collect(Collectors.toList()));
- result.addAll(src.getSideInputChildren().stream()
- .flatMap(child -> getAllReachables(child).stream()).collect(Collectors.toList()));
- result.addAll(src.getAdditionalTagOutputChildren().values().stream()
- .flatMap(child -> getAllReachables(child).stream()).collect(Collectors.toList()));
- return result;
- }
-
private void finalizeVertex(final VertexHarness vertexHarness) {
closeTransform(vertexHarness);
final OutputCollectorImpl outputCollector = vertexHarness.getOutputCollector();
@@ -333,7 +321,7 @@ public final class TaskExecutor {
// handle additional tagged outputs
vertexHarness.getAdditionalTagOutputChildren().keySet().forEach(tag -> {
outputCollector.iterateTag(tag).forEach(
- element -> handleAdditionalOutputElement(vertexHarness, element, tag)); // Recursion
+ element -> handleAdditionalOutputElement(vertexHarness, element, tag)); // Recursion
outputCollector.clearTag(tag);
});
finalizeOutputWriters(vertexHarness);
@@ -342,27 +330,19 @@ public final class TaskExecutor {
private void handleMainOutputElement(final VertexHarness harness, final Object element) {
// writes to children tasks
harness.getWritersToMainChildrenTasks().forEach(outputWriter -> outputWriter.write(element));
- // writes to side input children tasks
- if (!harness.getSideInputChildren().isEmpty()) {
- sideInputMap.put(((OperatorVertex) harness.getIRVertex()).getTransform().getTag(), element);
- }
// process elements in the next vertices within a task
- harness.getNonSideInputChildren().forEach(child -> processElementRecursively(child, element));
+ harness.getMainTagChildren().forEach(child -> processElementRecursively(child, element));
}
private void handleAdditionalOutputElement(final VertexHarness harness, final Object element, final String tag) {
// writes to additional children tasks
harness.getWritersToAdditionalChildrenTasks().entrySet().stream()
- .filter(kv -> kv.getKey().equals(tag))
- .forEach(kv -> kv.getValue().write(element));
- // writes to side input children tasks
- if (!harness.getSideInputChildren().isEmpty()) {
- sideInputMap.put(((OperatorVertex) harness.getIRVertex()).getTransform().getTag(), element);
- }
+ .filter(kv -> kv.getKey().equals(tag))
+ .forEach(kv -> kv.getValue().write(element));
// process elements in the next vertices within a task
harness.getAdditionalTagOutputChildren().entrySet().stream()
- .filter(kv -> kv.getKey().equals(tag))
- .forEach(kv -> processElementRecursively(kv.getValue(), element));
+ .filter(kv -> kv.getKey().equals(tag))
+ .forEach(kv -> processElementRecursively(kv.getValue(), element));
}
/**
@@ -393,17 +373,13 @@ public final class TaskExecutor {
} catch (IOException e) {
// IOException means that this task should be retried.
taskStateManager.onTaskStateChanged(TaskState.State.SHOULD_RETRY,
- Optional.empty(), Optional.of(TaskState.RecoverableTaskFailureCause.INPUT_READ_FAILURE));
+ Optional.empty(), Optional.of(TaskState.RecoverableTaskFailureCause.INPUT_READ_FAILURE));
LOG.error("{} Execution Failed (Recoverable: input read failure)! Exception: {}", taskId, e);
return false;
}
// Successfully fetched an element
- if (dataFetcher.isFromSideInput()) {
- sideInputMap.put(((OperatorVertex) dataFetcher.getDataSource()).getTransform().getTag(), element);
- } else {
- processElementRecursively(dataFetcher.getChild(), element);
- }
+ processElementRecursively(dataFetcher.getChild(), element);
}
// Remove the finished fetcher from the list
@@ -423,20 +399,20 @@ public final class TaskExecutor {
// Add all intra-task additional tags to additional output map.
irVertexDag.getOutgoingEdgesOf(irVertex.getId())
- .stream()
- .filter(edge -> edge.getPropertyValue(AdditionalOutputTagProperty.class).isPresent())
- .map(edge ->
- Pair.of(edge.getPropertyValue(AdditionalOutputTagProperty.class).get(), edge.getDst().getId()))
- .forEach(pair -> additionalOutputMap.put(pair.left(), pair.right()));
+ .stream()
+ .filter(edge -> edge.getPropertyValue(AdditionalOutputTagProperty.class).isPresent())
+ .map(edge ->
+ Pair.of(edge.getPropertyValue(AdditionalOutputTagProperty.class).get(), edge.getDst().getId()))
+ .forEach(pair -> additionalOutputMap.put(pair.left(), pair.right()));
// Add all inter-task additional tags to additional output map.
outEdgesToChildrenTasks
- .stream()
- .filter(edge -> edge.getSrcIRVertex().getId().equals(irVertex.getId()))
- .filter(edge -> edge.getPropertyValue(AdditionalOutputTagProperty.class).isPresent())
- .map(edge ->
- Pair.of(edge.getPropertyValue(AdditionalOutputTagProperty.class).get(), edge.getDstIRVertex().getId()))
- .forEach(pair -> additionalOutputMap.put(pair.left(), pair.right()));
+ .stream()
+ .filter(edge -> edge.getSrcIRVertex().getId().equals(irVertex.getId()))
+ .filter(edge -> edge.getPropertyValue(AdditionalOutputTagProperty.class).isPresent())
+ .map(edge ->
+ Pair.of(edge.getPropertyValue(AdditionalOutputTagProperty.class).get(), edge.getDstIRVertex().getId()))
+ .forEach(pair -> additionalOutputMap.put(pair.left(), pair.right()));
return additionalOutputMap;
}
@@ -455,38 +431,36 @@ public final class TaskExecutor {
}
private List<InputReader> getParentTaskReaders(final int taskIndex,
- final IRVertex irVertex,
final List<StageEdge> inEdgesFromParentTasks,
final DataTransferFactory dataTransferFactory) {
return inEdgesFromParentTasks
- .stream()
- .filter(inEdge -> inEdge.getDstIRVertex().getId().equals(irVertex.getId()))
- .map(inEdgeForThisVertex -> dataTransferFactory
- .createReader(taskIndex, inEdgeForThisVertex.getSrcIRVertex(), inEdgeForThisVertex))
- .collect(Collectors.toList());
+ .stream()
+ .map(inEdgeForThisVertex -> dataTransferFactory
+ .createReader(taskIndex, inEdgeForThisVertex.getSrcIRVertex(), inEdgeForThisVertex))
+ .collect(Collectors.toList());
}
private Set<String> getMainOutputVertices(final IRVertex irVertex,
- final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag,
- final List<StageEdge> outEdgesToChildrenTasks,
- final List<String> additionalOutputVertices) {
+ final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag,
+ final List<StageEdge> outEdgesToChildrenTasks,
+ final List<String> additionalOutputVertices) {
// all intra-task children vertices id
final List<String> outputVertices = irVertexDag.getOutgoingEdgesOf(irVertex).stream()
- .filter(edge -> edge.getSrc().getId().equals(irVertex.getId()))
- .map(edge -> edge.getDst().getId())
- .collect(Collectors.toList());
+ .filter(edge -> edge.getSrc().getId().equals(irVertex.getId()))
+ .map(edge -> edge.getDst().getId())
+ .collect(Collectors.toList());
// all inter-task children vertices id
outputVertices
- .addAll(outEdgesToChildrenTasks.stream()
- .filter(edge -> edge.getSrcIRVertex().getId().equals(irVertex.getId()))
- .map(edge -> edge.getDstIRVertex().getId())
- .collect(Collectors.toList()));
+ .addAll(outEdgesToChildrenTasks.stream()
+ .filter(edge -> edge.getSrcIRVertex().getId().equals(irVertex.getId()))
+ .map(edge -> edge.getDstIRVertex().getId())
+ .collect(Collectors.toList()));
// return vertices that are not marked as additional tagged outputs
return new HashSet<>(outputVertices.stream()
- .filter(vertexId -> !additionalOutputVertices.contains(vertexId))
- .collect(Collectors.toList()));
+ .filter(vertexId -> !additionalOutputVertices.contains(vertexId))
+ .collect(Collectors.toList()));
}
/**
@@ -503,12 +477,12 @@ public final class TaskExecutor {
final DataTransferFactory dataTransferFactory,
final Map<String, String> taggedOutputs) {
return outEdgesToChildrenTasks
- .stream()
- .filter(outEdge -> outEdge.getSrcIRVertex().getId().equals(irVertex.getId()))
- .filter(outEdge -> !taggedOutputs.containsValue(outEdge.getDstIRVertex().getId()))
- .map(outEdgeForThisVertex -> dataTransferFactory
- .createWriter(taskId, outEdgeForThisVertex.getDstIRVertex(), outEdgeForThisVertex))
- .collect(Collectors.toList());
+ .stream()
+ .filter(outEdge -> outEdge.getSrcIRVertex().getId().equals(irVertex.getId()))
+ .filter(outEdge -> !taggedOutputs.containsValue(outEdge.getDstIRVertex().getId()))
+ .map(outEdgeForThisVertex -> dataTransferFactory
+ .createWriter(taskId, outEdgeForThisVertex.getDstIRVertex(), outEdgeForThisVertex))
+ .collect(Collectors.toList());
}
/**
@@ -527,12 +501,12 @@ public final class TaskExecutor {
final Map<String, OutputWriter> additionalChildrenTaskWriters = new HashMap<>();
outEdgesToChildrenTasks
- .stream()
- .filter(outEdge -> outEdge.getSrcIRVertex().getId().equals(irVertex.getId()))
- .filter(outEdge -> taggedOutputs.containsValue(outEdge.getDstIRVertex().getId()))
- .forEach(outEdgeForThisVertex ->
- additionalChildrenTaskWriters.put(outEdgeForThisVertex.getDstIRVertex().getId(),
- dataTransferFactory.createWriter(taskId, outEdgeForThisVertex.getDstIRVertex(), outEdgeForThisVertex)));
+ .stream()
+ .filter(outEdge -> outEdge.getSrcIRVertex().getId().equals(irVertex.getId()))
+ .filter(outEdge -> taggedOutputs.containsValue(outEdge.getDstIRVertex().getId()))
+ .forEach(outEdgeForThisVertex ->
+ additionalChildrenTaskWriters.put(outEdgeForThisVertex.getDstIRVertex().getId(),
+ dataTransferFactory.createWriter(taskId, outEdgeForThisVertex.getDstIRVertex(), outEdgeForThisVertex)));
return additionalChildrenTaskWriters;
}
@@ -541,10 +515,10 @@ public final class TaskExecutor {
final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag,
final Map<String, VertexHarness> vertexIdToHarness) {
final List<VertexHarness> childrenHandlers = irVertexDag.getChildren(irVertex.getId())
- .stream()
- .map(IRVertex::getId)
- .map(vertexIdToHarness::get)
- .collect(Collectors.toList());
+ .stream()
+ .map(IRVertex::getId)
+ .map(vertexIdToHarness::get)
+ .collect(Collectors.toList());
if (childrenHandlers.stream().anyMatch(harness -> harness == null)) {
// Sanity check: there shouldn't be a null harness.
throw new IllegalStateException(childrenHandlers.toString());
@@ -569,13 +543,13 @@ public final class TaskExecutor {
transform.close();
}
vertexHarness.getContext().getSerializedData().ifPresent(data ->
- persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
- ControlMessage.Message.newBuilder()
- .setId(RuntimeIdManager.generateMessageId())
- .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
- .setType(ControlMessage.MessageType.ExecutorDataCollected)
- .setDataCollected(ControlMessage.DataCollectMessage.newBuilder().setData(data).build())
- .build()));
+ persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
+ ControlMessage.Message.newBuilder()
+ .setId(RuntimeIdManager.generateMessageId())
+ .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
+ .setType(ControlMessage.MessageType.ExecutorDataCollected)
+ .setDataCollected(ControlMessage.DataCollectMessage.newBuilder().setData(data).build())
+ .build()));
}
////////////////////////////////////////////// Misc
@@ -614,6 +588,6 @@ public final class TaskExecutor {
totalWrittenBytes += writtenBytes;
}
metricMessageSender.send("TaskMetric", taskId,
- "writtenBytes", SerializationUtils.serialize(totalWrittenBytes));
+ "writtenBytes", SerializationUtils.serialize(totalWrittenBytes));
}
}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
index 502325f..2ad8868 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
@@ -26,17 +26,16 @@ import java.util.List;
import java.util.Map;
/**
- * Captures the relationship between a non-source IRVertex's outputCollector, and children vertices.
+ * Captures the relationship between a non-source IRVertex's outputCollector, and mainTagChildren vertices.
*/
final class VertexHarness {
// IRVertex and transform-specific information
private final IRVertex irVertex;
private final OutputCollectorImpl outputCollector;
private final Transform.Context context;
+ private final List<VertexHarness> mainTagChildren;
// These lists can be empty
- private final List<VertexHarness> sideInputChildren;
- private final List<VertexHarness> nonSideInputChildren;
private final Map<String, VertexHarness> additionalTagOutputChildren;
private final List<OutputWriter> writersToMainChildrenTasks;
private final Map<String, OutputWriter> writersToAdditionalChildrenTasks;
@@ -44,35 +43,29 @@ final class VertexHarness {
VertexHarness(final IRVertex irVertex,
final OutputCollectorImpl outputCollector,
final List<VertexHarness> children,
- final List<Boolean> isSideInputs,
final List<Boolean> isAdditionalTagOutputs,
final List<OutputWriter> writersToMainChildrenTasks,
final Map<String, OutputWriter> writersToAdditionalChildrenTasks,
final Transform.Context context) {
this.irVertex = irVertex;
this.outputCollector = outputCollector;
- if (children.size() != isSideInputs.size() || children.size() != isAdditionalTagOutputs.size()) {
+ if (children.size() != isAdditionalTagOutputs.size()) {
throw new IllegalStateException(irVertex.toString());
}
final Map<String, String> taggedOutputMap = context.getTagToAdditionalChildren();
- final List<VertexHarness> sides = new ArrayList<>();
- final List<VertexHarness> nonSides = new ArrayList<>();
final Map<String, VertexHarness> tagged = new HashMap<>();
for (int i = 0; i < children.size(); i++) {
final VertexHarness child = children.get(i);
- if (isSideInputs.get(i)) {
- sides.add(child);
- } else if (isAdditionalTagOutputs.get(i)) {
+ if (isAdditionalTagOutputs.get(i)) {
taggedOutputMap.entrySet().stream()
.filter(kv -> child.getIRVertex().getId().equals(kv.getValue()))
.forEach(kv -> tagged.put(kv.getValue(), child));
- } else {
- nonSides.add(child);
}
}
- this.sideInputChildren = sides;
- this.nonSideInputChildren = nonSides;
this.additionalTagOutputChildren = tagged;
+ final List<VertexHarness> mainTagChildrenTmp = new ArrayList<>(children);
+ mainTagChildrenTmp.removeAll(additionalTagOutputChildren.values());
+ this.mainTagChildren = mainTagChildrenTmp;
this.writersToMainChildrenTasks = writersToMainChildrenTasks;
this.writersToAdditionalChildrenTasks = writersToAdditionalChildrenTasks;
this.context = context;
@@ -93,21 +86,14 @@ final class VertexHarness {
}
/**
- * @return list of non-sideinput children. (empty if none exists)
+ * @return mainTagChildren harnesses.
*/
- List<VertexHarness> getNonSideInputChildren() {
- return nonSideInputChildren;
+ List<VertexHarness> getMainTagChildren() {
+ return mainTagChildren;
}
/**
- * @return list of sideinput children. (empty if none exists)
- */
- List<VertexHarness> getSideInputChildren() {
- return sideInputChildren;
- }
-
- /**
- * @return map of tagged output children. (empty if none exists)
+ * @return map of tagged output mainTagChildren. (empty if none exists)
*/
public Map<String, VertexHarness> getAdditionalTagOutputChildren() {
return additionalTagOutputChildren;
diff --git a/common/src/test/java/edu/snu/nemo/common/ContextImplTest.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/TransformContextImplTest.java
similarity index 61%
rename from common/src/test/java/edu/snu/nemo/common/ContextImplTest.java
rename to runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/TransformContextImplTest.java
index 149cdd1..b2f90e6 100644
--- a/common/src/test/java/edu/snu/nemo/common/ContextImplTest.java
+++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/TransformContextImplTest.java
@@ -14,35 +14,43 @@
* limitations under the License.
*/
-package edu.snu.nemo.common;
+package edu.snu.nemo.runtime.executor;
import edu.snu.nemo.common.ir.vertex.transform.Transform;
+import edu.snu.nemo.runtime.executor.data.BroadcastManagerWorker;
import org.junit.Before;
import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
/**
- * Tests {@link ContextImpl}.
+ * Tests {@link TransformContextImpl}.
*/
-public class ContextImplTest {
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({BroadcastManagerWorker.class})
+public class TransformContextImplTest {
private Transform.Context context;
- private final Map sideInputs = new HashMap();
- private final Map<String, String> taggedOutputs = new HashMap();
+ private final Map<String, String> taggedOutputs = new HashMap<>();
@Before
public void setUp() {
- sideInputs.put("a", "b");
- this.context = new ContextImpl(sideInputs, taggedOutputs);
+ final BroadcastManagerWorker broadcastManagerWorker = mock(BroadcastManagerWorker.class);
+ when(broadcastManagerWorker.get("a")).thenReturn("b");
+ this.context = new TransformContextImpl(broadcastManagerWorker, taggedOutputs);
}
@Test
public void testContextImpl() {
- assertEquals(this.sideInputs, this.context.getSideInputs());
+ assertEquals("b", this.context.getBroadcastVariable("a"));
assertEquals(this.taggedOutputs, this.context.getTagToAdditionalChildren());
final String sampleText = "test_text";
diff --git a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java
index 8c1ccd3..0d0a69b 100644
--- a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java
+++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java
@@ -322,8 +322,7 @@ public final class DataTransferTest {
final IRVertex dstMockVertex = mock(IRVertex.class);
final Stage srcStage = setupStages("srcStage" + testIndex);
final Stage dstStage = setupStages("dstStage" + testIndex);
- dummyEdge = new StageEdge(edgeId, edgeProperties, srcMockVertex, dstMockVertex,
- srcStage, dstStage, false);
+ dummyEdge = new StageEdge(edgeId, edgeProperties, srcMockVertex, dstMockVertex, srcStage, dstStage);
// Initialize states in Master
TestUtil.generateTaskIds(srcStage).forEach(srcTaskId -> {
@@ -414,11 +413,9 @@ public final class DataTransferTest {
final IRVertex dstMockVertex = mock(IRVertex.class);
final Stage srcStage = setupStages("srcStage" + testIndex);
final Stage dstStage = setupStages("dstStage" + testIndex);
- dummyEdge = new StageEdge(edgeId, edgeProperties, srcMockVertex, dstMockVertex,
- srcStage, dstStage, false);
+ dummyEdge = new StageEdge(edgeId, edgeProperties, srcMockVertex, dstMockVertex, srcStage, dstStage);
final IRVertex dstMockVertex2 = mock(IRVertex.class);
- dummyEdge2 = new StageEdge(edgeId2, edgeProperties, srcMockVertex, dstMockVertex2,
- srcStage, dstStage, false);
+ dummyEdge2 = new StageEdge(edgeId2, edgeProperties, srcMockVertex, dstMockVertex2, srcStage, dstStage);
// Initialize states in Master
TestUtil.generateTaskIds(srcStage).forEach(srcTaskId -> {
final String blockId = RuntimeIdManager.generateBlockId(edgeId, srcTaskId);
diff --git a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java
index 47c2d19..12e5124 100644
--- a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java
+++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java
@@ -118,8 +118,7 @@ public final class ParentTaskDataFetcherTest {
return new ParentTaskDataFetcher(
mock(IRVertex.class),
readerForParentTask, // This is the only argument that affects the behavior of ParentTaskDataFetcher
- mock(VertexHarness.class),
- false);
+ mock(VertexHarness.class));
}
private InputReader generateInputReader(final CompletableFuture completableFuture) {
diff --git a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
index 4889c06..b1a1450 100644
--- a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
+++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
@@ -22,6 +22,7 @@ import edu.snu.nemo.common.dag.DAGBuilder;
import edu.snu.nemo.common.ir.Readable;
import edu.snu.nemo.common.ir.edge.IREdge;
import edu.snu.nemo.common.ir.edge.executionproperty.AdditionalOutputTagProperty;
+import edu.snu.nemo.common.ir.edge.executionproperty.BroadcastVariableIdProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.DataStoreProperty;
import edu.snu.nemo.common.ir.executionproperty.EdgeExecutionProperty;
@@ -39,6 +40,7 @@ import edu.snu.nemo.runtime.common.plan.StageEdge;
import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
import edu.snu.nemo.runtime.executor.MetricMessageSender;
import edu.snu.nemo.runtime.executor.TaskStateManager;
+import edu.snu.nemo.runtime.executor.data.BroadcastManagerWorker;
import edu.snu.nemo.runtime.executor.data.DataUtil;
import edu.snu.nemo.runtime.executor.datatransfer.DataTransferFactory;
import edu.snu.nemo.runtime.executor.datatransfer.InputReader;
@@ -52,6 +54,7 @@ import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import java.io.IOException;
+import java.io.Serializable;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
@@ -68,7 +71,7 @@ import static org.mockito.Mockito.*;
* Tests {@link TaskExecutor}.
*/
@RunWith(PowerMockRunner.class)
-@PrepareForTest({InputReader.class, OutputWriter.class, DataTransferFactory.class,
+@PrepareForTest({InputReader.class, OutputWriter.class, DataTransferFactory.class, BroadcastManagerWorker.class,
TaskStateManager.class, StageEdge.class, PersistentConnectionToMasterMap.class, Stage.class, IREdge.class})
public final class TaskExecutorTest {
private static final AtomicInteger RUNTIME_EDGE_ID = new AtomicInteger(0);
@@ -81,6 +84,7 @@ public final class TaskExecutorTest {
private List<Integer> elements;
private Map<String, List> runtimeEdgeToOutputData;
private DataTransferFactory dataTransferFactory;
+ private BroadcastManagerWorker broadcastManagerWorker;
private TaskStateManager taskStateManager;
private MetricMessageSender metricMessageSender;
private PersistentConnectionToMasterMap persistentConnectionToMasterMap;
@@ -111,6 +115,7 @@ public final class TaskExecutorTest {
doNothing().when(metricMessageSender).close();
persistentConnectionToMasterMap = mock(PersistentConnectionToMasterMap.class);
+ broadcastManagerWorker = mock(BroadcastManagerWorker.class);
}
private boolean checkEqualElements(final List<Integer> left, final List<Integer> right) {
@@ -156,8 +161,7 @@ public final class TaskExecutorTest {
vertexIdToReadable);
// Execute the task.
- final TaskExecutor taskExecutor = new TaskExecutor(
- task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender, persistentConnectionToMasterMap);
+ final TaskExecutor taskExecutor = getTaskExecutor(task, taskDag);
taskExecutor.execute();
// Check the output.
@@ -186,8 +190,7 @@ public final class TaskExecutorTest {
Collections.emptyMap());
// Execute the task.
- final TaskExecutor taskExecutor = new TaskExecutor(
- task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender, persistentConnectionToMasterMap);
+ final TaskExecutor taskExecutor = getTaskExecutor(task, taskDag);
taskExecutor.execute();
// Check the output.
@@ -207,10 +210,11 @@ public final class TaskExecutorTest {
final IRVertex operatorIRVertex1 = new OperatorVertex(new RelayTransform());
final IRVertex operatorIRVertex2 = new OperatorVertex(new RelayTransform());
+ final String edgeId = "edge";
final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>()
.addVertex(operatorIRVertex1)
.addVertex(operatorIRVertex2)
- .connectVertices(createEdge(operatorIRVertex1, operatorIRVertex2, false))
+ .connectVertices(createEdge(operatorIRVertex1, operatorIRVertex2, edgeId))
.buildWithoutSourceSinkCheck();
final StageEdge taskOutEdge = mockStageEdgeFrom(operatorIRVertex2);
@@ -224,8 +228,7 @@ public final class TaskExecutorTest {
Collections.emptyMap());
// Execute the task.
- final TaskExecutor taskExecutor = new TaskExecutor(
- task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender, persistentConnectionToMasterMap);
+ final TaskExecutor taskExecutor = getTaskExecutor(task, taskDag);
taskExecutor.execute();
// Check the output.
@@ -233,38 +236,43 @@ public final class TaskExecutorTest {
}
@Test(timeout=5000)
- public void testTwoOperatorsWithSideInput() throws Exception {
- final Object tag = new Object();
+ public void testTwoOperatorsWithBroadcastVariable() {
final Transform singleListTransform = new CreateSingleListTransform();
- final IRVertex operatorIRVertex1 = new OperatorVertex(singleListTransform);
- final IRVertex operatorIRVertex2 = new OperatorVertex(new SideInputPairTransform(singleListTransform.getTag()));
+ final long broadcastId = 0;
+ final IRVertex operatorIRVertex1 = new OperatorVertex(new RelayTransform());
+ final IRVertex operatorIRVertex2 = new OperatorVertex(new BroadcastVariablePairingTransform(broadcastId));
+
+ final String edgeId = "edge";
final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>()
.addVertex(operatorIRVertex1)
.addVertex(operatorIRVertex2)
- .connectVertices(createEdge(operatorIRVertex1, operatorIRVertex2, true))
+ .connectVertices(createEdge(operatorIRVertex1, operatorIRVertex2, edgeId))
.buildWithoutSourceSinkCheck();
final StageEdge taskOutEdge = mockStageEdgeFrom(operatorIRVertex2);
+
+ final StageEdge broadcastInEdge = mockBroadcastVariableStageEdgeTo(
+ new OperatorVertex(singleListTransform), operatorIRVertex2, broadcastId, elements);
+
final Task task = new Task(
"testSourceVertexDataFetching",
generateTaskId(),
TASK_EXECUTION_PROPERTY_MAP,
new byte[0],
- Arrays.asList(mockStageEdgeTo(operatorIRVertex1), mockStageEdgeTo(operatorIRVertex2)),
+ Arrays.asList(mockStageEdgeTo(operatorIRVertex1), broadcastInEdge),
Collections.singletonList(taskOutEdge),
Collections.emptyMap());
// Execute the task.
- final TaskExecutor taskExecutor = new TaskExecutor(
- task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender, persistentConnectionToMasterMap);
+ final TaskExecutor taskExecutor = getTaskExecutor(task, taskDag);
taskExecutor.execute();
// Check the output.
final List<Pair<List<Integer>, Integer>> pairs = runtimeEdgeToOutputData.get(taskOutEdge.getId());
final List<Integer> values = pairs.stream().map(Pair::right).collect(Collectors.toList());
assertTrue(checkEqualElements(elements, values));
- assertTrue(pairs.stream().map(Pair::left).allMatch(sideInput -> checkEqualElements(sideInput, values)));
+ assertTrue(pairs.stream().map(Pair::left).allMatch(broadcastVar -> checkEqualElements(broadcastVar, values)));
}
/**
@@ -283,9 +291,9 @@ public final class TaskExecutorTest {
final IRVertex bonusVertex1 = new OperatorVertex(new RelayTransform());
final IRVertex bonusVertex2 = new OperatorVertex(new RelayTransform());
- final RuntimeEdge<IRVertex> edge1 = createEdge(routerVertex, mainVertex, false, "edge-1");
- final RuntimeEdge<IRVertex> edge2 = createEdge(routerVertex, bonusVertex1, false, "edge-2");
- final RuntimeEdge<IRVertex> edge3 = createEdge(routerVertex, bonusVertex2, false, "edge-3");
+ final RuntimeEdge<IRVertex> edge1 = createEdge(routerVertex, mainVertex, "edge-1");
+ final RuntimeEdge<IRVertex> edge2 = createEdge(routerVertex, bonusVertex1, "edge-2");
+ final RuntimeEdge<IRVertex> edge3 = createEdge(routerVertex, bonusVertex2, "edge-3");
edge2.getExecutionProperties().put(AdditionalOutputTagProperty.of("bonus1"));
edge3.getExecutionProperties().put(AdditionalOutputTagProperty.of("bonus2"));
@@ -314,14 +322,14 @@ public final class TaskExecutorTest {
Collections.emptyMap());
// Execute the task.
- final TaskExecutor taskExecutor = new TaskExecutor(
- task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender, persistentConnectionToMasterMap);
+ final TaskExecutor taskExecutor = getTaskExecutor(task, taskDag);
taskExecutor.execute();
// Check the output.
final List<Integer> mainOutputs = runtimeEdgeToOutputData.get(outEdge1.getId());
final List<Integer> bonusOutputs1 = runtimeEdgeToOutputData.get(outEdge2.getId());
final List<Integer> bonusOutputs2 = runtimeEdgeToOutputData.get(outEdge3.getId());
+
List<Integer> even = elements.stream().filter(i -> i % 2 == 0).collect(Collectors.toList());
List<Integer> odd = elements.stream().filter(i -> i % 2 != 0).collect(Collectors.toList());
assertTrue(checkEqualElements(even, mainOutputs));
@@ -331,21 +339,10 @@ public final class TaskExecutorTest {
private RuntimeEdge<IRVertex> createEdge(final IRVertex src,
final IRVertex dst,
- final boolean isSideInput) {
- final String runtimeIREdgeId = "Runtime edge between operator tasks";
- ExecutionPropertyMap<EdgeExecutionProperty> edgeProperties = new ExecutionPropertyMap<>(runtimeIREdgeId);
- edgeProperties.put(DataStoreProperty.of(DataStoreProperty.Value.MemoryStore));
- return new RuntimeEdge<>(runtimeIREdgeId, edgeProperties, src, dst, isSideInput);
-
- }
-
- private RuntimeEdge<IRVertex> createEdge(final IRVertex src,
- final IRVertex dst,
- final boolean isSideInput,
final String runtimeIREdgeId) {
ExecutionPropertyMap<EdgeExecutionProperty> edgeProperties = new ExecutionPropertyMap<>(runtimeIREdgeId);
edgeProperties.put(DataStoreProperty.of(DataStoreProperty.Value.MemoryStore));
- return new RuntimeEdge<>(runtimeIREdgeId, edgeProperties, src, dst, isSideInput);
+ return new RuntimeEdge<>(runtimeIREdgeId, edgeProperties, src, dst);
}
@@ -355,18 +352,35 @@ public final class TaskExecutorTest {
irVertex,
new OperatorVertex(new RelayTransform()),
mock(Stage.class),
- mock(Stage.class),
- false);
+ mock(Stage.class));
}
private StageEdge mockStageEdgeTo(final IRVertex irVertex) {
+ final ExecutionPropertyMap executionPropertyMap =
+ ExecutionPropertyMap.of(mock(IREdge.class), CommunicationPatternProperty.Value.OneToOne);
return new StageEdge("runtime outgoing edge id",
- ExecutionPropertyMap.of(mock(IREdge.class), CommunicationPatternProperty.Value.OneToOne),
- new OperatorVertex(new RelayTransform()),
- irVertex,
- mock(Stage.class),
- mock(Stage.class),
- false);
+ executionPropertyMap,
+ new OperatorVertex(new RelayTransform()),
+ irVertex,
+ mock(Stage.class),
+ mock(Stage.class));
+ }
+
+ private StageEdge mockBroadcastVariableStageEdgeTo(final IRVertex srcVertex,
+ final IRVertex dstVertex,
+ final Serializable broadcastVariableId,
+ final Object broadcastVariable) {
+ when(broadcastManagerWorker.get(broadcastVariableId)).thenReturn(broadcastVariable);
+
+ final ExecutionPropertyMap executionPropertyMap =
+ ExecutionPropertyMap.of(mock(IREdge.class), CommunicationPatternProperty.Value.OneToOne);
+ executionPropertyMap.put(BroadcastVariableIdProperty.of(broadcastVariableId));
+ return new StageEdge("runtime outgoing edge id",
+ executionPropertyMap,
+ srcVertex,
+ dstVertex,
+ mock(Stage.class),
+ mock(Stage.class));
}
/**
@@ -384,8 +398,9 @@ public final class TaskExecutorTest {
.iterator())));
}
final InputReader inputReader = mock(InputReader.class);
+ final IRVertex srcVertex = (IRVertex) invocationOnMock.getArgument(1);
+ when(inputReader.getSrcIrVertex()).thenReturn(srcVertex);
when(inputReader.read()).thenReturn(inputFutures);
- when(inputReader.isSideInputReader()).thenReturn(false);
when(inputReader.getSourceParallelism()).thenReturn(SOURCE_PARALLELISM);
return inputReader;
}
@@ -445,7 +460,6 @@ public final class TaskExecutorTest {
private class CreateSingleListTransform<T> implements Transform<T, List<T>> {
private List<T> list;
private OutputCollector<List<T>> outputCollector;
- private final Object tag = new Object();
@Override
public void prepare(final Context context, final OutputCollector<List<T>> outputCollector) {
@@ -462,24 +476,19 @@ public final class TaskExecutorTest {
public void close() {
outputCollector.emit(list);
}
-
- @Override
- public Object getTag() {
- return tag;
- }
}
/**
- * Pairs data element with a side input.
+ * Pairs data element with a broadcast variable.
* @param <T> input/output type.
*/
- private class SideInputPairTransform<T> implements Transform<T, T> {
- private final Object sideInputTag;
+ private class BroadcastVariablePairingTransform<T> implements Transform<T, T> {
+ private final Serializable broadcastVariableId;
private Context context;
private OutputCollector<T> outputCollector;
- public SideInputPairTransform(final Object sideInputTag) {
- this.sideInputTag = sideInputTag;
+ public BroadcastVariablePairingTransform(final Serializable broadcastVariableId) {
+ this.broadcastVariableId = broadcastVariableId;
}
@Override
@@ -490,8 +499,8 @@ public final class TaskExecutorTest {
@Override
public void onData(final Object element) {
- final Object sideInput = context.getSideInputs().get(sideInputTag);
- outputCollector.emit((T) Pair.of(sideInput, element));
+ final Object broadcastVariable = context.getBroadcastVariable(broadcastVariableId);
+ outputCollector.emit((T) Pair.of(broadcastVariable, element));
}
@Override
@@ -542,4 +551,9 @@ public final class TaskExecutorTest {
IntStream.range(start, end).forEach(number -> numList.add(number));
return numList;
}
+
+ private TaskExecutor getTaskExecutor(final Task task, final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag) {
+ return new TaskExecutor(task, taskDag, taskStateManager, dataTransferFactory, broadcastManagerWorker,
+ metricMessageSender, persistentConnectionToMasterMap);
+ }
}
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java
index 05b7467..052a644 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java
@@ -72,13 +72,12 @@ public final class BlockManagerMaster {
/**
* Constructor.
- *
* @param masterMessageEnvironment the message environment.
*/
@Inject
private BlockManagerMaster(final MessageEnvironment masterMessageEnvironment) {
masterMessageEnvironment.setupListener(MessageEnvironment.BLOCK_MANAGER_MASTER_MESSAGE_LISTENER_ID,
- new PartitionManagerMasterControlMessageReceiver());
+ new BlockManagerMasterControlMessageReceiver());
this.blockIdWildcardToMetadataSet = new HashMap<>();
this.producerTaskIdToBlockIds = new HashMap<>();
this.lock = new ReentrantReadWriteLock();
@@ -337,7 +336,7 @@ public final class BlockManagerMaster {
/**
* Handler for control messages received.
*/
- public final class PartitionManagerMasterControlMessageReceiver implements MessageListener<ControlMessage.Message> {
+ public final class BlockManagerMasterControlMessageReceiver implements MessageListener<ControlMessage.Message> {
@Override
public void onMessage(final ControlMessage.Message message) {
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BroadcastManagerMaster.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BroadcastManagerMaster.java
new file mode 100644
index 0000000..77c6bbc
--- /dev/null
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BroadcastManagerMaster.java
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.runtime.master;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Broadcast variables saved in the master.
+ */
+public final class BroadcastManagerMaster {
+ private static final Logger LOG = LoggerFactory.getLogger(BroadcastManagerMaster.class.getName());
+ private static final Map<Serializable, Object> ID_TO_VARIABLE = new HashMap<>();
+
+ private BroadcastManagerMaster() {
+ }
+
+ /**
+ * @param variables from the client.
+ */
+ public static void registerBroadcastVariablesFromClient(final Map<Serializable, Object> variables) {
+ LOG.info("Registered broadcast variable ids {} sent from the client", variables.keySet());
+ ID_TO_VARIABLE.putAll(variables);
+ }
+
+ /**
+ * @param id of the broadcast variable.
+ * @return the requested broadcast variable.
+ */
+ public static Object getBroadcastVariable(final Serializable id) {
+ return ID_TO_VARIABLE.get(id);
+ }
+}
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanAppender.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanAppender.java
index da6ddf1..20b065b 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanAppender.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanAppender.java
@@ -111,8 +111,7 @@ public final class PlanAppender {
cachedEdge.getSrcIRVertex(),
vertexStagePair.left(),
cachedEdge.getSrc(),
- vertexStagePair.right(),
- cachedEdge.isSideInput());
+ vertexStagePair.right());
physicalDAGBuilder.connectVertices(newEdge);
final DuplicateEdgeGroupPropertyValue duplicateEdgeGroupPropertyValue =
cachedEdge.getPropertyValue(DuplicateEdgeGroupProperty.class)
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanStateManager.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanStateManager.java
index 5017311..9450b58 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanStateManager.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanStateManager.java
@@ -393,7 +393,9 @@ public final class PlanStateManager {
// Change plan state if needed
final boolean allStagesCompleted = stageIdToState.values().stream().allMatch(state ->
state.getStateMachine().getCurrentState().equals(StageState.State.COMPLETE));
- if (allStagesCompleted) {
+
+ // avoid duplicate plan COMPLETE caused by cloning
+ if (allStagesCompleted && !PlanState.State.COMPLETE.equals(getPlanState())) {
onPlanStateChanged(PlanState.State.COMPLETE);
}
}
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
index 00eac87..3129732 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
@@ -15,9 +15,11 @@
*/
package edu.snu.nemo.runtime.master;
+import com.google.protobuf.ByteString;
import edu.snu.nemo.common.Pair;
import edu.snu.nemo.common.exception.*;
import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.message.MessageContext;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
@@ -46,6 +48,7 @@ import com.fasterxml.jackson.core.TreeNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import javax.inject.Inject;
+import java.io.Serializable;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
@@ -328,6 +331,25 @@ public final class RuntimeMaster {
@Override
public void onMessageWithContext(final ControlMessage.Message message, final MessageContext messageContext) {
switch (message.getType()) {
+ case RequestBroadcastVariable:
+ final Serializable broadcastId =
+ SerializationUtils.deserialize(message.getRequestbroadcastVariableMsg().getBroadcastId().toByteArray());
+ final Object broadcastVariable = BroadcastManagerMaster.getBroadcastVariable(broadcastId);
+ if (broadcastVariable == null) {
+ throw new IllegalStateException(broadcastId.toString());
+ }
+ messageContext.reply(
+ ControlMessage.Message.newBuilder()
+ .setId(RuntimeIdManager.generateMessageId())
+ .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
+ .setType(ControlMessage.MessageType.InMasterBroadcastVariable)
+ .setBroadcastVariableMsg(ControlMessage.InMasterBroadcastVariableMessage.newBuilder()
+ .setRequestId(message.getId())
+ // TODO #206: Efficient Broadcast Variable Serialization
+ .setVariable(ByteString.copyFrom(SerializationUtils.serialize((Serializable) broadcastVariable)))
+ .build())
+ .build());
+ break;
default:
throw new IllegalMessageException(
new Exception("This message should not be requested to Master :" + message.getType()));
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraintTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraintTest.java
index 6af6a61..85a1601 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraintTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraintTest.java
@@ -71,7 +71,7 @@ public final class SkewnessAwareSchedulingConstraintTest {
dummyIREdge.setProperty(DataFlowProperty.of(DataFlowProperty.Value.Pull));
dummyIREdge.setProperty(DataSkewMetricProperty.of(new DataSkewMetricFactory(taskIdxToKeyRange)));
final StageEdge dummyEdge = new StageEdge("Edge0", dummyIREdge.getExecutionProperties(),
- srcMockVertex, dstMockVertex, srcMockStage, dstMockStage, false);
+ srcMockVertex, dstMockVertex, srcMockStage, dstMockStage);
return dummyEdge;
}