You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by je...@apache.org on 2018/08/19 07:03:56 UTC
[incubator-nemo] branch master updated: [NEMO-178] Zero-delay task
cloning (#107)
This is an automated email from the ASF dual-hosted git repository.
jeongyoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git
The following commit(s) were added to refs/heads/master by this push:
new 98dbe68 [NEMO-178] Zero-delay task cloning (#107)
98dbe68 is described below
commit 98dbe68f7c8287b3160d14226b0abd7f3bf3a091
Author: John Yang <jo...@gmail.com>
AuthorDate: Sun Aug 19 16:03:54 2018 +0900
[NEMO-178] Zero-delay task cloning (#107)
JIRA: NEMO-178: Zero-delay task cloning
Major changes:
* A state machine per task attempt/its output blocks
* Randomly choose one of the available attempt output blocks when fetching data
* A cloned task, or a retried task is just another task attempt
* ClonedSchedulingProperty: Zero-delay, upfront task cloning
Minor changes to note:
* Removes unused code
* RuntimeIdGenerator -> RuntimeIdManager
* Removes the first hyphens from the ids (e.g., vertex-4 => vertex4)
Tests for the changes:
* WordCountITCase#testClonedScheduling
Other comments:
* ClonedSchedulingPass at the moment just clones source vertices, as cloning Beam sink vertices appears to cause failures due to a duplicate file name conflict (I'll look into this in a later PR)
resolves NEMO-178
---
.../edu/snu/nemo/client/ClientEndpointTest.java | 9 +-
.../java/edu/snu/nemo/common/StateMachine.java | 5 +-
.../exception/IllegalStateTransitionException.java | 2 +-
.../java/edu/snu/nemo/common/ir/IdManager.java | 4 +-
.../ClonedSchedulingProperty.java | 46 ++++
.../java/edu/snu/nemo/common/StateMachineTest.java | 3 +-
.../nemo/compiler/backend/nemo/NemoBackend.java | 4 +-
.../annotating/ClonedSchedulingPass.java | 43 ++++
.../annotating/DefaultScheduleGroupPass.java | 2 +-
.../compiler/backend/nemo/DAGConverterTest.java | 81 +------
.../snu/nemo/examples/beam/PerKeyMedianITCase.java | 2 +-
.../snu/nemo/examples/beam/WordCountITCase.java | 10 +
.../ClonedSchedulingPolicyParallelismFive.java | 51 +++++
...ntimeIdGenerator.java => RuntimeIdManager.java} | 91 +++++---
.../runtime/common/optimizer/RunTimeOptimizer.java | 4 +-
.../runtime/common/plan/PhysicalPlanGenerator.java | 4 +-
.../edu/snu/nemo/runtime/common/plan/Stage.java | 13 --
.../edu/snu/nemo/runtime/common/plan/Task.java | 13 +-
.../snu/nemo/runtime/common/state/BlockState.java | 13 +-
.../snu/nemo/runtime/common/state/TaskState.java | 1 -
runtime/common/src/main/proto/ControlMessage.proto | 2 +-
.../main/java/edu/snu/nemo/driver/NemoDriver.java | 4 +-
runtime/executor/pom.xml | 10 +
.../edu/snu/nemo/runtime/executor/Executor.java | 4 +-
.../nemo/runtime/executor/MetricManagerWorker.java | 6 +-
.../nemo/runtime/executor/TaskStateManager.java | 4 +-
.../runtime/executor/data/BlockManagerWorker.java | 32 +--
.../executor/data/stores/AbstractBlockStore.java | 4 +-
.../executor/datatransfer/DataTransferFactory.java | 9 +-
.../runtime/executor/datatransfer/InputReader.java | 31 ++-
.../executor/datatransfer/OutputWriter.java | 20 +-
.../nemo/runtime/executor/task/TaskExecutor.java | 23 +-
.../edu/snu/nemo/runtime/executor/TestUtil.java | 27 ++-
.../nemo/runtime/executor/data/BlockStoreTest.java | 36 +--
.../executor/datatransfer/DataTransferTest.java | 47 ++--
.../runtime/executor/task/TaskExecutorTest.java | 62 +++---
.../nemo/runtime/master/BlockManagerMaster.java | 176 ++++++---------
.../edu/snu/nemo/runtime/master/BlockMetadata.java | 29 ++-
.../nemo/runtime/master/MetricManagerMaster.java | 4 +-
.../snu/nemo/runtime/master/PlanStateManager.java | 248 ++++++++++++++-------
.../edu/snu/nemo/runtime/master/RuntimeMaster.java | 8 +-
.../master/resource/ExecutorRepresenter.java | 4 +-
.../master/resource/ResourceSpecification.java | 4 +-
.../runtime/master/scheduler/BatchScheduler.java | 189 +++++++---------
.../scheduler/NodeShareSchedulingConstraint.java | 4 +-
.../SkewnessAwareSchedulingConstraint.java | 4 +-
.../SourceLocationAwareSchedulingConstraint.java | 7 +-
.../runtime/master/scheduler/TaskDispatcher.java | 22 +-
.../runtime/master/BlockManagerMasterTest.java | 82 ++++---
.../nemo/runtime/master/PlanStateManagerTest.java | 18 +-
.../master/scheduler/BatchSchedulerTest.java | 5 +-
.../master/scheduler/SchedulerTestUtil.java | 2 +-
.../SkewnessAwareSchedulingConstraintTest.java | 7 +-
.../runtime/master/scheduler/TaskRetryTest.java | 21 +-
54 files changed, 827 insertions(+), 729 deletions(-)
diff --git a/client/src/test/java/edu/snu/nemo/client/ClientEndpointTest.java b/client/src/test/java/edu/snu/nemo/client/ClientEndpointTest.java
index 744a992..835313e 100644
--- a/client/src/test/java/edu/snu/nemo/client/ClientEndpointTest.java
+++ b/client/src/test/java/edu/snu/nemo/client/ClientEndpointTest.java
@@ -18,12 +18,10 @@ package edu.snu.nemo.client;
import edu.snu.nemo.runtime.common.plan.PhysicalPlan;
import edu.snu.nemo.runtime.common.state.PlanState;
import edu.snu.nemo.runtime.common.state.TaskState;
-import edu.snu.nemo.runtime.master.MetricMessageHandler;
import edu.snu.nemo.runtime.master.PlanStateManager;
import edu.snu.nemo.runtime.common.plan.TestPlanGenerator;
import org.junit.Test;
import org.junit.runner.RunWith;
-import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import java.util.List;
@@ -39,10 +37,8 @@ import static org.mockito.Mockito.when;
* Test {@link ClientEndpoint}.
*/
@RunWith(PowerMockRunner.class)
-@PrepareForTest(MetricMessageHandler.class)
public class ClientEndpointTest {
private static final int MAX_SCHEDULE_ATTEMPT = 2;
- private final MetricMessageHandler metricMessageHandler = mock(MetricMessageHandler.class);
@Test(timeout = 3000)
public void testState() throws Exception {
@@ -58,8 +54,7 @@ public class ClientEndpointTest {
// Create a PlanStateManager of a dag and create a DriverEndpoint with it.
final PhysicalPlan physicalPlan =
TestPlanGenerator.generatePhysicalPlan(TestPlanGenerator.PlanType.TwoVerticesJoined, false);
- final PlanStateManager planStateManager =
- new PlanStateManager(physicalPlan, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
+ final PlanStateManager planStateManager = new PlanStateManager(physicalPlan, MAX_SCHEDULE_ATTEMPT);
final DriverEndpoint driverEndpoint = new DriverEndpoint(planStateManager, clientEndpoint);
@@ -71,7 +66,7 @@ public class ClientEndpointTest {
// Check finish.
final List<String> tasks = physicalPlan.getStageDAG().getTopologicalSort().stream()
- .flatMap(stage -> stage.getTaskIds().stream())
+ .flatMap(stage -> planStateManager.getTaskAttemptsToSchedule(stage.getId()).stream())
.collect(Collectors.toList());
tasks.forEach(taskId -> planStateManager.onTaskStateChanged(taskId, TaskState.State.EXECUTING));
tasks.forEach(taskId -> planStateManager.onTaskStateChanged(taskId, TaskState.State.COMPLETE));
diff --git a/common/src/main/java/edu/snu/nemo/common/StateMachine.java b/common/src/main/java/edu/snu/nemo/common/StateMachine.java
index f98c622..6e10407 100644
--- a/common/src/main/java/edu/snu/nemo/common/StateMachine.java
+++ b/common/src/main/java/edu/snu/nemo/common/StateMachine.java
@@ -68,7 +68,7 @@ public final class StateMachine {
* @throws RuntimeException if the state is unknown state, or the transition
* from the current state to the specified state is illegal
*/
- public synchronized void setState(final Enum state) {
+ public synchronized void setState(final Enum state) throws IllegalStateTransitionException {
if (!stateMap.containsKey(state)) {
throw new RuntimeException("Unknown state " + state);
}
@@ -99,7 +99,8 @@ public final class StateMachine {
* @throws RuntimeException if the state is unknown state, or the transition
* from the current state to the specified state is illegal
*/
- public synchronized boolean compareAndSetState(final Enum expectedCurrentState, final Enum state) {
+ public synchronized boolean compareAndSetState(final Enum expectedCurrentState,
+ final Enum state) throws IllegalStateTransitionException {
final boolean compared = currentState.stateEnum.equals(expectedCurrentState);
if (compared) {
setState(state);
diff --git a/common/src/main/java/edu/snu/nemo/common/exception/IllegalStateTransitionException.java b/common/src/main/java/edu/snu/nemo/common/exception/IllegalStateTransitionException.java
index f1a4b9f..b30684f 100644
--- a/common/src/main/java/edu/snu/nemo/common/exception/IllegalStateTransitionException.java
+++ b/common/src/main/java/edu/snu/nemo/common/exception/IllegalStateTransitionException.java
@@ -19,7 +19,7 @@ package edu.snu.nemo.common.exception;
* IllegalStateTransitionException.
* Thrown when the execution state transition is illegal.
*/
-public final class IllegalStateTransitionException extends RuntimeException {
+public final class IllegalStateTransitionException extends Exception {
/**
* IllegalStateTransitionException.
* @param cause cause
diff --git a/common/src/main/java/edu/snu/nemo/common/ir/IdManager.java b/common/src/main/java/edu/snu/nemo/common/ir/IdManager.java
index 934d048..3a183ae 100644
--- a/common/src/main/java/edu/snu/nemo/common/ir/IdManager.java
+++ b/common/src/main/java/edu/snu/nemo/common/ir/IdManager.java
@@ -35,13 +35,13 @@ public final class IdManager {
* @return a new operator ID.
*/
public static String newVertexId() {
- return "vertex" + (isDriver ? "-d" : "-") + vertexId.getAndIncrement();
+ return "vertex" + (isDriver ? "(d)" : "") + vertexId.getAndIncrement();
}
/**
* @return a new edge ID.
*/
public static String newEdgeId() {
- return "edge" + (isDriver ? "-d" : "-") + edgeId.getAndIncrement();
+ return "edge" + (isDriver ? "(d)" : "") + edgeId.getAndIncrement();
}
/**
diff --git a/common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/ClonedSchedulingProperty.java b/common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/ClonedSchedulingProperty.java
new file mode 100644
index 0000000..cd0f312
--- /dev/null
+++ b/common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/ClonedSchedulingProperty.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.common.ir.vertex.executionproperty;
+
+import edu.snu.nemo.common.ir.executionproperty.VertexExecutionProperty;
+
+/**
+ * Specifies cloned execution of a vertex.
+ *
+ * A major limitations of the current implementation:
+ * *ALL* of the clones are always scheduled immediately
+ */
+public final class ClonedSchedulingProperty extends VertexExecutionProperty<Integer> {
+ /**
+ * Constructor.
+ * @param value value of the execution property.
+ */
+ private ClonedSchedulingProperty(final Integer value) {
+ super(value);
+ }
+
+ /**
+ * Static method exposing the constructor.
+ * @param value value of the new execution property.
+ * @return the newly created execution property.
+ */
+ public static ClonedSchedulingProperty of(final Integer value) {
+ if (value <= 0) {
+ throw new IllegalStateException(String.valueOf(value));
+ }
+ return new ClonedSchedulingProperty(value);
+ }
+}
diff --git a/common/src/test/java/edu/snu/nemo/common/StateMachineTest.java b/common/src/test/java/edu/snu/nemo/common/StateMachineTest.java
index 691efbd..7f49af5 100644
--- a/common/src/test/java/edu/snu/nemo/common/StateMachineTest.java
+++ b/common/src/test/java/edu/snu/nemo/common/StateMachineTest.java
@@ -16,6 +16,7 @@
package edu.snu.nemo.common;
import edu.snu.nemo.common.StateMachine;
+import edu.snu.nemo.common.exception.IllegalStateTransitionException;
import org.junit.Before;
import org.junit.Test;
@@ -34,7 +35,7 @@ public final class StateMachineTest {
}
@Test
- public void testSimpleStateTransitions() {
+ public void testSimpleStateTransitions() throws IllegalStateTransitionException {
stateMachineBuilder.addState(CookingState.SHOPPING, "Shopping for ingredients");
stateMachineBuilder.addState(CookingState.PREPARING, "Washing vegetables, chopping meat...");
stateMachineBuilder.addState(CookingState.SEASONING, "Adding salt and pepper");
diff --git a/compiler/backend/src/main/java/edu/snu/nemo/compiler/backend/nemo/NemoBackend.java b/compiler/backend/src/main/java/edu/snu/nemo/compiler/backend/nemo/NemoBackend.java
index da7adb6..4789a5c 100644
--- a/compiler/backend/src/main/java/edu/snu/nemo/compiler/backend/nemo/NemoBackend.java
+++ b/compiler/backend/src/main/java/edu/snu/nemo/compiler/backend/nemo/NemoBackend.java
@@ -19,7 +19,7 @@ import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.compiler.backend.Backend;
import edu.snu.nemo.common.ir.edge.IREdge;
import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.plan.PhysicalPlan;
import edu.snu.nemo.runtime.common.plan.PhysicalPlanGenerator;
import edu.snu.nemo.runtime.common.plan.Stage;
@@ -61,6 +61,6 @@ public final class NemoBackend implements Backend<PhysicalPlan> {
public PhysicalPlan compile(final DAG<IRVertex, IREdge> irDAG,
final PhysicalPlanGenerator physicalPlanGenerator) {
final DAG<Stage, StageEdge> stageDAG = physicalPlanGenerator.apply(irDAG);
- return new PhysicalPlan(RuntimeIdGenerator.generatePhysicalPlanId(), irDAG, stageDAG);
+ return new PhysicalPlan(RuntimeIdManager.generatePhysicalPlanId(), irDAG, stageDAG);
}
}
diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/ClonedSchedulingPass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/ClonedSchedulingPass.java
new file mode 100644
index 0000000..f808002
--- /dev/null
+++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/ClonedSchedulingPass.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating;
+
+import edu.snu.nemo.common.dag.DAG;
+import edu.snu.nemo.common.ir.edge.IREdge;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.common.ir.vertex.executionproperty.ClonedSchedulingProperty;
+
+import java.util.Collections;
+
+/**
+ * Set the ClonedScheduling property of source vertices.
+ */
+public final class ClonedSchedulingPass extends AnnotatingPass {
+ /**
+ * Default constructor.
+ */
+ public ClonedSchedulingPass() {
+ super(ClonedSchedulingProperty.class, Collections.emptySet());
+ }
+
+ @Override
+ public DAG<IRVertex, IREdge> apply(final DAG<IRVertex, IREdge> dag) {
+ dag.getVertices().stream()
+ .filter(vertex -> dag.getIncomingEdgesOf(vertex.getId()).isEmpty())
+ .forEach(vertex -> vertex.setProperty(ClonedSchedulingProperty.of(2)));
+ return dag;
+ }
+}
diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
index 1aa4dcf..4dbf455 100644
--- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
+++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
@@ -254,7 +254,7 @@ public final class DefaultScheduleGroupPass extends AnnotatingPass {
* Constructor.
*/
ScheduleGroup() {
- super(String.format("ScheduleGroup-%d", nextScheduleGroupId++));
+ super(String.format("ScheduleGroup%d", nextScheduleGroupId++));
}
}
diff --git a/compiler/test/src/test/java/edu/snu/nemo/compiler/backend/nemo/DAGConverterTest.java b/compiler/test/src/test/java/edu/snu/nemo/compiler/backend/nemo/DAGConverterTest.java
index 52cc006..207592a 100644
--- a/compiler/test/src/test/java/edu/snu/nemo/compiler/backend/nemo/DAGConverterTest.java
+++ b/compiler/test/src/test/java/edu/snu/nemo/compiler/backend/nemo/DAGConverterTest.java
@@ -99,8 +99,8 @@ public final class DAGConverterTest {
assertEquals(physicalDAG.getOutgoingEdgesOf(physicalStage1).size(), 1);
assertEquals(physicalDAG.getOutgoingEdgesOf(physicalStage2).size(), 0);
- assertEquals(physicalStage1.getTaskIds().size(), 3);
- assertEquals(physicalStage2.getTaskIds().size(), 2);
+ assertEquals(3, physicalStage1.getParallelism());
+ assertEquals(2, physicalStage2.getParallelism());
}
@Test
@@ -134,10 +134,6 @@ public final class DAGConverterTest {
v6.setProperty(ParallelismProperty.of(2));
v6.setProperty(ResourcePriorityProperty.of(ResourcePriorityProperty.RESERVED));
-// final IRVertex v7 = new OperatorVertex(t);
-// v7.setProperty(Parallelism.of(2));
-// v7.setProperty(ResourcePriorityProperty.of(ResourcePriorityProperty.COMPUTE));
-
final IRVertex v8 = new OperatorVertex(dt);
v8.setProperty(ParallelismProperty.of(2));
v8.setProperty(ResourcePriorityProperty.of(ResourcePriorityProperty.COMPUTE));
@@ -150,8 +146,6 @@ public final class DAGConverterTest {
irDAGBuilder.addVertex(v6);
irDAGBuilder.addVertex(v8);
-// irDAGBuilder.addVertex(v7);
-
final IREdge e1 = new IREdge(CommunicationPatternProperty.Value.OneToOne, v1, v2);
e1.setProperty(DataStoreProperty.of(DataStoreProperty.Value.MemoryStore));
e1.setProperty(DataFlowProperty.of(DataFlowProperty.Value.Pull));
@@ -175,76 +169,5 @@ public final class DAGConverterTest {
final IREdge e6 = new IREdge(CommunicationPatternProperty.Value.OneToOne, v4, v8);
e6.setProperty(DataStoreProperty.of(DataStoreProperty.Value.LocalFileStore));
e6.setProperty(DataFlowProperty.of(DataFlowProperty.Value.Pull));
-
-// final IREdge e7 = new IREdge(OneToOne, v7, v5);
-// e7.setProperty(DataStoreProperty.of(MemoryStore));
-// e7.setProperty(Attribute.Key.PullOrPush, DataFlowProperty.Value.Push));
-//
-// final IREdge e8 = new IREdge(OneToOne, v5, v8);
-// e8.setProperty(DataStoreProperty.of(MemoryStore));
-// e8.setProperty(Attribute.Key.PullOrPush, DataFlowProperty.Value.Pull));
-
- // Stage 1 = {v1, v2, v3}
- irDAGBuilder.connectVertices(e1);
- irDAGBuilder.connectVertices(e2);
-
- // Stage 2 = {v4}
- irDAGBuilder.connectVertices(e3);
-
- // Stage 3 = {v7}
- // Commented out since SimpleRuntime does not yet support multi-input.
-// physicalDAGBuilder.createNewStage();
-// physicalDAGBuilder.addVertex(v7);
-
- // Stage 4 = {v5, v8}
- irDAGBuilder.connectVertices(e4);
- irDAGBuilder.connectVertices(e6);
-
- // Commented out since SimpleRuntime does not yet support multi-input.
-// irDAGBuilder.connectVertices(e7);
-// irDAGBuilder.connectVertices(e8);
-
- // Stage 5 = {v6}
- irDAGBuilder.connectVertices(e5);
-
- final DAG<IRVertex, IREdge> irDAG = new TestPolicy().runCompileTimeOptimization(irDAGBuilder.build(),
- DAG.EMPTY_DAG_DIRECTORY);
- final DAG<Stage, StageEdge> logicalDAG = physicalPlanGenerator.stagePartitionIrDAG(irDAG);
-
- // Test Logical DAG
- final List<Stage> sortedLogicalDAG = logicalDAG.getTopologicalSort();
- final Stage stage1 = sortedLogicalDAG.get(0);
- final Stage stage2 = sortedLogicalDAG.get(1);
- final Stage stage3 = sortedLogicalDAG.get(2);
- final Stage stage4 = sortedLogicalDAG.get(3);
- final Stage stage5 = sortedLogicalDAG.get(3);
-
- // The following asserts depend on how stage partitioning is defined; test must be rewritten accordingly.
-// assertEquals(logicalDAG.getVertices().size(), 5);
-// assertEquals(logicalDAG.getIncomingEdgesOf(stage1).size(), 0);
-// assertEquals(logicalDAG.getIncomingEdgesOf(stage2).size(), 1);
-// assertEquals(logicalDAG.getIncomingEdgesOf(stage3).size(), 1);
-// assertEquals(logicalDAG.getIncomingEdgesOf(stage4).size(), 1);
-// assertEquals(logicalDAG.getOutgoingEdgesOf(stage1).size(), 2);
-// assertEquals(logicalDAG.getOutgoingEdgesOf(stage2).size(), 0);
-// assertEquals(logicalDAG.getOutgoingEdgesOf(stage3).size(), 1);
-// assertEquals(logicalDAG.getOutgoingEdgesOf(stage4).size(), 0);
-
- // Test Physical DAG
-
-// final DAG<Stage, StageEdge> physicalDAG = logicalDAG.convert(new PhysicalDAGGenerator());
-// final List<Stage> sortedPhysicalDAG = physicalDAG.getTopologicalSort();
-// final Stage physicalStage1 = sortedPhysicalDAG.get(0);
-// final Stage physicalStage2 = sortedPhysicalDAG.get(1);
-// assertEquals(physicalDAG.getVertices().size(), 2);
-// assertEquals(physicalDAG.getIncomingEdgesOf(physicalStage1).size(), 0);
-// assertEquals(physicalDAG.getIncomingEdgesOf(physicalStage2).size(), 1);
-// assertEquals(physicalDAG.getOutgoingEdgesOf(physicalStage1).size(), 1);
-// assertEquals(physicalDAG.getOutgoingEdgesOf(physicalStage2).size(), 0);
-//
-// final List<Task> taskList1 = physicalStage1.getTaskList();
-// final List<Task> taskList2 = physicalStage2.getTaskList();
-// assertEquals(taskList1.size(), 3);
-// assertEquals(taskList2.size(), 2);
}
}
diff --git a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/PerKeyMedianITCase.java b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/PerKeyMedianITCase.java
index adf2811..497939b 100644
--- a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/PerKeyMedianITCase.java
+++ b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/PerKeyMedianITCase.java
@@ -32,7 +32,7 @@ import org.powermock.modules.junit4.PowerMockRunner;
@RunWith(PowerMockRunner.class)
@PrepareForTest(JobLauncher.class)
public final class PerKeyMedianITCase {
- private static final int TIMEOUT = 120000;
+ private static final int TIMEOUT = 60 * 1000;
private static ArgBuilder builder;
private static final String fileBasePath = System.getProperty("user.dir") + "/../resources/";
diff --git a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/WordCountITCase.java b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/WordCountITCase.java
index 7d291a9..1211370 100644
--- a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/WordCountITCase.java
+++ b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/WordCountITCase.java
@@ -105,4 +105,14 @@ public final class WordCountITCase {
.addOptimizationPolicy(TransientResourcePolicyParallelismFive.class.getCanonicalName())
.build());
}
+
+ @Test (timeout = TIMEOUT)
+ public void testClonedScheduling() throws Exception {
+ JobLauncher.main(builder
+ .addResourceJson(executorResourceFileName)
+ .addJobId(WordCountITCase.class.getSimpleName() + "_clonedscheduling")
+ .addMaxTaskAttempt(Integer.MAX_VALUE)
+ .addOptimizationPolicy(ClonedSchedulingPolicyParallelismFive.class.getCanonicalName())
+ .build());
+ }
}
diff --git a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/ClonedSchedulingPolicyParallelismFive.java b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/ClonedSchedulingPolicyParallelismFive.java
new file mode 100644
index 0000000..937e951
--- /dev/null
+++ b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/ClonedSchedulingPolicyParallelismFive.java
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.examples.beam.policy;
+
+import edu.snu.nemo.common.dag.DAG;
+import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper;
+import edu.snu.nemo.common.ir.edge.IREdge;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass;
+import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.ClonedSchedulingPass;
+import edu.snu.nemo.compiler.optimizer.policy.DefaultPolicy;
+import edu.snu.nemo.compiler.optimizer.policy.Policy;
+import edu.snu.nemo.compiler.optimizer.policy.PolicyImpl;
+import org.apache.reef.tang.Injector;
+import java.util.List;
+
+/**
+ * A default policy with cloning for tests.
+ */
+public final class ClonedSchedulingPolicyParallelismFive implements Policy {
+ private final Policy policy;
+ public ClonedSchedulingPolicyParallelismFive() {
+ final List<CompileTimePass> overwritingPasses = DefaultPolicy.BUILDER.getCompileTimePasses();
+ overwritingPasses.add(new ClonedSchedulingPass()); // CLONING!
+ this.policy = new PolicyImpl(
+ PolicyTestUtil.overwriteParallelism(5, overwritingPasses),
+ DefaultPolicy.BUILDER.getRuntimePasses());
+ }
+ @Override
+ public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory)
+ throws Exception {
+ return this.policy.runCompileTimeOptimization(dag, dagDirectory);
+ }
+ @Override
+ public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) {
+ this.policy.registerRunTimeOptimizations(injector, pubSubWrapper);
+ }
+}
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/RuntimeIdGenerator.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/RuntimeIdManager.java
similarity index 58%
rename from runtime/common/src/main/java/edu/snu/nemo/runtime/common/RuntimeIdGenerator.java
rename to runtime/common/src/main/java/edu/snu/nemo/runtime/common/RuntimeIdManager.java
index 049ccbe..15d34c1 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/RuntimeIdGenerator.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/RuntimeIdManager.java
@@ -21,19 +21,17 @@ import java.util.concurrent.atomic.AtomicLong;
/**
* ID Generator.
*/
-public final class RuntimeIdGenerator {
+public final class RuntimeIdManager {
private static AtomicInteger physicalPlanIdGenerator = new AtomicInteger(0);
private static AtomicInteger executorIdGenerator = new AtomicInteger(0);
private static AtomicLong messageIdGenerator = new AtomicLong(1L);
private static AtomicLong resourceSpecIdGenerator = new AtomicLong(0);
- private static final String BLOCK_PREFIX = "Block-";
- private static final String BLOCK_ID_SPLITTER = "_";
- private static final String TASK_INFIX = "-Task-";
+ private static final String SPLITTER = "-";
/**
* Private constructor which will not be used.
*/
- private RuntimeIdGenerator() {
+ private RuntimeIdManager() {
}
@@ -46,7 +44,7 @@ public final class RuntimeIdGenerator {
* TODO #100: Refactor string-based RuntimeIdGenerator for IR-based DynOpt
*/
public static String generatePhysicalPlanId() {
- return "Plan-" + physicalPlanIdGenerator.get();
+ return "Plan" + physicalPlanIdGenerator.getAndIncrement();
}
/**
@@ -56,7 +54,7 @@ public final class RuntimeIdGenerator {
* @return the generated ID
*/
public static String generateStageEdgeId(final String irEdgeId) {
- return "SEdge-" + irEdgeId;
+ return "SEdge" + irEdgeId;
}
/**
@@ -65,18 +63,22 @@ public final class RuntimeIdGenerator {
* @return the generated ID
*/
public static String generateStageId(final Integer stageId) {
- return "Stage-" + stageId;
+ return "Stage" + stageId;
}
/**
* Generates the ID for a task.
*
- * @param index the index of this task.
* @param stageId the ID of the stage.
+ * @param index the index of this task.
+ * @param attempt the attempt of this task.
* @return the generated ID
*/
- public static String generateTaskId(final int index, final String stageId) {
- return stageId + TASK_INFIX + index;
+ public static String generateTaskId(final String stageId, final int index, final int attempt) {
+ if (index < 0 || attempt < 0) {
+ throw new IllegalStateException(index + ", " + attempt);
+ }
+ return stageId + SPLITTER + index + SPLITTER + attempt;
}
/**
@@ -85,19 +87,37 @@ public final class RuntimeIdGenerator {
* @return the generated ID
*/
public static String generateExecutorId() {
- return "Executor-" + executorIdGenerator.getAndIncrement();
+ return "Executor" + executorIdGenerator.getAndIncrement();
}
/**
- * Generates the ID for a block, whose data is the output of a task.
+ * Generates the ID for a block, whose data is the output of a task attempt.
*
* @param runtimeEdgeId of the block
- * @param producerTaskIndex of the block
+ * @param producerTaskId of the block
* @return the generated ID
*/
public static String generateBlockId(final String runtimeEdgeId,
- final int producerTaskIndex) {
- return BLOCK_PREFIX + runtimeEdgeId + BLOCK_ID_SPLITTER + producerTaskIndex;
+ final String producerTaskId) {
+ return runtimeEdgeId + SPLITTER + getIndexFromTaskId(producerTaskId)
+ + SPLITTER + getAttemptFromTaskId(producerTaskId);
+ }
+
+ /**
+ * The block ID wildcard indicates to use 'ANY' of the available blocks produced by different task attempts.
+ * (Notice that a task clone or a task retry leads to a new task attempt)
+ *
+ * Wildcard block id looks like SEdge4-1-* (task index = 1), where the '*' matches with any task attempts.
+ * For this example, the ids of the producer task attempts will look like [Stage1-1-0, Stage1-1-1, Stage1-1-2, ...],
+ * with the (1) task stage id corresponding to the outgoing edge, (2) task index = 1, and (3) all task attempts.
+ *
+ * @param runtimeEdgeId of the block
+ * @param producerTaskIndex of the block
+ * @return the generated WILDCARD ID
+ */
+ public static String generateBlockIdWildcard(final String runtimeEdgeId,
+ final int producerTaskIndex) {
+ return runtimeEdgeId + SPLITTER + producerTaskIndex + SPLITTER + "*";
}
/**
@@ -115,7 +135,7 @@ public final class RuntimeIdGenerator {
* @return the generated ID
*/
public static String generateResourceSpecId() {
- return "ResourceSpec-" + resourceSpecIdGenerator.getAndIncrement();
+ return "ResourceSpec" + resourceSpecIdGenerator.getAndIncrement();
}
//////////////////////////////////////////////////////////////// Parse IDs
@@ -127,7 +147,7 @@ public final class RuntimeIdGenerator {
* @return the runtime edge ID.
*/
public static String getRuntimeEdgeIdFromBlockId(final String blockId) {
- return parseBlockId(blockId)[0];
+ return split(blockId)[0];
}
/**
@@ -136,20 +156,18 @@ public final class RuntimeIdGenerator {
* @param blockId the block ID to extract.
* @return the task index.
*/
- public static String getTaskIndexFromBlockId(final String blockId) {
- return parseBlockId(blockId)[1];
+ public static int getTaskIndexFromBlockId(final String blockId) {
+ return Integer.valueOf(split(blockId)[1]);
}
/**
- * Parses a block id.
- * The result array will contain runtime edge id and task index in order.
+ * Extracts wild card from a block ID.
*
- * @param blockId to parse.
- * @return the array of parsed information.
+ * @param blockId the block ID to extract.
+ * @return the wild card.
*/
- private static String[] parseBlockId(final String blockId) {
- final String woPrefix = blockId.split(BLOCK_PREFIX)[1];
- return woPrefix.split(BLOCK_ID_SPLITTER);
+ public static String getWildCardFromBlockId(final String blockId) {
+ return generateBlockIdWildcard(getRuntimeEdgeIdFromBlockId(blockId), getTaskIndexFromBlockId(blockId));
}
/**
@@ -159,7 +177,7 @@ public final class RuntimeIdGenerator {
* @return the stage ID.
*/
public static String getStageIdFromTaskId(final String taskId) {
- return parseTaskId(taskId)[0];
+ return split(taskId)[0];
}
/**
@@ -169,17 +187,20 @@ public final class RuntimeIdGenerator {
* @return the index.
*/
public static int getIndexFromTaskId(final String taskId) {
- return Integer.valueOf(parseTaskId(taskId)[1]);
+ return Integer.valueOf(split(taskId)[1]);
}
/**
- * Parses a task id.
- * The result array will contain the stage id and the index of the task in order.
+ * Extracts the attempt from a task ID.
*
- * @param taskId to parse.
- * @return the array of parsed information.
+ * @param taskId the task ID to extract.
+ * @return the attempt.
*/
- private static String[] parseTaskId(final String taskId) {
- return taskId.split(TASK_INFIX);
+ public static int getAttemptFromTaskId(final String taskId) {
+ return Integer.valueOf(split(taskId)[2]);
+ }
+
+ private static String[] split(final String id) {
+ return id.split(SPLITTER);
}
}
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RunTimeOptimizer.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RunTimeOptimizer.java
index 30ad958..d9ca024 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RunTimeOptimizer.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RunTimeOptimizer.java
@@ -19,7 +19,7 @@ import edu.snu.nemo.common.Pair;
import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.ir.edge.IREdge;
import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.optimizer.pass.runtime.DataSkewRuntimePass;
import edu.snu.nemo.runtime.common.plan.PhysicalPlan;
import edu.snu.nemo.runtime.common.plan.PhysicalPlanGenerator;
@@ -60,7 +60,7 @@ public final class RunTimeOptimizer {
.apply(originalPlan.getIrDAG(), Pair.of(targetEdge, (Map<Integer, Long>) dynOptData));
final DAG<Stage, StageEdge> stageDAG = physicalPlanGenerator.apply(newIrDAG);
final PhysicalPlan physicalPlan =
- new PhysicalPlan(RuntimeIdGenerator.generatePhysicalPlanId(), newIrDAG, stageDAG);
+ new PhysicalPlan(RuntimeIdManager.generatePhysicalPlanId(), newIrDAG, stageDAG);
return physicalPlan;
} catch (final InjectionException e) {
throw new RuntimeException(e);
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
index 16f4bc4..57c89a1 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
@@ -31,7 +31,7 @@ import edu.snu.nemo.common.dag.DAGBuilder;
import edu.snu.nemo.common.ir.edge.IREdge;
import edu.snu.nemo.common.exception.IllegalVertexOperationException;
import edu.snu.nemo.common.exception.PhysicalPlanGenerationException;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.reef.tang.annotations.Parameter;
import org.slf4j.Logger;
@@ -142,7 +142,7 @@ public final class PhysicalPlanGenerator implements Function<DAG<IRVertex, IREdg
for (final int stageId : vertexSetForEachStage.keySet()) {
final Set<IRVertex> stageVertices = vertexSetForEachStage.get(stageId);
- final String stageIdentifier = RuntimeIdGenerator.generateStageId(stageId);
+ final String stageIdentifier = RuntimeIdManager.generateStageId(stageId);
final ExecutionPropertyMap<VertexExecutionProperty> stageProperties = new ExecutionPropertyMap<>(stageIdentifier);
stagePartitioner.getStageProperties(stageVertices.iterator().next()).forEach(stageProperties::put);
final int stageParallelism = stageProperties.get(ParallelismProperty.class)
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java
index c2abcc4..9882f88 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java
@@ -23,11 +23,9 @@ import edu.snu.nemo.common.ir.executionproperty.VertexExecutionProperty;
import edu.snu.nemo.common.ir.vertex.IRVertex;
import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import edu.snu.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
import org.apache.commons.lang3.SerializationUtils;
import java.io.Serializable;
-import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -75,17 +73,6 @@ public final class Stage extends Vertex {
}
/**
- * @return the list of the task IDs in this stage.
- */
- public List<String> getTaskIds() {
- final List<String> taskIds = new ArrayList<>();
- for (int taskIdx = 0; taskIdx < getParallelism(); taskIdx++) {
- taskIds.add(RuntimeIdGenerator.generateTaskId(taskIdx, getId()));
- }
- return taskIds;
- }
-
- /**
* @return the parallelism
*/
public int getParallelism() {
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java
index 3e95830..58fd314 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java
@@ -18,19 +18,19 @@ package edu.snu.nemo.runtime.common.plan;
import edu.snu.nemo.common.ir.Readable;
import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
import edu.snu.nemo.common.ir.executionproperty.VertexExecutionProperty;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import java.io.Serializable;
import java.util.*;
/**
- * A Task is a self-contained executable that can be executed on a machine.
+ * A Task (attempt) is a self-contained executable that can be executed on a machine.
*/
public final class Task implements Serializable {
private final String planId;
private final String taskId;
private final List<StageEdge> taskIncomingEdges;
private final List<StageEdge> taskOutgoingEdges;
- private final int attemptIdx;
private final ExecutionPropertyMap<VertexExecutionProperty> executionProperties;
private final byte[] serializedIRDag;
private final Map<String, Readable> irVertexIdToReadable;
@@ -39,8 +39,7 @@ public final class Task implements Serializable {
* Constructor.
*
* @param planId the id of the physical plan.
- * @param taskId the ID of the task.
- * @param attemptIdx the attempt index.
+ * @param taskId the ID of this task attempt.
* @param executionProperties {@link VertexExecutionProperty} map for the corresponding stage
* @param serializedIRDag the serialized DAG of the task.
* @param taskIncomingEdges the incoming edges of the task.
@@ -49,7 +48,6 @@ public final class Task implements Serializable {
*/
public Task(final String planId,
final String taskId,
- final int attemptIdx,
final ExecutionPropertyMap<VertexExecutionProperty> executionProperties,
final byte[] serializedIRDag,
final List<StageEdge> taskIncomingEdges,
@@ -57,7 +55,6 @@ public final class Task implements Serializable {
final Map<String, Readable> irVertexIdToReadable) {
this.planId = planId;
this.taskId = taskId;
- this.attemptIdx = attemptIdx;
this.executionProperties = executionProperties;
this.serializedIRDag = serializedIRDag;
this.taskIncomingEdges = taskIncomingEdges;
@@ -104,7 +101,7 @@ public final class Task implements Serializable {
* @return the attempt index.
*/
public int getAttemptIdx() {
- return attemptIdx;
+ return RuntimeIdManager.getAttemptFromTaskId(taskId);
}
/**
@@ -141,7 +138,7 @@ public final class Task implements Serializable {
sb.append(" / taskId: ");
sb.append(taskId);
sb.append(" / attempt: ");
- sb.append(attemptIdx);
+ sb.append(getAttemptIdx());
sb.append(" / incoming: ");
sb.append(taskIncomingEdges);
sb.append(" / outgoing: ");
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/BlockState.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/BlockState.java
index 65ada21..b2329ee 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/BlockState.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/BlockState.java
@@ -35,18 +35,21 @@ public final class BlockState {
stateMachineBuilder.addState(State.IN_PROGRESS, "The block is in the progress of being created.");
stateMachineBuilder.addState(State.AVAILABLE, "The block is available.");
- // Add transitions
- stateMachineBuilder.addTransition(State.NOT_AVAILABLE, State.IN_PROGRESS,
- "The task that produces the block is scheduled.");
+ // From IN_PROGRESS
stateMachineBuilder.addTransition(State.IN_PROGRESS, State.AVAILABLE, "The block is successfully created");
-
stateMachineBuilder.addTransition(State.IN_PROGRESS, State.NOT_AVAILABLE,
"The block is lost before being created");
+
+ // From AVAILABLE
stateMachineBuilder.addTransition(State.AVAILABLE, State.NOT_AVAILABLE, "The block is lost");
+
+ // From NOT_AVAILABLE
+ stateMachineBuilder.addTransition(State.NOT_AVAILABLE, State.IN_PROGRESS,
+ "The task that produces the block is scheduled.");
stateMachineBuilder.addTransition(State.NOT_AVAILABLE, State.NOT_AVAILABLE,
"A block can be reported lost from multiple sources");
- stateMachineBuilder.setInitialState(State.NOT_AVAILABLE);
+ stateMachineBuilder.setInitialState(State.IN_PROGRESS);
return stateMachineBuilder.build();
}
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java
index be83bf2..c40e89d 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java
@@ -56,7 +56,6 @@ public final class TaskState {
stateMachineBuilder.addTransition(State.COMPLETE, State.SHOULD_RETRY, "Completed before, but should be retried");
// From SHOULD_RETRY
- stateMachineBuilder.addTransition(State.SHOULD_RETRY, State.READY, "Ready to be retried");
stateMachineBuilder.addTransition(State.SHOULD_RETRY, State.SHOULD_RETRY,
"SHOULD_RETRY can be caused by multiple reasons");
diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto
index 911990e..9e29e43 100644
--- a/runtime/common/src/main/proto/ControlMessage.proto
+++ b/runtime/common/src/main/proto/ControlMessage.proto
@@ -126,7 +126,7 @@ message PartitionSizeEntry {
message RequestBlockLocationMsg {
required string executorId = 1;
- required string blockId = 2;
+ required string blockIdWildcard = 2;
}
message ExecutorFailedMsg {
diff --git a/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java b/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java
index 14bffb5..9677cac 100644
--- a/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java
+++ b/runtime/driver/src/main/java/edu/snu/nemo/driver/NemoDriver.java
@@ -18,7 +18,7 @@ package edu.snu.nemo.driver;
import edu.snu.nemo.common.ir.IdManager;
import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.ResourceSitePass;
import edu.snu.nemo.conf.JobConf;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.message.MessageParameters;
import edu.snu.nemo.runtime.master.ClientRPC;
@@ -146,7 +146,7 @@ public final class NemoDriver {
public final class AllocatedEvaluatorHandler implements EventHandler<AllocatedEvaluator> {
@Override
public void onNext(final AllocatedEvaluator allocatedEvaluator) {
- final String executorId = RuntimeIdGenerator.generateExecutorId();
+ final String executorId = RuntimeIdManager.generateExecutorId();
runtimeMaster.onContainerAllocated(executorId, allocatedEvaluator,
getExecutorConfiguration(executorId));
}
diff --git a/runtime/executor/pom.xml b/runtime/executor/pom.xml
index 0997f8f..d2737e9 100644
--- a/runtime/executor/pom.xml
+++ b/runtime/executor/pom.xml
@@ -69,5 +69,15 @@ limitations under the License.
<version>0.1-SNAPSHOT</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <!--
+ This is needed to view the logs when running unit tests.
+ See https://dzone.com/articles/how-configure-slf4j-different for details.
+ -->
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-simple</artifactId>
+ <version>1.6.2</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</project>
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java
index 52e153e..b4bdcd7 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java
@@ -25,7 +25,7 @@ import edu.snu.nemo.common.ir.vertex.IRVertex;
import edu.snu.nemo.conf.JobConf;
import edu.snu.nemo.common.exception.IllegalMessageException;
import edu.snu.nemo.common.exception.UnknownFailureCauseException;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.message.MessageContext;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
@@ -135,7 +135,7 @@ public final class Executor {
} catch (final Exception e) {
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.ExecutorFailed)
.setExecutorFailedMsg(ControlMessage.ExecutorFailedMsg.newBuilder()
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/MetricManagerWorker.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/MetricManagerWorker.java
index d8f3b41..8e51cda 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/MetricManagerWorker.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/MetricManagerWorker.java
@@ -16,7 +16,7 @@
package edu.snu.nemo.runtime.executor;
import com.google.protobuf.ByteString;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.common.exception.UnknownFailureCauseException;
@@ -57,7 +57,7 @@ public final class MetricManagerWorker implements MetricMessageSender {
flushMetricMessageQueueToMaster();
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.MetricFlushed)
.build());
@@ -79,7 +79,7 @@ public final class MetricManagerWorker implements MetricMessageSender {
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.MetricMessageReceived)
.setMetricMsg(metricMsgBuilder.build())
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskStateManager.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskStateManager.java
index a707b03..f82fd11 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskStateManager.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskStateManager.java
@@ -17,7 +17,7 @@ package edu.snu.nemo.runtime.executor;
import edu.snu.nemo.common.exception.UnknownExecutionStateException;
import edu.snu.nemo.common.exception.UnknownFailureCauseException;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.common.message.PersistentConnectionToMasterMap;
@@ -126,7 +126,7 @@ public final class TaskStateManager {
// Send taskStateChangedMsg to master!
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.TaskStateChanged)
.setTaskStateChangedMsg(msgBuilder.build())
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java
index 4cf0b12..692d01f 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java
@@ -24,10 +24,10 @@ import edu.snu.nemo.common.exception.UnsupportedExecutionPropertyException;
import edu.snu.nemo.common.ir.edge.executionproperty.DataStoreProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.DataPersistenceProperty;
import edu.snu.nemo.conf.JobConf;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.comm.ControlMessage.ByteTransferContextDescriptor;
import edu.snu.nemo.common.KeyRange;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.common.message.PersistentConnectionToMasterMap;
import edu.snu.nemo.runtime.executor.bytetransfer.ByteInputContext;
@@ -139,20 +139,20 @@ public final class BlockManagerWorker {
* or to the lower data plane.
* This can be invoked multiple times per blockId (maybe due to failures).
*
- * @param blockId of the block.
- * @param runtimeEdgeId id of the runtime edge that corresponds to the block.
- * @param blockStore for the data storage.
- * @param keyRange the key range descriptor
+ * @param blockIdWildcard of the block.
+ * @param runtimeEdgeId id of the runtime edge that corresponds to the block.
+ * @param blockStore for the data storage.
+ * @param keyRange the key range descriptor
* @return the {@link CompletableFuture} of the block.
*/
public CompletableFuture<DataUtil.IteratorWithNumBytes> readBlock(
- final String blockId,
+ final String blockIdWildcard,
final String runtimeEdgeId,
final DataStoreProperty.Value blockStore,
final KeyRange keyRange) {
// Let's see if a remote worker has it
final CompletableFuture<ControlMessage.Message> blockLocationFuture =
- pendingBlockLocationRequest.computeIfAbsent(blockId, blockIdToRequest -> {
+ pendingBlockLocationRequest.computeIfAbsent(blockIdWildcard, blockIdToRequest -> {
// Ask Master for the location.
// (IMPORTANT): This 'request' effectively blocks the TaskExecutor thread if the block is IN_PROGRESS.
// We use this property to make the receiver task of a 'push' edge to wait in an Executor for its input data
@@ -160,19 +160,19 @@ public final class BlockManagerWorker {
final CompletableFuture<ControlMessage.Message> responseFromMasterFuture = persistentConnectionToMasterMap
.getMessageSender(MessageEnvironment.BLOCK_MANAGER_MASTER_MESSAGE_LISTENER_ID).request(
ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.BLOCK_MANAGER_MASTER_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.RequestBlockLocation)
.setRequestBlockLocationMsg(
ControlMessage.RequestBlockLocationMsg.newBuilder()
.setExecutorId(executorId)
- .setBlockId(blockId)
+ .setBlockIdWildcard(blockIdWildcard)
.build())
.build());
return responseFromMasterFuture;
});
blockLocationFuture.whenComplete((message, throwable) -> {
- pendingBlockLocationRequest.remove(blockId);
+ pendingBlockLocationRequest.remove(blockIdWildcard);
});
// Using thenCompose so that fetching block data starts after getting response from master.
@@ -185,10 +185,12 @@ public final class BlockManagerWorker {
responseFromMaster.getBlockLocationInfoMsg();
if (!blockLocationInfoMsg.hasOwnerExecutorId()) {
throw new BlockFetchException(new Throwable(
- "Block " + blockId + " location unknown: "
+ "Block " + blockIdWildcard + " location unknown: "
+ "The block state is " + blockLocationInfoMsg.getState()));
}
+
// This is the executor id that we wanted to know
+ final String blockId = blockLocationInfoMsg.getBlockId();
final String targetExecutorId = blockLocationInfoMsg.getOwnerExecutorId();
if (targetExecutorId.equals(executorId) || targetExecutorId.equals(REMOTE_FILE_STORE)) {
// Block resides in the evaluator
@@ -232,7 +234,6 @@ public final class BlockManagerWorker {
* @param blockStore the store to save the block.
* @param reportPartitionSizes whether report the size of partitions to master or not.
* @param partitionSizeMap the map of partition keys and sizes to report.
- * @param srcIRVertexId the IR vertex ID of the source task.
* @param expectedReadTotal the expected number of read for this block.
* @param persistence how to handle the used block.
*/
@@ -240,7 +241,6 @@ public final class BlockManagerWorker {
final DataStoreProperty.Value blockStore,
final boolean reportPartitionSizes,
final Map<Integer, Long> partitionSizeMap,
- final String srcIRVertexId,
final int expectedReadTotal,
final DataPersistenceProperty.Value persistence) {
final String blockId = block.getId();
@@ -273,7 +273,7 @@ public final class BlockManagerWorker {
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.BLOCK_MANAGER_MASTER_MESSAGE_LISTENER_ID)
.send(ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.BLOCK_MANAGER_MASTER_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.BlockStateChanged)
.setBlockStateChangedMsg(blockStateChangedMsgBuilder.build())
@@ -292,7 +292,7 @@ public final class BlockManagerWorker {
// TODO #4: Refactor metric aggregation for (general) run-rime optimization.
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
.send(ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.DataSizeMetric)
.setDataSizeMetricMsg(ControlMessage.DataSizeMetricMsg.newBuilder()
@@ -329,7 +329,7 @@ public final class BlockManagerWorker {
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.BLOCK_MANAGER_MASTER_MESSAGE_LISTENER_ID)
.send(ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.BLOCK_MANAGER_MASTER_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.BlockStateChanged)
.setBlockStateChangedMsg(blockStateChangedMsgBuilder)
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/stores/AbstractBlockStore.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/stores/AbstractBlockStore.java
index 761dbed..32d4615 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/stores/AbstractBlockStore.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/stores/AbstractBlockStore.java
@@ -15,7 +15,7 @@
*/
package edu.snu.nemo.runtime.executor.data.stores;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.executor.data.SerializerManager;
import edu.snu.nemo.runtime.executor.data.streamchainer.Serializer;
@@ -41,7 +41,7 @@ public abstract class AbstractBlockStore implements BlockStore {
* @return the coder.
*/
protected final Serializer getSerializerFromWorker(final String blockId) {
- final String runtimeEdgeId = RuntimeIdGenerator.getRuntimeEdgeIdFromBlockId(blockId);
+ final String runtimeEdgeId = RuntimeIdManager.getRuntimeEdgeIdFromBlockId(blockId);
return serializerManager.getSerializer(runtimeEdgeId);
}
}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferFactory.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferFactory.java
index 36db3ec..c96df7b 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferFactory.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferFactory.java
@@ -41,18 +41,15 @@ public final class DataTransferFactory {
/**
* Creates an {@link OutputWriter} between two stages.
*
- * @param srcIRVertex the {@link IRVertex} that outputs the data to be written.
- * @param srcTaskIdx the index of the source task.
+ * @param srcTaskId the id of the source task.
* @param dstIRVertex the {@link IRVertex} that will take the output data as its input.
* @param runtimeEdge that connects the srcTask to the tasks belonging to dstIRVertex.
* @return the {@link OutputWriter} created.
*/
- public OutputWriter createWriter(final IRVertex srcIRVertex,
- final int srcTaskIdx,
+ public OutputWriter createWriter(final String srcTaskId,
final IRVertex dstIRVertex,
final RuntimeEdge<?> runtimeEdge) {
- return new OutputWriter(hashRangeMultiplier, srcTaskIdx,
- srcIRVertex.getId(), dstIRVertex, runtimeEdge, blockManagerWorker);
+ return new OutputWriter(hashRangeMultiplier, srcTaskId, dstIRVertex, runtimeEdge, blockManagerWorker);
}
/**
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java
index 9870259..eadd8bf 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java
@@ -24,7 +24,7 @@ import edu.snu.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupPropertyValue;
import edu.snu.nemo.common.ir.vertex.IRVertex;
import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.common.KeyRange;
import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
import edu.snu.nemo.runtime.common.plan.StageEdge;
@@ -88,10 +88,10 @@ public final class InputReader extends DataTransfer {
}
private CompletableFuture<DataUtil.IteratorWithNumBytes> readOneToOne() {
- final String blockId = getBlockId(dstTaskIndex);
+ final String blockIdWildcard = generateWildCardBlockId(dstTaskIndex);
final Optional<DataStoreProperty.Value> dataStoreProperty
= runtimeEdge.getPropertyValue(DataStoreProperty.class);
- return blockManagerWorker.readBlock(blockId, getId(), dataStoreProperty.get(), HashRange.all());
+ return blockManagerWorker.readBlock(blockIdWildcard, getId(), dataStoreProperty.get(), HashRange.all());
}
private List<CompletableFuture<DataUtil.IteratorWithNumBytes>> readBroadcast() {
@@ -101,8 +101,8 @@ public final class InputReader extends DataTransfer {
final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = new ArrayList<>();
for (int srcTaskIdx = 0; srcTaskIdx < numSrcTasks; srcTaskIdx++) {
- final String blockId = getBlockId(srcTaskIdx);
- futures.add(blockManagerWorker.readBlock(blockId, getId(), dataStoreProperty.get(), HashRange.all()));
+ final String blockIdWildcard = generateWildCardBlockId(srcTaskIdx);
+ futures.add(blockManagerWorker.readBlock(blockIdWildcard, getId(), dataStoreProperty.get(), HashRange.all()));
}
return futures;
@@ -128,32 +128,27 @@ public final class InputReader extends DataTransfer {
final int numSrcTasks = this.getSourceParallelism();
final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = new ArrayList<>();
for (int srcTaskIdx = 0; srcTaskIdx < numSrcTasks; srcTaskIdx++) {
- final String blockId = getBlockId(srcTaskIdx);
+ final String blockIdWildcard = generateWildCardBlockId(srcTaskIdx);
futures.add(
- blockManagerWorker.readBlock(blockId, getId(), dataStoreProperty.get(), hashRangeToRead));
+ blockManagerWorker.readBlock(blockIdWildcard, getId(), dataStoreProperty.get(), hashRangeToRead));
}
return futures;
}
- public RuntimeEdge getRuntimeEdge() {
- return runtimeEdge;
- }
-
/**
- * Get block id.
- *
- * @param taskIdx task index of the block
- * @return the block id
+ * See {@link RuntimeIdManager#generateBlockIdWildcard(String, int)} for information on block wildcards.
+ * @param producerTaskIndex to use.
+ * @return wildcard block id that corresponds to "ANY" task attempt of the task index.
*/
- private String getBlockId(final int taskIdx) {
+ private String generateWildCardBlockId(final int producerTaskIndex) {
final Optional<DuplicateEdgeGroupPropertyValue> duplicateDataProperty =
runtimeEdge.getPropertyValue(DuplicateEdgeGroupProperty.class);
if (!duplicateDataProperty.isPresent() || duplicateDataProperty.get().getGroupSize() <= 1) {
- return RuntimeIdGenerator.generateBlockId(getId(), taskIdx);
+ return RuntimeIdManager.generateBlockIdWildcard(getId(), producerTaskIndex);
}
final String duplicateEdgeId = duplicateDataProperty.get().getRepresentativeEdgeId();
- return RuntimeIdGenerator.generateBlockId(duplicateEdgeId, taskIdx);
+ return RuntimeIdManager.generateBlockIdWildcard(duplicateEdgeId, producerTaskIndex);
}
public IRVertex getSrcIrVertex() {
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputWriter.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputWriter.java
index 71d810a..162f491 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputWriter.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputWriter.java
@@ -20,7 +20,7 @@ import edu.snu.nemo.common.exception.*;
import edu.snu.nemo.common.ir.edge.executionproperty.*;
import edu.snu.nemo.common.ir.vertex.IRVertex;
import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
import edu.snu.nemo.runtime.executor.data.BlockManagerWorker;
import edu.snu.nemo.runtime.executor.data.block.Block;
@@ -32,9 +32,7 @@ import java.util.*;
* Represents the output data transfer from a task.
*/
public final class OutputWriter extends DataTransfer implements AutoCloseable {
- private final String blockId;
private final RuntimeEdge<?> runtimeEdge;
- private final String srcVertexId;
private final IRVertex dstIrVertex;
private final DataStoreProperty.Value blockStoreValue;
private final BlockManagerWorker blockManagerWorker;
@@ -47,22 +45,18 @@ public final class OutputWriter extends DataTransfer implements AutoCloseable {
* Constructor.
*
* @param hashRangeMultiplier the {@link edu.snu.nemo.conf.JobConf.HashRangeMultiplier}.
- * @param srcTaskIdx the index of the source task.
- * @param srcRuntimeVertexId the ID of the source vertex.
+ * @param srcTaskId the id of the source task.
* @param dstIrVertex the destination IR vertex.
* @param runtimeEdge the {@link RuntimeEdge}.
* @param blockManagerWorker the {@link BlockManagerWorker}.
*/
OutputWriter(final int hashRangeMultiplier,
- final int srcTaskIdx,
- final String srcRuntimeVertexId,
+ final String srcTaskId,
final IRVertex dstIrVertex,
final RuntimeEdge<?> runtimeEdge,
final BlockManagerWorker blockManagerWorker) {
super(runtimeEdge.getId());
- this.blockId = RuntimeIdGenerator.generateBlockId(getId(), srcTaskIdx);
this.runtimeEdge = runtimeEdge;
- this.srcVertexId = srcRuntimeVertexId;
this.dstIrVertex = dstIrVertex;
this.blockManagerWorker = blockManagerWorker;
this.blockStoreValue = runtimeEdge.getPropertyValue(DataStoreProperty.class).
@@ -95,8 +89,8 @@ public final class OutputWriter extends DataTransfer implements AutoCloseable {
throw new UnsupportedPartitionerException(
new Throwable("Partitioner " + partitionerPropertyValue + " is not supported."));
}
- blockToWrite = blockManagerWorker.createBlock(blockId, blockStoreValue);
-
+ blockToWrite = blockManagerWorker.createBlock(
+ RuntimeIdManager.generateBlockId(getId(), srcTaskId), blockStoreValue);
final Optional<DuplicateEdgeGroupPropertyValue> duplicateDataProperty =
runtimeEdge.getPropertyValue(DuplicateEdgeGroupProperty.class);
nonDummyBlock = !duplicateDataProperty.isPresent()
@@ -142,11 +136,11 @@ public final class OutputWriter extends DataTransfer implements AutoCloseable {
}
this.writtenBytes = blockSizeTotal;
blockManagerWorker.writeBlock(blockToWrite, blockStoreValue, isDataSizeMetricCollectionEdge,
- partitionSizeMap.get(), srcVertexId, getExpectedRead(), persistence);
+ partitionSizeMap.get(), getExpectedRead(), persistence);
} else {
this.writtenBytes = -1; // no written bytes info.
blockManagerWorker.writeBlock(blockToWrite, blockStoreValue, isDataSizeMetricCollectionEdge,
- Collections.emptyMap(), srcVertexId, getExpectedRead(), persistence);
+ Collections.emptyMap(), getExpectedRead(), persistence);
}
}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
index 7d1901d..f2b6552 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
@@ -23,7 +23,7 @@ import edu.snu.nemo.common.ir.Readable;
import edu.snu.nemo.common.ir.edge.executionproperty.AdditionalOutputTagProperty;
import edu.snu.nemo.common.ir.vertex.*;
import edu.snu.nemo.common.ir.vertex.transform.Transform;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.common.message.PersistentConnectionToMasterMap;
@@ -133,7 +133,7 @@ public final class TaskExecutor {
private Pair<List<DataFetcher>, List<VertexHarness>> prepare(final Task task,
final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag,
final DataTransferFactory dataTransferFactory) {
- final int taskIndex = RuntimeIdGenerator.getIndexFromTaskId(task.getTaskId());
+ final int taskIndex = RuntimeIdManager.getIndexFromTaskId(task.getTaskId());
// Traverse in a reverse-topological order to ensure that each visited vertex's children vertices exist.
final List<IRVertex> reverseTopologicallySorted = Lists.reverse(irVertexDag.getTopologicalSort());
@@ -164,10 +164,10 @@ public final class TaskExecutor {
// Handle writes
// Main output children task writes
final List<OutputWriter> mainChildrenTaskWriters = getMainChildrenTaskWriters(
- taskIndex, irVertex, task.getTaskOutgoingEdges(), dataTransferFactory, additionalOutputMap);
+ irVertex, task.getTaskOutgoingEdges(), dataTransferFactory, additionalOutputMap);
// Additional output children task writes
final Map<String, OutputWriter> additionalChildrenTaskWriters = getAdditionalChildrenTaskWriters(
- taskIndex, irVertex, task.getTaskOutgoingEdges(), dataTransferFactory, additionalOutputMap);
+ irVertex, task.getTaskOutgoingEdges(), dataTransferFactory, additionalOutputMap);
// Find all main vertices and additional vertices
final List<String> additionalOutputVertices = new ArrayList<>(additionalOutputMap.values());
final Set<String> mainChildren =
@@ -494,15 +494,13 @@ public final class TaskExecutor {
/**
* Return inter-task OutputWriters, for single output or output associated with main tag.
- * @param taskIndex current task index
* @param irVertex source irVertex
* @param outEdgesToChildrenTasks outgoing edges to child tasks
* @param dataTransferFactory dataTransferFactory
* @param taggedOutputs tag to vertex id map
* @return OutputWriters for main children tasks
*/
- private List<OutputWriter> getMainChildrenTaskWriters(final int taskIndex,
- final IRVertex irVertex,
+ private List<OutputWriter> getMainChildrenTaskWriters(final IRVertex irVertex,
final List<StageEdge> outEdgesToChildrenTasks,
final DataTransferFactory dataTransferFactory,
final Map<String, String> taggedOutputs) {
@@ -511,21 +509,19 @@ public final class TaskExecutor {
.filter(outEdge -> outEdge.getSrcIRVertex().getId().equals(irVertex.getId()))
.filter(outEdge -> !taggedOutputs.containsValue(outEdge.getDstIRVertex().getId()))
.map(outEdgeForThisVertex -> dataTransferFactory
- .createWriter(irVertex, taskIndex, outEdgeForThisVertex.getDstIRVertex(), outEdgeForThisVertex))
+ .createWriter(taskId, outEdgeForThisVertex.getDstIRVertex(), outEdgeForThisVertex))
.collect(Collectors.toList());
}
/**
* Return inter-task OutputWriters associated with additional output tags.
- * @param taskIndex current task index
* @param irVertex source irVertex
* @param outEdgesToChildrenTasks outgoing edges to child tasks
* @param dataTransferFactory dataTransferFactory
* @param taggedOutputs tag to vertex id map
* @return additional children vertex id to OutputWriters map.
*/
- private Map<String, OutputWriter> getAdditionalChildrenTaskWriters(final int taskIndex,
- final IRVertex irVertex,
+ private Map<String, OutputWriter> getAdditionalChildrenTaskWriters(final IRVertex irVertex,
final List<StageEdge> outEdgesToChildrenTasks,
final DataTransferFactory dataTransferFactory,
final Map<String, String> taggedOutputs) {
@@ -537,8 +533,7 @@ public final class TaskExecutor {
.filter(outEdge -> taggedOutputs.containsValue(outEdge.getDstIRVertex().getId()))
.forEach(outEdgeForThisVertex -> {
additionalChildrenTaskWriters.put(outEdgeForThisVertex.getDstIRVertex().getId(),
- dataTransferFactory.createWriter(irVertex, taskIndex, outEdgeForThisVertex.getDstIRVertex(),
- outEdgeForThisVertex));
+ dataTransferFactory.createWriter(taskId, outEdgeForThisVertex.getDstIRVertex(), outEdgeForThisVertex));
});
return additionalChildrenTaskWriters;
@@ -578,7 +573,7 @@ public final class TaskExecutor {
vertexHarness.getContext().getSerializedData().ifPresent(data ->
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.ExecutorDataCollected)
.setDataCollected(ControlMessage.DataCollectMessage.newBuilder().setData(data).build())
diff --git a/common/src/main/java/edu/snu/nemo/common/exception/IllegalStateTransitionException.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/TestUtil.java
similarity index 51%
copy from common/src/main/java/edu/snu/nemo/common/exception/IllegalStateTransitionException.java
copy to runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/TestUtil.java
index f1a4b9f..28e455c 100644
--- a/common/src/main/java/edu/snu/nemo/common/exception/IllegalStateTransitionException.java
+++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/TestUtil.java
@@ -13,18 +13,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package edu.snu.nemo.common.exception;
+package edu.snu.nemo.runtime.executor;
-/**
- * IllegalStateTransitionException.
- * Thrown when the execution state transition is illegal.
- */
-public final class IllegalStateTransitionException extends RuntimeException {
- /**
- * IllegalStateTransitionException.
- * @param cause cause
- */
- public IllegalStateTransitionException(final Throwable cause) {
- super(cause);
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
+import edu.snu.nemo.runtime.common.plan.Stage;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public final class TestUtil {
+ public static List<String> generateTaskIds(final Stage stage) {
+ final List<String> result = new ArrayList<>(stage.getParallelism());
+ final int first_attempt = 0;
+ for (int taskIndex = 0; taskIndex < stage.getParallelism(); taskIndex++) {
+ result.add(RuntimeIdManager.generateTaskId(stage.getId(), taskIndex, first_attempt));
+ }
+ return result;
}
}
diff --git a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/data/BlockStoreTest.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/data/BlockStoreTest.java
index 61adee2..7c1cf93 100644
--- a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/data/BlockStoreTest.java
+++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/data/BlockStoreTest.java
@@ -19,13 +19,14 @@ import edu.snu.nemo.common.Pair;
import edu.snu.nemo.common.coder.*;
import edu.snu.nemo.common.ir.edge.executionproperty.CompressionProperty;
import edu.snu.nemo.conf.JobConf;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.common.HashRange;
import edu.snu.nemo.common.KeyRange;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.common.message.local.LocalMessageDispatcher;
import edu.snu.nemo.runtime.common.message.local.LocalMessageEnvironment;
import edu.snu.nemo.runtime.common.state.BlockState;
+import edu.snu.nemo.runtime.executor.TestUtil;
import edu.snu.nemo.runtime.executor.data.block.Block;
import edu.snu.nemo.runtime.executor.data.partition.NonSerializedPartition;
import edu.snu.nemo.runtime.executor.data.streamchainer.DecompressionStreamChainer;
@@ -51,6 +52,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.*;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
@@ -99,6 +101,10 @@ public final class BlockStoreTest {
private List<KeyRange> readKeyRangeList;
private List<List<Iterable>> expectedDataInRange;
+ private String getTaskId(final int index) {
+ return RuntimeIdManager.generateTaskId("STAGE", index, 0);
+ }
+
/**
* Generates the ids and the data which will be used for the block store tests.
*/
@@ -121,14 +127,12 @@ public final class BlockStoreTest {
IntStream.range(0, NUM_READ_VERTICES).forEach(number -> readTaskIdList.add("Read_IR_vertex"));
// Generates the ids and the data of the blocks to be used.
- final String shuffleEdge = RuntimeIdGenerator.generateStageEdgeId("shuffle_edge");
+ final String shuffleEdge = RuntimeIdManager.generateStageEdgeId("shuffle_edge");
IntStream.range(0, NUM_WRITE_VERTICES).forEach(writeTaskIdx -> {
// Create a block for each writer task.
- final String blockId = RuntimeIdGenerator.generateBlockId(shuffleEdge, writeTaskIdx);
+ final String taskId = getTaskId(writeTaskIdx);
+ final String blockId = RuntimeIdManager.generateBlockId(shuffleEdge, taskId);
blockIdList.add(blockId);
- blockManagerMaster.initializeState(blockId, "Unused");
- blockManagerMaster.onBlockStateChanged(
- blockId, BlockState.State.IN_PROGRESS, null);
// Create blocks for this block.
final List<NonSerializedPartition<Integer>> partitionsForBlock = new ArrayList<>(NUM_READ_VERTICES);
@@ -141,15 +145,11 @@ public final class BlockStoreTest {
});
// Following part is for the concurrent read test.
- final String writeTaskId = "conc_write_IR_vertex";
final List<String> concReadTaskIdList = new ArrayList<>(NUM_CONC_READ_TASKS);
- final String concEdge = RuntimeIdGenerator.generateStageEdgeId("conc_read_edge");
+ final String concEdge = RuntimeIdManager.generateStageEdgeId("conc_read_edge");
// Generates the ids and the data to be used.
- concBlockId = RuntimeIdGenerator.generateBlockId(concEdge, NUM_WRITE_VERTICES + NUM_READ_VERTICES + 1);
- blockManagerMaster.initializeState(concBlockId, "unused");
- blockManagerMaster.onBlockStateChanged(
- concBlockId, BlockState.State.IN_PROGRESS, null);
+ concBlockId = RuntimeIdManager.generateBlockId(concEdge, getTaskId(NUM_WRITE_VERTICES + NUM_READ_VERTICES + 1));
IntStream.range(0, NUM_CONC_READ_TASKS).forEach(number -> concReadTaskIdList.add("conc_read_IR_vertex"));
concBlockPartition = new NonSerializedPartition(0, getRangedNumList(0, CONC_READ_DATA_SIZE), -1, -1);
@@ -165,16 +165,13 @@ public final class BlockStoreTest {
// Generates the ids of the tasks to be used.
IntStream.range(0, NUM_WRITE_HASH_TASKS).forEach(number -> writeHashTaskIdList.add("hash_write_IR_vertex"));
IntStream.range(0, NUM_READ_HASH_TASKS).forEach(number -> readHashTaskIdList.add("hash_read_IR_vertex"));
- final String hashEdge = RuntimeIdGenerator.generateStageEdgeId("hash_edge");
+ final String hashEdge = RuntimeIdManager.generateStageEdgeId("hash_edge");
// Generates the ids and the data of the blocks to be used.
IntStream.range(0, NUM_WRITE_HASH_TASKS).forEach(writeTaskIdx -> {
- final String blockId = RuntimeIdGenerator.generateBlockId(
- hashEdge, NUM_WRITE_VERTICES + NUM_READ_VERTICES + 1 + writeTaskIdx);
+ final String taskId = getTaskId(NUM_WRITE_VERTICES + NUM_READ_VERTICES + 1 + writeTaskIdx);
+ final String blockId = RuntimeIdManager.generateBlockId(hashEdge, taskId);
hashedBlockIdList.add(blockId);
- blockManagerMaster.initializeState(blockId, "Unused");
- blockManagerMaster.onBlockStateChanged(
- blockId, BlockState.State.IN_PROGRESS, null);
final List<NonSerializedPartition<Integer>> hashedBlock = new ArrayList<>(HASH_RANGE);
// Generates the data having each hash value.
IntStream.range(0, HASH_RANGE).forEach(hashValue ->
@@ -317,6 +314,7 @@ public final class BlockStoreTest {
}
block.commit();
writerSideStore.writeBlock(block);
+ blockManagerMaster.onProducerTaskScheduled(getTaskId(writeTaskIdx), Collections.singleton(blockId));
blockManagerMaster.onBlockStateChanged(blockId, BlockState.State.AVAILABLE,
"Writer side of the shuffle edge");
return true;
@@ -411,6 +409,7 @@ public final class BlockStoreTest {
data.forEach(element -> block.write(concBlockPartition.getKey(), element));
block.commit();
writerSideStore.writeBlock(block);
+ blockManagerMaster.onProducerTaskScheduled(getTaskId(0), Collections.singleton(block.getId()));
blockManagerMaster.onBlockStateChanged(
concBlockId, BlockState.State.AVAILABLE, "Writer side of the concurrent read edge");
return true;
@@ -500,6 +499,7 @@ public final class BlockStoreTest {
}
block.commit();
writerSideStore.writeBlock(block);
+ blockManagerMaster.onProducerTaskScheduled(getTaskId(writeTaskIdx), Collections.singleton(blockId));
blockManagerMaster.onBlockStateChanged(blockId, BlockState.State.AVAILABLE,
"Writer side of the shuffle in hash range edge");
return true;
diff --git a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java
index 229f30d..8c1ccd3 100644
--- a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java
+++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java
@@ -33,7 +33,7 @@ import edu.snu.nemo.common.Pair;
import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.dag.DAGBuilder;
import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.common.message.MessageParameters;
import edu.snu.nemo.runtime.common.message.PersistentConnectionToMasterMap;
@@ -43,6 +43,7 @@ import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
import edu.snu.nemo.runtime.common.plan.Stage;
import edu.snu.nemo.runtime.common.plan.StageEdge;
import edu.snu.nemo.runtime.executor.Executor;
+import edu.snu.nemo.runtime.executor.TestUtil;
import edu.snu.nemo.runtime.executor.data.BlockManagerWorker;
import edu.snu.nemo.runtime.executor.data.SerializerManager;
import edu.snu.nemo.runtime.master.*;
@@ -70,7 +71,9 @@ import java.io.IOException;
import java.util.*;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
import java.util.stream.IntStream;
+import java.util.stream.Stream;
import static edu.snu.nemo.common.dag.DAG.EMPTY_DAG_DIRECTORY;
import static edu.snu.nemo.runtime.common.RuntimeTestUtil.getRangedNumList;
@@ -210,7 +213,7 @@ public final class DataTransferTest {
}
@Test
- public void testWriteAndRead() throws Exception {
+ public void testWriteAndRead() {
// test OneToOne same worker
writeAndRead(worker1, worker1, CommunicationPatternProperty.Value.OneToOne, MEMORY_STORE);
@@ -317,24 +320,22 @@ public final class DataTransferTest {
final IRVertex srcMockVertex = mock(IRVertex.class);
final IRVertex dstMockVertex = mock(IRVertex.class);
- final Stage srcStage = setupStages("srcStage-" + testIndex);
- final Stage dstStage = setupStages("dstStage-" + testIndex);
+ final Stage srcStage = setupStages("srcStage" + testIndex);
+ final Stage dstStage = setupStages("dstStage" + testIndex);
dummyEdge = new StageEdge(edgeId, edgeProperties, srcMockVertex, dstMockVertex,
srcStage, dstStage, false);
// Initialize states in Master
- srcStage.getTaskIds().forEach(srcTaskId -> {
- final String blockId = RuntimeIdGenerator.generateBlockId(
- edgeId, RuntimeIdGenerator.getIndexFromTaskId(srcTaskId));
- master.initializeState(blockId, srcTaskId);
- master.onProducerTaskScheduled(srcTaskId);
+ TestUtil.generateTaskIds(srcStage).forEach(srcTaskId -> {
+ final String blockId = RuntimeIdManager.generateBlockId(edgeId, srcTaskId);
+ master.onProducerTaskScheduled(srcTaskId, Collections.singleton(blockId));
});
// Write
final List<List> dataWrittenList = new ArrayList<>();
- IntStream.range(0, PARALLELISM_TEN).forEach(srcTaskIndex -> {
+ TestUtil.generateTaskIds(srcStage).forEach(srcTaskId -> {
final List dataWritten = getRangedNumList(0, PARALLELISM_TEN);
- final OutputWriter writer = transferFactory.createWriter(srcVertex, srcTaskIndex, dstVertex, dummyEdge);
+ final OutputWriter writer = transferFactory.createWriter(srcTaskId, dstVertex, dummyEdge);
dataWritten.iterator().forEachRemaining(writer::write);
writer.close();
dataWrittenList.add(dataWritten);
@@ -411,35 +412,30 @@ public final class DataTransferTest {
final IRVertex srcMockVertex = mock(IRVertex.class);
final IRVertex dstMockVertex = mock(IRVertex.class);
- final Stage srcStage = setupStages("srcStage-" + testIndex);
- final Stage dstStage = setupStages("dstStage-" + testIndex);
+ final Stage srcStage = setupStages("srcStage" + testIndex);
+ final Stage dstStage = setupStages("dstStage" + testIndex);
dummyEdge = new StageEdge(edgeId, edgeProperties, srcMockVertex, dstMockVertex,
srcStage, dstStage, false);
final IRVertex dstMockVertex2 = mock(IRVertex.class);
- final Stage dstStage2 = setupStages("dstStage-" + testIndex2);
dummyEdge2 = new StageEdge(edgeId2, edgeProperties, srcMockVertex, dstMockVertex2,
srcStage, dstStage, false);
// Initialize states in Master
- srcStage.getTaskIds().forEach(srcTaskId -> {
- final String blockId = RuntimeIdGenerator.generateBlockId(
- edgeId, RuntimeIdGenerator.getIndexFromTaskId(srcTaskId));
- master.initializeState(blockId, srcTaskId);
- final String blockId2 = RuntimeIdGenerator.generateBlockId(
- edgeId2, RuntimeIdGenerator.getIndexFromTaskId(srcTaskId));
- master.initializeState(blockId2, srcTaskId);
- master.onProducerTaskScheduled(srcTaskId);
+ TestUtil.generateTaskIds(srcStage).forEach(srcTaskId -> {
+ final String blockId = RuntimeIdManager.generateBlockId(edgeId, srcTaskId);
+ final String blockId2 = RuntimeIdManager.generateBlockId(edgeId2, srcTaskId);
+ master.onProducerTaskScheduled(srcTaskId, Stream.of(blockId, blockId2).collect(Collectors.toSet()));
});
// Write
final List<List> dataWrittenList = new ArrayList<>();
- IntStream.range(0, PARALLELISM_TEN).forEach(srcTaskIndex -> {
+ TestUtil.generateTaskIds(srcStage).forEach(srcTaskId -> {
final List dataWritten = getRangedNumList(0, PARALLELISM_TEN);
- final OutputWriter writer = transferFactory.createWriter(srcVertex, srcTaskIndex, dstVertex, dummyEdge);
+ final OutputWriter writer = transferFactory.createWriter(srcTaskId, dstVertex, dummyEdge);
dataWritten.iterator().forEachRemaining(writer::write);
writer.close();
dataWrittenList.add(dataWritten);
- final OutputWriter writer2 = transferFactory.createWriter(srcVertex, srcTaskIndex, dstVertex, dummyEdge2);
+ final OutputWriter writer2 = transferFactory.createWriter(srcTaskId, dstVertex, dummyEdge2);
dataWritten.iterator().forEachRemaining(writer2::write);
writer2.close();
});
@@ -539,4 +535,5 @@ public final class DataTransferTest {
stageExecutionProperty.put(ScheduleGroupProperty.of(0));
return new Stage(stageId, emptyDag, stageExecutionProperty, Collections.emptyList());
}
+
}
diff --git a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
index e810989..a1251d5 100644
--- a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
+++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
@@ -30,7 +30,7 @@ import edu.snu.nemo.common.ir.vertex.OperatorVertex;
import edu.snu.nemo.common.ir.vertex.transform.Transform;
import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.message.PersistentConnectionToMasterMap;
import edu.snu.nemo.runtime.common.plan.Stage;
import edu.snu.nemo.runtime.common.plan.Task;
@@ -70,12 +70,15 @@ import static org.mockito.Mockito.*;
@PrepareForTest({InputReader.class, OutputWriter.class, DataTransferFactory.class,
TaskStateManager.class, StageEdge.class, PersistentConnectionToMasterMap.class, Stage.class, IREdge.class})
public final class TaskExecutorTest {
+ private static final AtomicInteger RUNTIME_EDGE_ID = new AtomicInteger(0);
private static final int DATA_SIZE = 100;
private static final ExecutionPropertyMap<VertexExecutionProperty> TASK_EXECUTION_PROPERTY_MAP
= new ExecutionPropertyMap<>("TASK_EXECUTION_PROPERTY_MAP");
private static final int SOURCE_PARALLELISM = 5;
+ private static final int FIRST_ATTEMPT = 0;
+
private List<Integer> elements;
- private Map<String, List> vertexIdToOutputData;
+ private Map<String, List> runtimeEdgeToOutputData;
private DataTransferFactory dataTransferFactory;
private TaskStateManager taskStateManager;
private MetricMessageSender metricMessageSender;
@@ -83,8 +86,8 @@ public final class TaskExecutorTest {
private AtomicInteger stageId;
private String generateTaskId() {
- return RuntimeIdGenerator.generateTaskId(0,
- RuntimeIdGenerator.generateStageId(stageId.getAndIncrement()));
+ return RuntimeIdManager.generateTaskId(
+ RuntimeIdManager.generateStageId(stageId.getAndIncrement()), 0, FIRST_ATTEMPT);
}
@Before
@@ -96,10 +99,10 @@ public final class TaskExecutorTest {
taskStateManager = mock(TaskStateManager.class);
// Mock a DataTransferFactory.
- vertexIdToOutputData = new HashMap<>();
+ runtimeEdgeToOutputData = new HashMap<>();
dataTransferFactory = mock(DataTransferFactory.class);
when(dataTransferFactory.createReader(anyInt(), any(), any())).then(new ParentTaskReaderAnswer());
- when(dataTransferFactory.createWriter(any(), anyInt(), any(), any())).then(new ChildTaskWriterAnswer());
+ when(dataTransferFactory.createWriter(any(), any(), any())).then(new ChildTaskWriterAnswer());
// Mock a MetricMessageSender.
metricMessageSender = mock(MetricMessageSender.class);
@@ -140,15 +143,15 @@ public final class TaskExecutorTest {
.addVertex(sourceIRVertex)
.buildWithoutSourceSinkCheck();
+ final StageEdge taskOutEdge = mockStageEdgeFrom(sourceIRVertex);
final Task task =
new Task(
"testSourceVertexDataFetching",
generateTaskId(),
- 0,
TASK_EXECUTION_PROPERTY_MAP,
new byte[0],
Collections.emptyList(),
- Collections.singletonList(mockStageEdgeFrom(sourceIRVertex)),
+ Collections.singletonList(taskOutEdge),
vertexIdToReadable);
// Execute the task.
@@ -157,7 +160,7 @@ public final class TaskExecutorTest {
taskExecutor.execute();
// Check the output.
- assertTrue(checkEqualElements(elements, vertexIdToOutputData.get(sourceIRVertex.getId())));
+ assertTrue(checkEqualElements(elements, runtimeEdgeToOutputData.get(taskOutEdge.getId())));
}
/**
@@ -171,14 +174,14 @@ public final class TaskExecutorTest {
.addVertex(vertex)
.buildWithoutSourceSinkCheck();
+ final StageEdge taskOutEdge = mockStageEdgeFrom(vertex);
final Task task = new Task(
"testSourceVertexDataFetching",
generateTaskId(),
- 0,
TASK_EXECUTION_PROPERTY_MAP,
new byte[0],
Collections.singletonList(mockStageEdgeTo(vertex)),
- Collections.singletonList(mockStageEdgeFrom(vertex)),
+ Collections.singletonList(taskOutEdge),
Collections.emptyMap());
// Execute the task.
@@ -187,7 +190,7 @@ public final class TaskExecutorTest {
taskExecutor.execute();
// Check the output.
- assertTrue(checkEqualElements(elements, vertexIdToOutputData.get(vertex.getId())));
+ assertTrue(checkEqualElements(elements, runtimeEdgeToOutputData.get(taskOutEdge.getId())));
}
/**
@@ -209,14 +212,14 @@ public final class TaskExecutorTest {
.connectVertices(createEdge(operatorIRVertex1, operatorIRVertex2, false))
.buildWithoutSourceSinkCheck();
+ final StageEdge taskOutEdge = mockStageEdgeFrom(operatorIRVertex2);
final Task task = new Task(
"testSourceVertexDataFetching",
generateTaskId(),
- 0,
TASK_EXECUTION_PROPERTY_MAP,
new byte[0],
Collections.singletonList(mockStageEdgeTo(operatorIRVertex1)),
- Collections.singletonList(mockStageEdgeFrom(operatorIRVertex2)),
+ Collections.singletonList(taskOutEdge),
Collections.emptyMap());
// Execute the task.
@@ -225,7 +228,7 @@ public final class TaskExecutorTest {
taskExecutor.execute();
// Check the output.
- assertTrue(checkEqualElements(elements, vertexIdToOutputData.get(operatorIRVertex2.getId())));
+ assertTrue(checkEqualElements(elements, runtimeEdgeToOutputData.get(taskOutEdge.getId())));
}
@Test(timeout=5000)
@@ -241,14 +244,14 @@ public final class TaskExecutorTest {
.connectVertices(createEdge(operatorIRVertex1, operatorIRVertex2, true))
.buildWithoutSourceSinkCheck();
+ final StageEdge taskOutEdge = mockStageEdgeFrom(operatorIRVertex2);
final Task task = new Task(
"testSourceVertexDataFetching",
generateTaskId(),
- 0,
TASK_EXECUTION_PROPERTY_MAP,
new byte[0],
Arrays.asList(mockStageEdgeTo(operatorIRVertex1), mockStageEdgeTo(operatorIRVertex2)),
- Collections.singletonList(mockStageEdgeFrom(operatorIRVertex2)),
+ Collections.singletonList(taskOutEdge),
Collections.emptyMap());
// Execute the task.
@@ -257,7 +260,7 @@ public final class TaskExecutorTest {
taskExecutor.execute();
// Check the output.
- final List<Pair<List<Integer>, Integer>> pairs = vertexIdToOutputData.get(operatorIRVertex2.getId());
+ final List<Pair<List<Integer>, Integer>> pairs = runtimeEdgeToOutputData.get(taskOutEdge.getId());
final List<Integer> values = pairs.stream().map(Pair::right).collect(Collectors.toList());
assertTrue(checkEqualElements(elements, values));
assertTrue(pairs.stream().map(Pair::left).allMatch(sideInput -> checkEqualElements(sideInput, values)));
@@ -296,16 +299,17 @@ public final class TaskExecutorTest {
.connectVertices(edge3)
.buildWithoutSourceSinkCheck();
+ final StageEdge outEdge1 = mockStageEdgeFrom(mainVertex);
+ final StageEdge outEdge2 = mockStageEdgeFrom(bonusVertex1);
+ final StageEdge outEdge3 = mockStageEdgeFrom(bonusVertex2);
+
final Task task = new Task(
"testAdditionalOutputs",
generateTaskId(),
- 0,
TASK_EXECUTION_PROPERTY_MAP,
new byte[0],
Collections.singletonList(mockStageEdgeTo(routerVertex)),
- Arrays.asList(mockStageEdgeFrom(mainVertex),
- mockStageEdgeFrom(bonusVertex1),
- mockStageEdgeFrom(bonusVertex2)),
+ Arrays.asList(outEdge1, outEdge2, outEdge3),
Collections.emptyMap());
// Execute the task.
@@ -314,9 +318,9 @@ public final class TaskExecutorTest {
taskExecutor.execute();
// Check the output.
- final List<Integer> mainOutputs = vertexIdToOutputData.get(mainVertex.getId());
- final List<Integer> bonusOutputs1 = vertexIdToOutputData.get(bonusVertex1.getId());
- final List<Integer> bonusOutputs2 = vertexIdToOutputData.get(bonusVertex1.getId());
+ final List<Integer> mainOutputs = runtimeEdgeToOutputData.get(outEdge1.getId());
+ final List<Integer> bonusOutputs1 = runtimeEdgeToOutputData.get(outEdge2.getId());
+ final List<Integer> bonusOutputs2 = runtimeEdgeToOutputData.get(outEdge3.getId());
List<Integer> even = elements.stream().filter(i -> i % 2 == 0).collect(Collectors.toList());
List<Integer> odd = elements.stream().filter(i -> i % 2 != 0).collect(Collectors.toList());
assertTrue(checkEqualElements(even, mainOutputs));
@@ -345,7 +349,7 @@ public final class TaskExecutorTest {
}
private StageEdge mockStageEdgeFrom(final IRVertex irVertex) {
- return new StageEdge("runtime incoming edge id",
+ return new StageEdge("SEdge" + RUNTIME_EDGE_ID.getAndIncrement(),
ExecutionPropertyMap.of(mock(IREdge.class), CommunicationPatternProperty.Value.OneToOne),
irVertex,
new OperatorVertex(new RelayTransform()),
@@ -394,15 +398,15 @@ public final class TaskExecutorTest {
@Override
public OutputWriter answer(final InvocationOnMock invocationOnMock) throws Throwable {
final Object[] args = invocationOnMock.getArguments();
- final IRVertex vertex = (IRVertex) args[0];
+ final RuntimeEdge runtimeEdge = (RuntimeEdge) args[2];
final OutputWriter outputWriter = mock(OutputWriter.class);
doAnswer(new Answer() {
@Override
public Object answer(final InvocationOnMock invocationOnMock) throws Throwable {
final Object[] args = invocationOnMock.getArguments();
final Object dataToWrite = args[0];
- vertexIdToOutputData.computeIfAbsent(vertex.getId(), emptyTaskId -> new ArrayList<>());
- vertexIdToOutputData.get(vertex.getId()).add(dataToWrite);
+ runtimeEdgeToOutputData.computeIfAbsent(runtimeEdge.getId(), emptyTaskId -> new ArrayList<>());
+ runtimeEdgeToOutputData.get(runtimeEdge.getId()).add(dataToWrite);
return null;
}
}).when(outputWriter).write(any());
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java
index 5461039..3f2ffda 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java
@@ -15,20 +15,14 @@
*/
package edu.snu.nemo.runtime.master;
-import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.exception.IllegalMessageException;
import edu.snu.nemo.common.exception.UnknownExecutionStateException;
-import edu.snu.nemo.common.ir.vertex.IRVertex;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.exception.AbsentBlockException;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.message.MessageContext;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.common.message.MessageListener;
-import edu.snu.nemo.runtime.common.plan.PhysicalPlan;
-import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
-import edu.snu.nemo.runtime.common.plan.Stage;
-import edu.snu.nemo.runtime.common.plan.StageEdge;
import edu.snu.nemo.runtime.common.state.BlockState;
import com.google.common.annotations.VisibleForTesting;
@@ -43,15 +37,12 @@ import java.util.concurrent.Future;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
-import java.util.stream.IntStream;
+import java.util.stream.Collectors;
import org.apache.reef.annotations.audience.DriverSide;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import static edu.snu.nemo.runtime.common.state.BlockState.State.IN_PROGRESS;
-import static edu.snu.nemo.runtime.common.state.BlockState.State.NOT_AVAILABLE;
-
/**
* Master-side block manager.
*/
@@ -59,8 +50,14 @@ import static edu.snu.nemo.runtime.common.state.BlockState.State.NOT_AVAILABLE;
@DriverSide
public final class BlockManagerMaster {
private static final Logger LOG = LoggerFactory.getLogger(BlockManagerMaster.class.getName());
- private final Map<String, BlockMetadata> blockIdToMetadata;
- private final Map<String, Set<String>> producerTaskIdToBlockIds;
+
+ private final Map<String, Set<String>> producerTaskIdToBlockIds; // a task can have multiple out-edges
+
+ /**
+ * See {@link RuntimeIdManager#generateBlockIdWildcard(String, int)} for information on block wildcards.
+ */
+ private final Map<String, Set<BlockMetadata>> blockIdWildcardToMetadataSet; // a metadata = a task attempt output
+
// A lock that can be acquired exclusively or not.
// Because the BlockMetadata itself is sufficiently synchronized,
// operation that runs in a single block can just acquire a (sharable) read lock.
@@ -68,6 +65,8 @@ public final class BlockManagerMaster {
// modifies global variables in this class have to acquire an (exclusive) write lock.
private final ReadWriteLock lock;
+ private final Random random = new Random();
+
/**
* Constructor.
*
@@ -77,56 +76,30 @@ public final class BlockManagerMaster {
private BlockManagerMaster(final MessageEnvironment masterMessageEnvironment) {
masterMessageEnvironment.setupListener(MessageEnvironment.BLOCK_MANAGER_MASTER_MESSAGE_LISTENER_ID,
new PartitionManagerMasterControlMessageReceiver());
- this.blockIdToMetadata = new HashMap<>();
+ this.blockIdWildcardToMetadataSet = new HashMap<>();
this.producerTaskIdToBlockIds = new HashMap<>();
this.lock = new ReentrantReadWriteLock();
}
- public void initialize(final PhysicalPlan physicalPlan) {
- final DAG<Stage, StageEdge> stageDAG = physicalPlan.getStageDAG();
- stageDAG.topologicalDo(stage -> {
- final List<String> taskIdsForStage = stage.getTaskIds();
- final List<StageEdge> stageOutgoingEdges = stageDAG.getOutgoingEdgesOf(stage);
-
- // Initialize states for blocks of inter-stage edges
- stageOutgoingEdges.forEach(stageEdge -> {
- final int srcParallelism = taskIdsForStage.size();
- IntStream.range(0, srcParallelism).forEach(srcTaskIdx -> {
- final String blockId = RuntimeIdGenerator.generateBlockId(stageEdge.getId(), srcTaskIdx);
- initializeState(blockId, taskIdsForStage.get(srcTaskIdx));
- });
- });
-
- // Initialize states for blocks of stage internal edges
- taskIdsForStage.forEach(taskId -> {
- final DAG<IRVertex, RuntimeEdge<IRVertex>> taskInternalDag = stage.getIRDAG();
- taskInternalDag.getVertices().forEach(task -> {
- final List<RuntimeEdge<IRVertex>> internalOutgoingEdges = taskInternalDag.getOutgoingEdgesOf(task);
- internalOutgoingEdges.forEach(taskRuntimeEdge -> {
- final int srcTaskIdx = RuntimeIdGenerator.getIndexFromTaskId(taskId);
- final String blockId = RuntimeIdGenerator.generateBlockId(taskRuntimeEdge.getId(), srcTaskIdx);
- initializeState(blockId, taskId);
- });
- });
- });
- });
- }
-
/**
- * Initializes the states of a block which will be produced by producer task(s).
+ * Initializes the states of a block which will be produced by a producer task.
*
* @param blockId the id of the block to initialize.
- * @param producerTaskId the id of the producer task.
+ * @param producerTaskId the id of the producer task.
*/
@VisibleForTesting
- public void initializeState(final String blockId,
- final String producerTaskId) {
+ private void initializeState(final String blockId, final String producerTaskId) {
final Lock writeLock = lock.writeLock();
writeLock.lock();
try {
- blockIdToMetadata.put(blockId, new BlockMetadata(blockId));
+ // task - to - blockIds
producerTaskIdToBlockIds.putIfAbsent(producerTaskId, new HashSet<>());
producerTaskIdToBlockIds.get(producerTaskId).add(blockId);
+
+ // wildcard - to - metadata
+ final String wildCard = RuntimeIdManager.getWildCardFromBlockId(blockId);
+ blockIdWildcardToMetadataSet.putIfAbsent(wildCard, new HashSet<>());
+ blockIdWildcardToMetadataSet.get(wildCard).add(new BlockMetadata(blockId));
} finally {
writeLock.unlock();
}
@@ -161,25 +134,28 @@ public final class BlockManagerMaster {
/**
* Returns a handler of block location requests.
*
- * @param blockId id of the specified block.
+ * @param blockIdOrWildcard id of the specified block.
* @return the handler of block location requests, which completes exceptionally when the block
* is not {@code IN_PROGRESS} or {@code AVAILABLE}.
*/
- public BlockLocationRequestHandler getBlockLocationHandler(final String blockId) {
+ public BlockRequestHandler getBlockLocationHandler(final String blockIdOrWildcard) {
final Lock readLock = lock.readLock();
readLock.lock();
try {
- final BlockState.State state = getBlockState(blockId);
- switch (state) {
- case IN_PROGRESS:
- case AVAILABLE:
- return blockIdToMetadata.get(blockId).getLocationHandler();
- case NOT_AVAILABLE:
- final BlockLocationRequestHandler handler = new BlockLocationRequestHandler(blockId);
- handler.completeExceptionally(new AbsentBlockException(blockId, state));
- return handler;
- default:
- throw new UnsupportedOperationException(state.toString());
+ final Set<BlockMetadata> metadataSet =
+ getBlockWildcardStateSet(RuntimeIdManager.getWildCardFromBlockId(blockIdOrWildcard));
+ final List<BlockMetadata> candidates = metadataSet.stream()
+ .filter(metadata -> metadata.getBlockState().equals(BlockState.State.IN_PROGRESS)
+ || metadata.getBlockState().equals(BlockState.State.AVAILABLE))
+ .collect(Collectors.toList());
+ if (!candidates.isEmpty()) {
+ // Randomly pick one of the candidate handlers.
+ return candidates.get(random.nextInt(candidates.size())).getLocationHandler();
+ } else {
+ // No candidate exists
+ final BlockRequestHandler handler = new BlockRequestHandler(blockIdOrWildcard);
+ handler.completeExceptionally(new AbsentBlockException(blockIdOrWildcard, BlockState.State.NOT_AVAILABLE));
+ return handler;
}
} finally {
readLock.unlock();
@@ -192,8 +168,7 @@ public final class BlockManagerMaster {
* @param blockId the id of the block.
* @return the ids of the producer tasks.
*/
- @VisibleForTesting
- public Set<String> getProducerTaskIds(final String blockId) {
+ private Set<String> getProducerTaskIds(final String blockId) {
final Lock readLock = lock.readLock();
readLock.lock();
try {
@@ -210,32 +185,16 @@ public final class BlockManagerMaster {
}
}
- public Set<String> getIdsOfBlocksProducedBy(final String taskId) {
- final Lock readLock = lock.readLock();
- readLock.lock();
- try {
- return producerTaskIdToBlockIds.get(taskId);
- } finally {
- readLock.unlock();
- }
- }
-
/**
* To be called when a potential producer task is scheduled.
- * @param scheduledTaskId the ID of the scheduled task.
+ * @param taskId the ID of the scheduled task.
+ * @param blockIds this task will produce
*/
- public void onProducerTaskScheduled(final String scheduledTaskId) {
+ public void onProducerTaskScheduled(final String taskId, final Set<String> blockIds) {
final Lock writeLock = lock.writeLock();
writeLock.lock();
try {
- if (producerTaskIdToBlockIds.containsKey(scheduledTaskId)) {
- producerTaskIdToBlockIds.get(scheduledTaskId).forEach(blockId -> {
- if (blockIdToMetadata.get(blockId).getBlockState()
- .getStateMachine().getCurrentState().equals(NOT_AVAILABLE)) {
- onBlockStateChanged(blockId, IN_PROGRESS, null);
- }
- });
- } // else this task does not produce any block
+ blockIds.forEach(blockId -> initializeState(blockId, taskId));
} finally {
writeLock.unlock();
}
@@ -252,13 +211,10 @@ public final class BlockManagerMaster {
writeLock.lock();
try {
if (producerTaskIdToBlockIds.containsKey(failedTaskId)) {
- producerTaskIdToBlockIds.get(failedTaskId).forEach(blockId -> {
- final BlockState.State state = (BlockState.State)
- blockIdToMetadata.get(blockId).getBlockState().getStateMachine().getCurrentState();
- LOG.info("Block lost: {}", blockId);
- onBlockStateChanged(blockId, BlockState.State.NOT_AVAILABLE, null);
- });
- } // else this task does not produce any block
+ producerTaskIdToBlockIds.get(failedTaskId).forEach(blockId ->
+ onBlockStateChanged(blockId, BlockState.State.NOT_AVAILABLE, null)
+ );
+ } // else this task has not produced any block
} finally {
writeLock.unlock();
}
@@ -270,13 +226,12 @@ public final class BlockManagerMaster {
* @param executorId the id of the executor.
* @return the committed blocks by the executor.
*/
- @VisibleForTesting
- Set<String> getCommittedBlocksByWorker(final String executorId) {
+ private Set<String> getCommittedBlocksByWorker(final String executorId) {
final Lock readLock = lock.readLock();
readLock.lock();
try {
final Set<String> blockIds = new HashSet<>();
- blockIdToMetadata.values().forEach(blockMetadata -> {
+ blockIdWildcardToMetadataSet.values().stream().flatMap(Set::stream).forEach(blockMetadata -> {
final Future<String> location = blockMetadata.getLocationHandler().getLocationFuture();
if (location.isDone()) {
try {
@@ -297,15 +252,14 @@ public final class BlockManagerMaster {
}
/**
- * @param blockId the id of the block.
- * @return the {@link BlockState} of a block.
+ * @param blockIdWildcard to query.
+ * @return set of block metadata for the wildcard, empty if none exists.
*/
- @VisibleForTesting
- public BlockState.State getBlockState(final String blockId) {
+ private Set<BlockMetadata> getBlockWildcardStateSet(final String blockIdWildcard) {
final Lock readLock = lock.readLock();
readLock.lock();
try {
- return (BlockState.State) blockIdToMetadata.get(blockId).getBlockState().getStateMachine().getCurrentState();
+ return blockIdWildcardToMetadataSet.getOrDefault(blockIdWildcard, new HashSet<>(0));
} finally {
readLock.unlock();
}
@@ -326,12 +280,24 @@ public final class BlockManagerMaster {
final Lock readLock = lock.readLock();
readLock.lock();
try {
- blockIdToMetadata.get(blockId).onStateChanged(newState, location);
+ getBlockMetaData(blockId).onStateChanged(newState, location);
} finally {
readLock.unlock();
}
}
+ private BlockMetadata getBlockMetaData(final String blockId) {
+ final List<BlockMetadata> candidates =
+ blockIdWildcardToMetadataSet.get(RuntimeIdManager.getWildCardFromBlockId(blockId))
+ .stream()
+ .filter(meta -> meta.getBlockId().equals(blockId))
+ .collect(Collectors.toList());
+ if (candidates.size() != 1) {
+ throw new RuntimeException("BlockId " + blockId + ": " + candidates.toString()); // should match only 1
+ }
+ return candidates.get(0);
+ }
+
/**
* Deals with a request for the location of a block.
*
@@ -341,12 +307,12 @@ public final class BlockManagerMaster {
void onRequestBlockLocation(final ControlMessage.Message message,
final MessageContext messageContext) {
assert (message.getType() == ControlMessage.MessageType.RequestBlockLocation);
- final String blockId = message.getRequestBlockLocationMsg().getBlockId();
+ final String blockIdWildcard = message.getRequestBlockLocationMsg().getBlockIdWildcard();
final long requestId = message.getId();
final Lock readLock = lock.readLock();
readLock.lock();
try {
- final BlockLocationRequestHandler locationFuture = getBlockLocationHandler(blockId);
+ final BlockRequestHandler locationFuture = getBlockLocationHandler(blockIdWildcard);
locationFuture.registerRequest(requestId, messageContext);
} finally {
readLock.unlock();
@@ -397,7 +363,7 @@ public final class BlockManagerMaster {
* The handler of block location requests.
*/
@VisibleForTesting
- public static final class BlockLocationRequestHandler {
+ public static final class BlockRequestHandler {
private final String blockId;
private final CompletableFuture<String> locationFuture;
@@ -406,7 +372,7 @@ public final class BlockManagerMaster {
*
* @param blockId the ID of the block.
*/
- BlockLocationRequestHandler(final String blockId) {
+ BlockRequestHandler(final String blockId) {
this.blockId = blockId;
this.locationFuture = new CompletableFuture<>();
}
@@ -454,7 +420,7 @@ public final class BlockManagerMaster {
}
messageContext.reply(
ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.EXECUTOR_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.BlockLocationInfo)
.setBlockLocationInfoMsg(infoMsgBuilder.build())
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockMetadata.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockMetadata.java
index 9e6d471..e1f2af7 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockMetadata.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockMetadata.java
@@ -16,6 +16,7 @@
package edu.snu.nemo.runtime.master;
import edu.snu.nemo.common.StateMachine;
+import edu.snu.nemo.common.exception.IllegalStateTransitionException;
import edu.snu.nemo.runtime.common.state.BlockState;
import edu.snu.nemo.runtime.common.exception.AbsentBlockException;
import org.slf4j.Logger;
@@ -32,7 +33,7 @@ final class BlockMetadata {
private static final Logger LOG = LoggerFactory.getLogger(BlockMetadata.class.getName());
private final String blockId;
private final BlockState blockState;
- private volatile BlockManagerMaster.BlockLocationRequestHandler locationHandler;
+ private volatile BlockManagerMaster.BlockRequestHandler locationHandler;
/**
* Constructs the metadata for a block.
@@ -43,7 +44,7 @@ final class BlockMetadata {
// Initialize block level metadata.
this.blockId = blockId;
this.blockState = new BlockState();
- this.locationHandler = new BlockManagerMaster.BlockLocationRequestHandler(blockId);
+ this.locationHandler = new BlockManagerMaster.BlockRequestHandler(blockId);
}
/**
@@ -65,7 +66,7 @@ final class BlockMetadata {
case NOT_AVAILABLE:
// Reset the block location and committer information.
locationHandler.completeExceptionally(new AbsentBlockException(blockId, newState));
- locationHandler = new BlockManagerMaster.BlockLocationRequestHandler(blockId);
+ locationHandler = new BlockManagerMaster.BlockRequestHandler(blockId);
break;
case AVAILABLE:
if (location == null) {
@@ -77,7 +78,11 @@ final class BlockMetadata {
throw new UnsupportedOperationException(newState.toString());
}
- stateMachine.setState(newState);
+ try {
+ stateMachine.setState(newState);
+ } catch (IllegalStateTransitionException e) {
+ throw new RuntimeException(blockId + " - Illegal block state transition ", e);
+ }
}
/**
@@ -90,14 +95,24 @@ final class BlockMetadata {
/**
* @return the state of this block.
*/
- BlockState getBlockState() {
- return blockState;
+ BlockState.State getBlockState() {
+ return (BlockState.State) blockState.getStateMachine().getCurrentState();
}
/**
* @return the handler of block location requests.
*/
- synchronized BlockManagerMaster.BlockLocationRequestHandler getLocationHandler() {
+ synchronized BlockManagerMaster.BlockRequestHandler getLocationHandler() {
return locationHandler;
}
+
+ @Override
+ public String toString() {
+ final StringBuilder sb = new StringBuilder();
+ sb.append(blockId);
+ sb.append("(");
+ sb.append(blockState);
+ sb.append(")");
+ return sb.toString();
+ }
}
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/MetricManagerMaster.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/MetricManagerMaster.java
index 4a266d5..00a32c4 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/MetricManagerMaster.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/MetricManagerMaster.java
@@ -17,7 +17,7 @@ package edu.snu.nemo.runtime.master;
import javax.inject.Inject;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.master.scheduler.ExecutorRegistry;
@@ -46,7 +46,7 @@ public final class MetricManagerMaster implements MetricMessageHandler {
public synchronized void sendMetricFlushRequest() {
executorRegistry.viewExecutors(executors -> executors.forEach(executor -> {
final ControlMessage.Message message = ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.EXECUTOR_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.RequestMetricFlush)
.build();
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanStateManager.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanStateManager.java
index dd9f302..54264dc 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanStateManager.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/PlanStateManager.java
@@ -17,10 +17,10 @@ package edu.snu.nemo.runtime.master;
import com.google.common.annotations.VisibleForTesting;
import edu.snu.nemo.common.exception.IllegalStateTransitionException;
-import edu.snu.nemo.common.exception.SchedulingException;
import edu.snu.nemo.common.exception.UnknownExecutionStateException;
import edu.snu.nemo.common.StateMachine;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.common.ir.vertex.executionproperty.ClonedSchedulingProperty;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.plan.PhysicalPlan;
import edu.snu.nemo.runtime.common.plan.Stage;
import edu.snu.nemo.runtime.common.state.PlanState;
@@ -34,6 +34,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
+import java.util.stream.Collectors;
import edu.snu.nemo.runtime.common.state.TaskState;
import edu.snu.nemo.runtime.common.metric.JobMetric;
@@ -65,14 +66,8 @@ public final class PlanStateManager {
* The data structures below track the execution states of this plan.
*/
private final PlanState planState;
- private final Map<String, StageState> idToStageStates;
- private final Map<String, TaskState> idToTaskStates;
-
- /**
- * Maintain the number of schedule attempts for each task.
- * The attempt numbers are updated only here, and are read-only in other places.
- */
- private final Map<String, Integer> taskIdToCurrentAttempt;
+ private final Map<String, StageState> stageIdToState;
+ private final Map<String, List<List<TaskState>>> stageIdToTaskAttemptStates; // sorted by task idx, and then attempt
/**
* Represents the plan to manage.
@@ -85,24 +80,16 @@ public final class PlanStateManager {
private final Lock finishLock;
private final Condition planFinishedCondition;
- /**
- * For metrics.
- */
- private final MetricMessageHandler metricMessageHandler;
-
private MetricStore metricStore;
public PlanStateManager(final PhysicalPlan physicalPlan,
- final MetricMessageHandler metricMessageHandler,
final int maxScheduleAttempt) {
this.planId = physicalPlan.getId();
this.physicalPlan = physicalPlan;
- this.metricMessageHandler = metricMessageHandler;
this.maxScheduleAttempt = maxScheduleAttempt;
this.planState = new PlanState();
- this.idToStageStates = new HashMap<>();
- this.idToTaskStates = new HashMap<>();
- this.taskIdToCurrentAttempt = new HashMap<>();
+ this.stageIdToState = new HashMap<>();
+ this.stageIdToTaskAttemptStates = new HashMap<>();
this.finishLock = new ReentrantLock();
this.planFinishedCondition = finishLock.newCondition();
this.metricStore = MetricStore.getStore();
@@ -117,18 +104,81 @@ public final class PlanStateManager {
*/
private void initializeComputationStates() {
onPlanStateChanged(PlanState.State.EXECUTING);
-
- // Initialize the states for the plan down to task-level.
physicalPlan.getStageDAG().topologicalDo(stage -> {
- idToStageStates.put(stage.getId(), new StageState());
- stage.getTaskIds().forEach(taskId -> {
- idToTaskStates.put(taskId, new TaskState());
- taskIdToCurrentAttempt.put(taskId, 1);
- });
+ stageIdToState.put(stage.getId(), new StageState());
+ stageIdToTaskAttemptStates.put(stage.getId(), new ArrayList<>(stage.getParallelism()));
+ for (int taskIndex = 0; taskIndex < stage.getParallelism(); taskIndex++) {
+ // for each task idx of this stage
+ stageIdToTaskAttemptStates.get(stage.getId()).add(new ArrayList<>());
+ // task states will be initialized lazily in getTaskAttemptsToSchedule()
+ }
});
}
/**
+ * Get task attempts that are "READY".
+ * @param stageId to run
+ * @return executable task attempts
+ */
+ public synchronized List<String> getTaskAttemptsToSchedule(final String stageId) {
+ if (getStageState(stageId).equals(StageState.State.COMPLETE)) {
+ // This stage is done
+ return new ArrayList<>(0);
+ }
+
+ // For each task index....
+ final List<String> taskAttemptsToSchedule = new ArrayList<>();
+ final Stage stage = physicalPlan.getStageDAG().getVertexById(stageId);
+ for (int taskIndex = 0; taskIndex < stage.getParallelism(); taskIndex++) {
+ final List<TaskState> attemptStatesForThisTaskIndex =
+ stageIdToTaskAttemptStates.get(stage.getId()).get(taskIndex);
+
+ // If one of the attempts is COMPLETE, do not schedule
+ if (attemptStatesForThisTaskIndex
+ .stream()
+ .noneMatch(state -> state.getStateMachine().getCurrentState().equals(TaskState.State.COMPLETE))) {
+
+ // (Step 1) Create new READY attempts, as many as
+ // # of clones - # of 'not-done' attempts)
+ final int numOfClones = stage.getPropertyValue(ClonedSchedulingProperty.class).orElse(1);
+ final long numOfNotDoneAttempts = attemptStatesForThisTaskIndex.stream().filter(this::isTaskNotDone).count();
+ for (int i = 0; i < numOfClones - numOfNotDoneAttempts; i++) {
+ attemptStatesForThisTaskIndex.add(new TaskState());
+ }
+
+ // (Step 2) Check max attempt
+ if (attemptStatesForThisTaskIndex.size() > maxScheduleAttempt) {
+ throw new RuntimeException(
+ attemptStatesForThisTaskIndex.size() + " exceeds max attempt " + maxScheduleAttempt);
+ }
+
+ // (Step 3) Return all READY attempts
+ for (int attempt = 0; attempt < attemptStatesForThisTaskIndex.size(); attempt++) {
+ if (attemptStatesForThisTaskIndex.get(attempt).getStateMachine().getCurrentState()
+ .equals(TaskState.State.READY)) {
+ taskAttemptsToSchedule.add(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt));
+ }
+ }
+
+ }
+ }
+
+ return taskAttemptsToSchedule;
+ }
+
+ private boolean isTaskNotDone(final TaskState taskState) {
+ final TaskState.State state = (TaskState.State) taskState.getStateMachine().getCurrentState();
+ return state.equals(TaskState.State.READY)
+ || state.equals(TaskState.State.EXECUTING)
+ || state.equals(TaskState.State.ON_HOLD);
+ }
+
+
+ public synchronized Set<String> getAllTaskAttemptsOfStage(final String stageId) {
+ return getTaskAttemptIdsToItsState(stageId).keySet();
+ }
+
+ /**
* Updates the state of a task.
* Task state changes can occur both in master and executor.
* State changes that occur in master are
@@ -142,58 +192,52 @@ public final class PlanStateManager {
*/
public synchronized void onTaskStateChanged(final String taskId, final TaskState.State newTaskState) {
// Change task state
- final StateMachine taskState = idToTaskStates.get(taskId).getStateMachine();
+ final StateMachine taskState = getTaskStateHelper(taskId).getStateMachine();
LOG.debug("Task State Transition: id {}, from {} to {}",
new Object[]{taskId, taskState.getCurrentState(), newTaskState});
-
metricStore.getOrCreateMetric(TaskMetric.class, taskId)
.addEvent((TaskState.State) taskState.getCurrentState(), newTaskState);
metricStore.triggerBroadcast(TaskMetric.class, taskId);
- taskState.setState(newTaskState);
-
- switch (newTaskState) {
- case ON_HOLD:
- case COMPLETE:
- case FAILED:
- case SHOULD_RETRY:
- case EXECUTING:
- break;
- case READY:
- final int currentAttempt = taskIdToCurrentAttempt.get(taskId) + 1;
- if (currentAttempt <= maxScheduleAttempt) {
- taskIdToCurrentAttempt.put(taskId, currentAttempt);
- } else {
- throw new SchedulingException(new Throwable("Exceeded max number of scheduling attempts for " + taskId));
- }
- break;
- default:
- throw new UnknownExecutionStateException(new Throwable("This task state is unknown"));
+ try {
+ taskState.setState(newTaskState);
+ } catch (IllegalStateTransitionException e) {
+ throw new RuntimeException(taskId + " - Illegal task state transition ", e);
}
- // Change stage state, if needed
- final String stageId = RuntimeIdGenerator.getStageIdFromTaskId(taskId);
- final List<String> tasksOfThisStage = physicalPlan.getStageDAG().getVertexById(stageId).getTaskIds();
- final long numOfCompletedOrOnHoldTasksInThisStage = tasksOfThisStage
- .stream()
- .map(this::getTaskState)
- .filter(state -> state.equals(TaskState.State.COMPLETE) || state.equals(TaskState.State.ON_HOLD))
+ // Log not-yet-completed tasks for us humans to track progress
+ final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId);
+ final List<List<TaskState>> taskStatesOfThisStage = stageIdToTaskAttemptStates.get(stageId);
+ final long numOfCompletedTaskIndicesInThisStage = taskStatesOfThisStage.stream()
+ .map(attempts -> attempts.stream()
+ .map(state -> state.getStateMachine().getCurrentState())
+ .allMatch(curState -> curState.equals(TaskState.State.COMPLETE)
+ || curState.equals(TaskState.State.SHOULD_RETRY)
+ || curState.equals(TaskState.State.ON_HOLD)))
+ .filter(bool -> bool.equals(true))
.count();
if (newTaskState.equals(TaskState.State.COMPLETE)) {
- // Log not-yet-completed tasks for us to track progress
- LOG.info("{} completed: {} Task(s) remaining in this stage",
- taskId, tasksOfThisStage.size() - numOfCompletedOrOnHoldTasksInThisStage);
+ LOG.info("{} completed: {} Task(s) out of {} are remaining in this stage",
+ taskId, taskStatesOfThisStage.size() - numOfCompletedTaskIndicesInThisStage, taskStatesOfThisStage.size());
}
+
+ // Change stage state, if needed
switch (newTaskState) {
// INCOMPLETE stage
case SHOULD_RETRY:
- onStageStateChanged(stageId, StageState.State.INCOMPLETE);
+ final boolean isAPeerAttemptCompleted = getPeerAttemptsforTheSameTaskIndex(taskId).stream()
+ .anyMatch(state -> state.equals(TaskState.State.COMPLETE));
+ if (!isAPeerAttemptCompleted) {
+ // None of the peers has completed, hence this stage is incomplete
+ onStageStateChanged(stageId, StageState.State.INCOMPLETE);
+ }
break;
// COMPLETE stage
case COMPLETE:
case ON_HOLD:
- if (numOfCompletedOrOnHoldTasksInThisStage == tasksOfThisStage.size()) {
+ if (numOfCompletedTaskIndicesInThisStage
+ == physicalPlan.getStageDAG().getVertexById(stageId).getParallelism()) {
onStageStateChanged(stageId, StageState.State.COMPLETE);
}
break;
@@ -208,6 +252,20 @@ public final class PlanStateManager {
}
}
+ private List<TaskState.State> getPeerAttemptsforTheSameTaskIndex(final String taskId) {
+ final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId);
+ final int taskIndex = RuntimeIdManager.getIndexFromTaskId(taskId);
+ final int attempt = RuntimeIdManager.getAttemptFromTaskId(taskId);
+
+ final List<TaskState> otherAttemptsforTheSameTaskIndex =
+ new ArrayList<>(stageIdToTaskAttemptStates.get(stageId).get(taskIndex));
+ otherAttemptsforTheSameTaskIndex.remove(attempt);
+
+ return otherAttemptsforTheSameTaskIndex.stream()
+ .map(state -> (TaskState.State) state.getStateMachine().getCurrentState())
+ .collect(Collectors.toList());
+ }
+
/**
* (PRIVATE METHOD)
* Updates the state of a stage.
@@ -216,7 +274,7 @@ public final class PlanStateManager {
*/
private void onStageStateChanged(final String stageId, final StageState.State newStageState) {
// Change stage state
- final StateMachine stageStateMachine = idToStageStates.get(stageId).getStateMachine();
+ final StateMachine stageStateMachine = stageIdToState.get(stageId).getStateMachine();
metricStore.getOrCreateMetric(StageMetric.class, stageId)
.addEvent(getStageState(stageId), newStageState);
@@ -224,10 +282,14 @@ public final class PlanStateManager {
LOG.debug("Stage State Transition: id {} from {} to {}",
new Object[]{stageId, stageStateMachine.getCurrentState(), newStageState});
- stageStateMachine.setState(newStageState);
+ try {
+ stageStateMachine.setState(newStageState);
+ } catch (IllegalStateTransitionException e) {
+ throw new RuntimeException(stageId + " - Illegal stage state transition ", e);
+ }
// Change plan state if needed
- final boolean allStagesCompleted = idToStageStates.values().stream().allMatch(state ->
+ final boolean allStagesCompleted = stageIdToState.values().stream().allMatch(state ->
state.getStateMachine().getCurrentState().equals(StageState.State.COMPLETE));
if (allStagesCompleted) {
onPlanStateChanged(PlanState.State.COMPLETE);
@@ -244,7 +306,12 @@ public final class PlanStateManager {
.addEvent((PlanState.State) planState.getStateMachine().getCurrentState(), newState);
metricStore.triggerBroadcast(JobMetric.class, planId);
- planState.getStateMachine().setState(newState);
+
+ try {
+ planState.getStateMachine().setState(newState);
+ } catch (IllegalStateTransitionException e) {
+ throw new RuntimeException(planId + " - Illegal plan state transition ", e);
+ }
if (newState == PlanState.State.EXECUTING) {
LOG.debug("Executing Plan ID {}...", this.planId);
@@ -260,7 +327,7 @@ public final class PlanStateManager {
finishLock.unlock();
}
} else {
- throw new IllegalStateTransitionException(new Exception("Illegal Plan State Transition"));
+ throw new RuntimeException("Illegal Plan State Transition");
}
}
@@ -321,24 +388,18 @@ public final class PlanStateManager {
}
public synchronized StageState.State getStageState(final String stageId) {
- return (StageState.State) idToStageStates.get(stageId).getStateMachine().getCurrentState();
+ return (StageState.State) stageIdToState.get(stageId).getStateMachine().getCurrentState();
}
public synchronized TaskState.State getTaskState(final String taskId) {
- return (TaskState.State) idToTaskStates.get(taskId).getStateMachine().getCurrentState();
+ return (TaskState.State) getTaskStateHelper(taskId).getStateMachine().getCurrentState();
}
- public synchronized int getTaskAttempt(final String taskId) {
- if (taskIdToCurrentAttempt.containsKey(taskId)) {
- return taskIdToCurrentAttempt.get(taskId);
- } else {
- throw new IllegalStateException("No mapping for this task's attemptIdx, an inconsistent state occurred.");
- }
- }
-
- @VisibleForTesting
- public synchronized Map<String, TaskState> getAllTaskStates() {
- return idToTaskStates;
+ private TaskState getTaskStateHelper(final String taskId) {
+ return stageIdToTaskAttemptStates
+ .get(RuntimeIdManager.getStageIdFromTaskId(taskId))
+ .get(RuntimeIdManager.getIndexFromTaskId(taskId))
+ .get(RuntimeIdManager.getAttemptFromTaskId(taskId));
}
/**
@@ -381,24 +442,45 @@ public final class PlanStateManager {
sb.append(", ");
}
isFirstStage = false;
- final StageState stageState = idToStageStates.get(stage.getId());
+ final StageState stageState = stageIdToState.get(stage.getId());
sb.append("{\"id\": \"").append(stage.getId()).append("\", ");
sb.append("\"state\": \"").append(stageState.toString()).append("\", ");
sb.append("\"tasks\": [");
boolean isFirstTask = true;
- for (final String taskId : stage.getTaskIds()) {
+ for (final Map.Entry<String, TaskState.State> entry : getTaskAttemptIdsToItsState(stage.getId()).entrySet()) {
if (!isFirstTask) {
sb.append(", ");
}
isFirstTask = false;
- final TaskState taskState = idToTaskStates.get(taskId);
- sb.append("{\"id\": \"").append(taskId).append("\", ");
- sb.append("\"state\": \"").append(taskState.toString()).append("\"}");
+ sb.append("{\"id\": \"").append(entry.getKey()).append("\", ");
+ sb.append("\"state\": \"").append(entry.getValue().toString()).append("\"}");
}
sb.append("]}");
}
sb.append("]}");
return sb.toString();
}
+
+ @VisibleForTesting
+ public synchronized Map<String, TaskState.State> getAllTaskAttemptIdsToItsState() {
+ return physicalPlan.getStageDAG().getVertices()
+ .stream()
+ .map(Stage::getId)
+ .flatMap(stageId -> getTaskAttemptIdsToItsState(stageId).entrySet().stream())
+ .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+ }
+
+ private Map<String, TaskState.State> getTaskAttemptIdsToItsState(final String stageId) {
+ final Map<String, TaskState.State> result = new HashMap<>();
+ final List<List<TaskState>> taskStates = stageIdToTaskAttemptStates.get(stageId);
+ for (int taskIndex = 0; taskIndex < taskStates.size(); taskIndex++) {
+ final List<TaskState> attemptStates = taskStates.get(taskIndex);
+ for (int attempt = 0; attempt < attemptStates.size(); attempt++) {
+ result.put(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt),
+ (TaskState.State) attemptStates.get(attempt).getStateMachine().getCurrentState());
+ }
+ }
+ return result;
+ }
}
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
index d1fa2df..bfc69d2 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
@@ -78,10 +78,8 @@ public final class RuntimeMaster {
private final ExecutorService runtimeMasterThread;
private final Scheduler scheduler;
private final ContainerManager containerManager;
- private final BlockManagerMaster blockManagerMaster;
private final MetricMessageHandler metricMessageHandler;
private final MessageEnvironment masterMessageEnvironment;
- private final MetricStore metricStore;
private final ClientRPC clientRPC;
private final MetricManagerMaster metricManagerMaster;
// For converting json data. This is a thread safe.
@@ -96,7 +94,6 @@ public final class RuntimeMaster {
@Inject
private RuntimeMaster(final Scheduler scheduler,
final ContainerManager containerManager,
- final BlockManagerMaster blockManagerMaster,
final MetricMessageHandler metricMessageHandler,
final MessageEnvironment masterMessageEnvironment,
final ClientRPC clientRPC,
@@ -110,7 +107,6 @@ public final class RuntimeMaster {
Executors.newSingleThreadExecutor(runnable -> new Thread(runnable, "RuntimeMaster thread"));
this.scheduler = scheduler;
this.containerManager = containerManager;
- this.blockManagerMaster = blockManagerMaster;
this.metricMessageHandler = metricMessageHandler;
this.masterMessageEnvironment = masterMessageEnvironment;
this.masterMessageEnvironment
@@ -121,7 +117,6 @@ public final class RuntimeMaster {
this.irVertices = new HashSet<>();
this.resourceRequestCount = new AtomicInteger(0);
this.objectMapper = new ObjectMapper();
- this.metricStore = MetricStore.getStore();
this.metricServer = startRestMetricServer();
}
@@ -157,8 +152,7 @@ public final class RuntimeMaster {
final Callable<Pair<PlanStateManager, ScheduledExecutorService>> planExecutionCallable = () -> {
this.irVertices.addAll(plan.getIdToIRVertex().values());
try {
- blockManagerMaster.initialize(plan);
- final PlanStateManager planStateManager = new PlanStateManager(plan, metricMessageHandler, maxScheduleAttempt);
+ final PlanStateManager planStateManager = new PlanStateManager(plan, maxScheduleAttempt);
scheduler.schedulePlan(plan, planStateManager);
final ScheduledExecutorService dagLoggingExecutor = scheduleDagLogging(planStateManager);
return Pair.of(planStateManager, dagLoggingExecutor);
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ExecutorRepresenter.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ExecutorRepresenter.java
index d5c2e89..e95b260 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ExecutorRepresenter.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ExecutorRepresenter.java
@@ -17,7 +17,7 @@ package edu.snu.nemo.runtime.master.resource;
import com.google.protobuf.ByteString;
import edu.snu.nemo.common.ir.vertex.executionproperty.ResourceSlotProperty;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.common.message.MessageSender;
@@ -110,7 +110,7 @@ public final class ExecutorRepresenter {
final byte[] serialized = SerializationUtils.serialize(task);
sendControlMessage(
ControlMessage.Message.newBuilder()
- .setId(RuntimeIdGenerator.generateMessageId())
+ .setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.EXECUTOR_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.ScheduleTask)
.setScheduleTaskMsg(
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ResourceSpecification.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ResourceSpecification.java
index 173a477..d0b3958 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ResourceSpecification.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ResourceSpecification.java
@@ -15,7 +15,7 @@
*/
package edu.snu.nemo.runtime.master.resource;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
/**
* Represents the specifications of a resource.
@@ -38,7 +38,7 @@ public final class ResourceSpecification {
final int capacity,
final int memory,
final int poisonSec) {
- this.resourceSpecId = RuntimeIdGenerator.generateResourceSpecId();
+ this.resourceSpecId = RuntimeIdManager.generateResourceSpecId();
this.containerType = containerType;
this.capacity = capacity;
this.memory = memory;
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchScheduler.java
index 77881c0..71e7d3f 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchScheduler.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchScheduler.java
@@ -22,10 +22,9 @@ import edu.snu.nemo.common.ir.Readable;
import edu.snu.nemo.common.ir.edge.IREdge;
import edu.snu.nemo.common.ir.edge.executionproperty.MetricCollectionProperty;
import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.eventhandler.DynamicOptimizationEvent;
import edu.snu.nemo.runtime.common.plan.*;
-import edu.snu.nemo.runtime.common.state.BlockState;
import edu.snu.nemo.runtime.common.state.TaskState;
import edu.snu.nemo.runtime.master.DataSkewDynOptDataHandler;
import edu.snu.nemo.runtime.master.DynOptDataHandler;
@@ -43,7 +42,6 @@ import javax.annotation.concurrent.NotThreadSafe;
import javax.inject.Inject;
import java.util.*;
import java.util.stream.Collectors;
-import java.util.stream.Stream;
import org.slf4j.Logger;
@@ -106,8 +104,6 @@ public final class BatchScheduler implements Scheduler {
*/
@Override
public void schedulePlan(final PhysicalPlan submittedPhysicalPlan, final PlanStateManager submittedPlanStateManager) {
- LOG.info("Scheduled plan");
-
this.physicalPlan = submittedPhysicalPlan;
this.planStateManager = submittedPlanStateManager;
@@ -147,73 +143,59 @@ public final class BatchScheduler implements Scheduler {
final TaskState.State newState,
@Nullable final String vertexPutOnHold,
final TaskState.RecoverableTaskFailureCause failureCause) {
- final int currentTaskAttemptIndex = planStateManager.getTaskAttempt(taskId);
-
- if (taskAttemptIndex == currentTaskAttemptIndex) {
- // Do change state, as this notification is for the current task attempt.
- planStateManager.onTaskStateChanged(taskId, newState);
- switch (newState) {
- case COMPLETE:
- onTaskExecutionComplete(executorId, taskId);
- break;
- case SHOULD_RETRY:
- // SHOULD_RETRY from an executor means that the task ran into a recoverable failure
- onTaskExecutionFailedRecoverable(executorId, taskId, failureCause);
- break;
- case ON_HOLD:
- onTaskExecutionOnHold(executorId, taskId);
- break;
- case FAILED:
- throw new UnrecoverableFailureException(new Exception(new StringBuffer().append("The plan failed on Task #")
- .append(taskId).append(" in Executor ").append(executorId).toString()));
- case READY:
- case EXECUTING:
- throw new IllegalStateTransitionException(
- new Exception("The states READY/EXECUTING cannot occur at this point"));
- default:
- throw new UnknownExecutionStateException(new Exception("This TaskState is unknown: " + newState));
- }
+ // Do change state, as this notification is for the current task attempt.
+ planStateManager.onTaskStateChanged(taskId, newState);
+ switch (newState) {
+ case COMPLETE:
+ onTaskExecutionComplete(executorId, taskId);
+ break;
+ case SHOULD_RETRY:
+ // SHOULD_RETRY from an executor means that the task ran into a recoverable failure
+ onTaskExecutionFailedRecoverable(executorId, taskId, failureCause);
+ break;
+ case ON_HOLD:
+ onTaskExecutionOnHold(executorId, taskId);
+ break;
+ case FAILED:
+ throw new UnrecoverableFailureException(new Exception(new StringBuffer().append("The plan failed on Task #")
+ .append(taskId).append(" in Executor ").append(executorId).toString()));
+ case READY:
+ case EXECUTING:
+ throw new RuntimeException("The states READY/EXECUTING cannot occur at this point");
+ default:
+ throw new UnknownExecutionStateException(new Exception("This TaskState is unknown: " + newState));
+ }
- // Invoke doSchedule()
- switch (newState) {
- case COMPLETE:
- case ON_HOLD:
- // If the stage has completed
- final String stageIdForTaskUponCompletion = RuntimeIdGenerator.getStageIdFromTaskId(taskId);
- if (planStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE)) {
- if (!planStateManager.isPlanDone()) {
- doSchedule();
- }
+ // Invoke doSchedule()
+ switch (newState) {
+ case COMPLETE:
+ case ON_HOLD:
+ // If the stage has completed
+ final String stageIdForTaskUponCompletion = RuntimeIdManager.getStageIdFromTaskId(taskId);
+ if (planStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE)) {
+ if (!planStateManager.isPlanDone()) {
+ doSchedule();
}
- break;
- case SHOULD_RETRY:
- // Do retry
- doSchedule();
- break;
- default:
- break;
- }
+ }
+ break;
+ case SHOULD_RETRY:
+ // Do retry
+ doSchedule();
+ break;
+ default:
+ break;
+ }
- // Invoke taskDispatcher.onExecutorSlotAvailable()
- switch (newState) {
- // These three states mean that a slot is made available.
- case COMPLETE:
- case ON_HOLD:
- case SHOULD_RETRY:
- taskDispatcher.onExecutorSlotAvailable();
- break;
- default:
- break;
- }
- } else if (taskAttemptIndex < currentTaskAttemptIndex) {
- // Do not change state, as this report is from a previous task attempt.
- // For example, the master can receive a notification that an executor has been removed,
- // and then a notification that the task that was running in the removed executor has been completed.
- // In this case, if we do not consider the attempt number, the state changes from SHOULD_RETRY to COMPLETED,
- // which is illegal.
- LOG.info("{} state change to {} arrived late, we will ignore this.", new Object[]{taskId, newState});
- } else {
- throw new SchedulingException(new Throwable("AttemptIdx for a task cannot be greater than its current index"));
+ // Invoke taskDispatcher.onExecutorSlotAvailable()
+ switch (newState) {
+ // These three states mean that a slot is made available.
+ case COMPLETE:
+ case ON_HOLD:
+ case SHOULD_RETRY:
+ taskDispatcher.onExecutorSlotAvailable();
+ break;
+ default:
+ break;
}
}
@@ -274,7 +256,7 @@ public final class BatchScheduler implements Scheduler {
LOG.info("Scheduling some tasks in {}, which are in the same ScheduleGroup", tasksToSchedule.stream()
.map(Task::getTaskId)
- .map(RuntimeIdGenerator::getStageIdFromTaskId)
+ .map(RuntimeIdManager::getStageIdFromTaskId)
.collect(Collectors.toSet()));
// Set the pointer to the schedulable tasks.
@@ -306,41 +288,22 @@ public final class BatchScheduler implements Scheduler {
final List<StageEdge> stageOutgoingEdges =
physicalPlan.getStageDAG().getOutgoingEdgesOf(stageToSchedule.getId());
- final List<String> taskIdsToSchedule = new LinkedList<>();
- for (final String taskId : stageToSchedule.getTaskIds()) {
- final TaskState.State taskState = planStateManager.getTaskState(taskId);
-
- switch (taskState) {
- // Don't schedule these.
- case COMPLETE:
- case EXECUTING:
- case ON_HOLD:
- break;
-
- // These are schedulable.
- case SHOULD_RETRY:
- planStateManager.onTaskStateChanged(taskId, TaskState.State.READY);
- case READY:
- taskIdsToSchedule.add(taskId);
- break;
-
- // This shouldn't happen.
- default:
- throw new SchedulingException(new Throwable("Detected a FAILED Task"));
- }
- }
-
// Create and return tasks.
final List<Map<String, Readable>> vertexIdToReadables = stageToSchedule.getVertexIdToReadables();
+
+ final List<String> taskIdsToSchedule = planStateManager.getTaskAttemptsToSchedule(stageToSchedule.getId());
final List<Task> tasks = new ArrayList<>(taskIdsToSchedule.size());
taskIdsToSchedule.forEach(taskId -> {
- blockManagerMaster.onProducerTaskScheduled(taskId); // Notify the block manager early for push edges.
- final int taskIdx = RuntimeIdGenerator.getIndexFromTaskId(taskId);
- final int attemptIdx = planStateManager.getTaskAttempt(taskId);
+ final Set<String> blockIds = physicalPlan.getStageDAG()
+ .getOutgoingEdgesOf(RuntimeIdManager.getStageIdFromTaskId(taskId))
+ .stream()
+ .map(stageEdge -> RuntimeIdManager.generateBlockId(stageEdge.getId(), taskId))
+ .collect(Collectors.toSet()); // ids of blocks this task will produce
+ blockManagerMaster.onProducerTaskScheduled(taskId, blockIds);
+ final int taskIdx = RuntimeIdManager.getIndexFromTaskId(taskId);
tasks.add(new Task(
physicalPlan.getId(),
taskId,
- attemptIdx,
stageToSchedule.getExecutionProperties(),
stageToSchedule.getSerializedIRDAG(),
stageIncomingEdges,
@@ -367,10 +330,12 @@ public final class BatchScheduler implements Scheduler {
});
}
- public IREdge getEdgeToOptimize(final String taskId) {
+ private IREdge getEdgeToOptimize(final String taskId) {
// Get a stage including the given task
final Stage stagePutOnHold = physicalPlan.getStageDAG().getVertices().stream()
- .filter(stage -> stage.getTaskIds().contains(taskId)).findFirst().get();
+ .filter(stage -> stage.getId().equals(RuntimeIdManager.getStageIdFromTaskId(taskId)))
+ .findFirst()
+ .orElseThrow(() -> new RuntimeException());
// Get outgoing edges of that stage with MetricCollectionProperty
List<StageEdge> stageEdges = physicalPlan.getStageDAG().getOutgoingEdgesOf(stagePutOnHold);
@@ -400,7 +365,7 @@ public final class BatchScheduler implements Scheduler {
executor.onTaskExecutionComplete(taskId);
return Pair.of(executor, state);
});
- final String stageIdForTaskUponCompletion = RuntimeIdGenerator.getStageIdFromTaskId(taskId);
+ final String stageIdForTaskUponCompletion = RuntimeIdManager.getStageIdFromTaskId(taskId);
final boolean stageComplete =
planStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE);
@@ -464,32 +429,32 @@ public final class BatchScheduler implements Scheduler {
return Collections.emptySet();
}
- final Set<String> selectedParentTasks = children.stream()
+ final Set<String> parentsWithLostBlocks = children.stream()
.flatMap(child -> getParentTasks(child).stream())
- .filter(parent -> blockManagerMaster.getIdsOfBlocksProducedBy(parent).stream()
- .map(blockManagerMaster::getBlockState)
- .anyMatch(blockState -> blockState.equals(BlockState.State.NOT_AVAILABLE)) // If a block is missing
- )
+ .filter(parent -> blockManagerMaster.getBlockLocationHandler(parent).getLocationFuture().isCancelled())
.collect(Collectors.toSet());
// Recursive call
- return Sets.union(selectedParentTasks, recursivelyGetParentTasksForLostBlocks(selectedParentTasks));
+ return Sets.union(parentsWithLostBlocks, recursivelyGetParentTasksForLostBlocks(parentsWithLostBlocks));
}
private Set<String> getParentTasks(final String childTaskId) {
- final String stageIdOfChildTask = RuntimeIdGenerator.getStageIdFromTaskId(childTaskId);
+ final String stageIdOfChildTask = RuntimeIdManager.getStageIdFromTaskId(childTaskId);
return physicalPlan.getStageDAG().getIncomingEdgesOf(stageIdOfChildTask)
.stream()
.flatMap(inStageEdge -> {
- final List<String> tasksOfParentStage = inStageEdge.getSrc().getTaskIds();
+ final String parentStageId = inStageEdge.getSrc().getId();
+ final Set<String> tasksOfParentStage = planStateManager.getAllTaskAttemptsOfStage(parentStageId);
+
switch (inStageEdge.getDataCommunicationPattern()) {
case Shuffle:
case BroadCast:
- // All of the parent stage's tasks are parents
+ // All of the parent stage's tasks
return tasksOfParentStage.stream();
case OneToOne:
- // Only one of the parent stage's tasks is a parent
- return Stream.of(tasksOfParentStage.get(RuntimeIdGenerator.getIndexFromTaskId(childTaskId)));
+ // Same-index tasks of the parent stage
+ return tasksOfParentStage.stream().filter(task ->
+ RuntimeIdManager.getIndexFromTaskId(task) == RuntimeIdManager.getIndexFromTaskId(childTaskId));
default:
throw new IllegalStateException(inStageEdge.toString());
}
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/NodeShareSchedulingConstraint.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/NodeShareSchedulingConstraint.java
index a2c0935..6ec8eb4 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/NodeShareSchedulingConstraint.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/NodeShareSchedulingConstraint.java
@@ -17,7 +17,7 @@ package edu.snu.nemo.runtime.master.scheduler;
import edu.snu.nemo.common.ir.executionproperty.AssociatedProperty;
import edu.snu.nemo.common.ir.vertex.executionproperty.ResourceSiteProperty;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.plan.Task;
import edu.snu.nemo.runtime.master.resource.ExecutorRepresenter;
@@ -57,7 +57,7 @@ public final class NodeShareSchedulingConstraint implements SchedulingConstraint
}
try {
return executor.getNodeName().equals(
- getNodeName(propertyValue, RuntimeIdGenerator.getIndexFromTaskId(task.getTaskId())));
+ getNodeName(propertyValue, RuntimeIdManager.getIndexFromTaskId(task.getTaskId())));
} catch (final IllegalStateException e) {
throw new RuntimeException(String.format("Cannot schedule %s", task.getTaskId(), e));
}
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraint.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraint.java
index 6e5ba4c..cf593a1 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraint.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraint.java
@@ -20,7 +20,7 @@ import edu.snu.nemo.common.ir.edge.executionproperty.CommunicationPatternPropert
import edu.snu.nemo.common.ir.edge.executionproperty.DataSkewMetricProperty;
import edu.snu.nemo.common.ir.executionproperty.AssociatedProperty;
import edu.snu.nemo.common.ir.vertex.executionproperty.ResourceSkewedDataProperty;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.common.HashRange;
import edu.snu.nemo.common.KeyRange;
import edu.snu.nemo.runtime.common.plan.StageEdge;
@@ -45,7 +45,7 @@ public final class SkewnessAwareSchedulingConstraint implements SchedulingConstr
}
public boolean hasSkewedData(final Task task) {
- final int taskIdx = RuntimeIdGenerator.getIndexFromTaskId(task.getTaskId());
+ final int taskIdx = RuntimeIdManager.getIndexFromTaskId(task.getTaskId());
for (StageEdge inEdge : task.getTaskIncomingEdges()) {
if (CommunicationPatternProperty.Value.Shuffle
.equals(inEdge.getDataCommunicationPattern())) {
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraint.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraint.java
index d355aa7..f5e270a 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraint.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraint.java
@@ -19,7 +19,7 @@ import edu.snu.nemo.common.ir.Readable;
import edu.snu.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import edu.snu.nemo.common.ir.executionproperty.AssociatedProperty;
import edu.snu.nemo.common.ir.vertex.executionproperty.ResourceLocalityProperty;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.plan.StageEdge;
import edu.snu.nemo.runtime.common.plan.Task;
import edu.snu.nemo.runtime.master.BlockManagerMaster;
@@ -60,9 +60,8 @@ public final class SourceLocationAwareSchedulingConstraint implements Scheduling
physicalStageEdge.getPropertyValue(CommunicationPatternProperty.class)
.orElseThrow(() -> new RuntimeException("No comm pattern!")))) {
final String blockIdToRead =
- RuntimeIdGenerator.generateBlockId(physicalStageEdge.getId(),
- RuntimeIdGenerator.getIndexFromTaskId(task.getTaskId()));
- final BlockManagerMaster.BlockLocationRequestHandler locationHandler =
+ RuntimeIdManager.generateBlockId(physicalStageEdge.getId(), task.getTaskId());
+ final BlockManagerMaster.BlockRequestHandler locationHandler =
blockManagerMaster.getBlockLocationHandler(blockIdToRead);
if (locationHandler.getLocationFuture().isDone()) { // if the location is known.
try {
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/TaskDispatcher.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/TaskDispatcher.java
index e6feaa5..0dbee50 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/TaskDispatcher.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/TaskDispatcher.java
@@ -46,7 +46,6 @@ import javax.inject.Inject;
@NotThreadSafe
final class TaskDispatcher {
private static final Logger LOG = LoggerFactory.getLogger(TaskDispatcher.class.getName());
- private final Map<String, PlanStateManager> planStateManagers;
private final PendingTaskCollectionPointer pendingTaskCollectionPointer;
private final ExecutorService schedulerThread;
private boolean isSchedulerRunning;
@@ -57,12 +56,13 @@ final class TaskDispatcher {
private final SchedulingConstraintRegistry schedulingConstraintRegistry;
private final SchedulingPolicy schedulingPolicy;
+ private PlanStateManager planStateManager;
+
@Inject
private TaskDispatcher(final SchedulingConstraintRegistry schedulingConstraintRegistry,
final SchedulingPolicy schedulingPolicy,
final PendingTaskCollectionPointer pendingTaskCollectionPointer,
final ExecutorRegistry executorRegistry) {
- this.planStateManagers = new HashMap<>();
this.pendingTaskCollectionPointer = pendingTaskCollectionPointer;
this.schedulerThread = Executors.newSingleThreadExecutor(runnable ->
new Thread(runnable, "TaskDispatcher thread"));
@@ -84,13 +84,12 @@ final class TaskDispatcher {
doScheduleTaskList();
schedulingIteration.await();
}
- planStateManagers.values().forEach(planStateManager -> {
- if (planStateManager.isPlanDone()) {
- LOG.info("{} is complete.", planStateManager.getPlanId());
- } else {
- LOG.info("{} is incomplete.", planStateManager.getPlanId());
- }
- });
+
+ if (planStateManager.isPlanDone()) {
+ LOG.info("{} is complete.", planStateManager.getPlanId());
+ } else {
+ LOG.info("{} is incomplete.", planStateManager.getPlanId());
+ }
LOG.info("TaskDispatcher Terminated!");
}
}
@@ -106,7 +105,6 @@ final class TaskDispatcher {
final Collection<Task> taskList = taskListOptional.get();
final List<Task> couldNotSchedule = new ArrayList<>();
for (final Task task : taskList) {
- final PlanStateManager planStateManager = planStateManagers.get(task.getPlanId());
if (!planStateManager.getTaskState(task.getTaskId()).equals(TaskState.State.READY)) {
// Guard against race conditions causing duplicate task launches
LOG.debug("Skipping {} as it is not READY", task.getTaskId());
@@ -163,8 +161,8 @@ final class TaskDispatcher {
/**
* Run the dispatcher thread.
*/
- void run(final PlanStateManager planStateManager) {
- planStateManagers.put(planStateManager.getPlanId(), planStateManager);
+ void run(final PlanStateManager plan) {
+ this.planStateManager = plan;
if (!isTerminated && !isSchedulerRunning) {
schedulerThread.execute(new SchedulerThread());
schedulerThread.shutdown();
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/BlockManagerMasterTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/BlockManagerMasterTest.java
index e5b362e..04437cb 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/BlockManagerMasterTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/BlockManagerMasterTest.java
@@ -15,17 +15,17 @@
*/
package edu.snu.nemo.runtime.master;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.exception.AbsentBlockException;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.common.message.local.LocalMessageDispatcher;
import edu.snu.nemo.runtime.common.message.local.LocalMessageEnvironment;
import edu.snu.nemo.runtime.common.state.BlockState;
import org.apache.reef.tang.Injector;
-import org.apache.reef.tang.Tang;
import org.junit.Before;
import org.junit.Test;
+import java.util.Collections;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import static org.junit.Assert.assertEquals;
@@ -36,6 +36,8 @@ import static org.junit.Assert.assertTrue;
* Test for {@link BlockManagerMaster}.
*/
public final class BlockManagerMasterTest {
+ private static int FIRST_ATTEMPT = 0;
+ private static int SECOND_ATTEPMT = 1;
private BlockManagerMaster blockManagerMaster;
@Before
@@ -78,19 +80,18 @@ public final class BlockManagerMasterTest {
*/
@Test
public void testLostAfterCommit() throws Exception {
- final String edgeId = RuntimeIdGenerator.generateStageEdgeId("Edge-0");
+ final String edgeId = RuntimeIdManager.generateStageEdgeId("Edge0");
final int srcTaskIndex = 0;
- final String taskId = RuntimeIdGenerator.generateTaskId(srcTaskIndex, "Stage-test");
- final String executorId = RuntimeIdGenerator.generateExecutorId();
- final String blockId = RuntimeIdGenerator.generateBlockId(edgeId, srcTaskIndex);
+ final String taskId = RuntimeIdManager.generateTaskId("Stage0", srcTaskIndex, FIRST_ATTEMPT);
+ final String executorId = RuntimeIdManager.generateExecutorId();
+ final String blockId = RuntimeIdManager.generateBlockId(edgeId, taskId);
// Initially the block state is NOT_AVAILABLE.
- blockManagerMaster.initializeState(blockId, taskId);
checkBlockAbsentException(blockManagerMaster.getBlockLocationHandler(blockId).getLocationFuture(), blockId,
BlockState.State.NOT_AVAILABLE);
// The block is being IN_PROGRESS.
- blockManagerMaster.onProducerTaskScheduled(taskId);
+ blockManagerMaster.onProducerTaskScheduled(taskId, Collections.singleton(blockId));
final Future<String> future = blockManagerMaster.getBlockLocationHandler(blockId).getLocationFuture();
checkPendingFuture(future);
@@ -111,39 +112,48 @@ public final class BlockManagerMasterTest {
*/
@Test
public void testBeforeAfterCommit() throws Exception {
- final String edgeId = RuntimeIdGenerator.generateStageEdgeId("Edge-1");
+ final String edgeId = RuntimeIdManager.generateStageEdgeId("Edge1");
final int srcTaskIndex = 0;
- final String taskId = RuntimeIdGenerator.generateTaskId(srcTaskIndex, "Stage-Test");
- final String executorId = RuntimeIdGenerator.generateExecutorId();
- final String blockId = RuntimeIdGenerator.generateBlockId(edgeId, srcTaskIndex);
- // The block is being scheduled.
- blockManagerMaster.initializeState(blockId, taskId);
- blockManagerMaster.onProducerTaskScheduled(taskId);
- final Future<String> future0 = blockManagerMaster.getBlockLocationHandler(blockId).getLocationFuture();
- checkPendingFuture(future0);
+ // First attempt
+ {
+ final String firstAttemptTaskId = RuntimeIdManager.generateTaskId("Stage0", srcTaskIndex, FIRST_ATTEMPT);
+ final String firstAttemptBlockId = RuntimeIdManager.generateBlockId(edgeId, firstAttemptTaskId);
- // Producer task fails.
- blockManagerMaster.onProducerTaskFailed(taskId);
+ // The block is being scheduled.
+ blockManagerMaster.onProducerTaskScheduled(firstAttemptTaskId, Collections.singleton(firstAttemptBlockId));
+ final Future<String> future0 = blockManagerMaster.getBlockLocationHandler(firstAttemptBlockId).getLocationFuture();
+ checkPendingFuture(future0);
- // A future, previously pending on IN_PROGRESS state, is now completed exceptionally.
- checkBlockAbsentException(future0, blockId, BlockState.State.NOT_AVAILABLE);
- checkBlockAbsentException(blockManagerMaster.getBlockLocationHandler(blockId).getLocationFuture(), blockId,
- BlockState.State.NOT_AVAILABLE);
-
- // Re-scheduling the task.
- blockManagerMaster.onProducerTaskScheduled(taskId);
- final Future<String> future1 = blockManagerMaster.getBlockLocationHandler(blockId).getLocationFuture();
- checkPendingFuture(future1);
+ // Producer task fails.
+ blockManagerMaster.onProducerTaskFailed(firstAttemptTaskId);
- // Committed.
- blockManagerMaster.onBlockStateChanged(blockId, BlockState.State.AVAILABLE, executorId);
- checkBlockLocation(future1, executorId); // A future, previously pending on IN_PROGRESS state, is now resolved.
- checkBlockLocation(blockManagerMaster.getBlockLocationHandler(blockId).getLocationFuture(), executorId);
+ // A future, previously pending on IN_PROGRESS state, is now completed exceptionally.
+ checkBlockAbsentException(future0, firstAttemptBlockId, BlockState.State.NOT_AVAILABLE);
+ checkBlockAbsentException(blockManagerMaster.getBlockLocationHandler(firstAttemptBlockId).getLocationFuture(), firstAttemptBlockId,
+ BlockState.State.NOT_AVAILABLE);
+ }
- // Then removed.
- blockManagerMaster.onBlockStateChanged(blockId, BlockState.State.NOT_AVAILABLE, executorId);
- checkBlockAbsentException(blockManagerMaster.getBlockLocationHandler(blockId).getLocationFuture(), blockId,
- BlockState.State.NOT_AVAILABLE);
+ // Second attempt
+ {
+ final String secondAttemptTaskId = RuntimeIdManager.generateTaskId("Stage0", srcTaskIndex, SECOND_ATTEPMT);
+ final String secondAttemptBlockId = RuntimeIdManager.generateBlockId(edgeId, secondAttemptTaskId);
+ final String executorId = RuntimeIdManager.generateExecutorId();
+
+ // Re-scheduling the task.
+ blockManagerMaster.onProducerTaskScheduled(secondAttemptTaskId, Collections.singleton(secondAttemptBlockId));
+ final Future<String> future1 = blockManagerMaster.getBlockLocationHandler(secondAttemptBlockId).getLocationFuture();
+ checkPendingFuture(future1);
+
+ // Committed.
+ blockManagerMaster.onBlockStateChanged(secondAttemptBlockId, BlockState.State.AVAILABLE, executorId);
+ checkBlockLocation(future1, executorId); // A future, previously pending on IN_PROGRESS state, is now resolved.
+ checkBlockLocation(blockManagerMaster.getBlockLocationHandler(secondAttemptBlockId).getLocationFuture(), executorId);
+
+ // Then removed.
+ blockManagerMaster.onBlockStateChanged(secondAttemptBlockId, BlockState.State.NOT_AVAILABLE, executorId);
+ checkBlockAbsentException(blockManagerMaster.getBlockLocationHandler(secondAttemptBlockId).getLocationFuture(), secondAttemptBlockId,
+ BlockState.State.NOT_AVAILABLE);
+ }
}
}
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/PlanStateManagerTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/PlanStateManagerTest.java
index 2003153..aeae40c 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/PlanStateManagerTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/PlanStateManagerTest.java
@@ -16,7 +16,7 @@
package edu.snu.nemo.runtime.master;
import edu.snu.nemo.conf.JobConf;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.common.message.local.LocalMessageDispatcher;
import edu.snu.nemo.runtime.common.message.local.LocalMessageEnvironment;
@@ -30,7 +30,6 @@ import org.apache.reef.tang.Injector;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
-import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import java.util.List;
@@ -45,16 +44,13 @@ import static org.mockito.Mockito.mock;
* Tests {@link PlanStateManager}.
*/
@RunWith(PowerMockRunner.class)
-@PrepareForTest(MetricMessageHandler.class)
public final class PlanStateManagerTest {
private static final int MAX_SCHEDULE_ATTEMPT = 2;
- private MetricMessageHandler metricMessageHandler;
@Before
public void setUp() throws Exception {
final Injector injector = LocalMessageEnvironment.forkInjector(LocalMessageDispatcher.getInjector(),
MessageEnvironment.MASTER_COMMUNICATION_ID);
- metricMessageHandler = mock(MetricMessageHandler.class);
injector.bindVolatileParameter(JobConf.DAGDirectory.class, "");
}
@@ -66,8 +62,7 @@ public final class PlanStateManagerTest {
public void testPhysicalPlanStateChanges() throws Exception {
final PhysicalPlan physicalPlan =
TestPlanGenerator.generatePhysicalPlan(TestPlanGenerator.PlanType.TwoVerticesJoined, false);
- final PlanStateManager planStateManager =
- new PlanStateManager(physicalPlan, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
+ final PlanStateManager planStateManager = new PlanStateManager(physicalPlan, MAX_SCHEDULE_ATTEMPT);
assertEquals(planStateManager.getPlanId(), "TestPlan");
@@ -75,11 +70,11 @@ public final class PlanStateManagerTest {
for (int stageIdx = 0; stageIdx < stageList.size(); stageIdx++) {
final Stage stage = stageList.get(stageIdx);
- final List<String> taskIds = stage.getTaskIds();
+ final List<String> taskIds = planStateManager.getTaskAttemptsToSchedule(stage.getId());
taskIds.forEach(taskId -> {
planStateManager.onTaskStateChanged(taskId, TaskState.State.EXECUTING);
planStateManager.onTaskStateChanged(taskId, TaskState.State.COMPLETE);
- if (RuntimeIdGenerator.getIndexFromTaskId(taskId) == taskIds.size() - 1) {
+ if (RuntimeIdManager.getIndexFromTaskId(taskId) == taskIds.size() - 1) {
assertEquals(StageState.State.COMPLETE, planStateManager.getStageState(stage.getId()));
}
});
@@ -98,8 +93,7 @@ public final class PlanStateManagerTest {
public void testWaitUntilFinish() throws Exception {
final PhysicalPlan physicalPlan =
TestPlanGenerator.generatePhysicalPlan(TestPlanGenerator.PlanType.TwoVerticesJoined, false);
- final PlanStateManager planStateManager =
- new PlanStateManager(physicalPlan, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
+ final PlanStateManager planStateManager = new PlanStateManager(physicalPlan, MAX_SCHEDULE_ATTEMPT);
assertFalse(planStateManager.isPlanDone());
@@ -111,7 +105,7 @@ public final class PlanStateManagerTest {
// Complete the plan and check the result again.
// It has to return COMPLETE.
final List<String> tasks = physicalPlan.getStageDAG().getTopologicalSort().stream()
- .flatMap(stage -> stage.getTaskIds().stream())
+ .flatMap(stage -> planStateManager.getTaskAttemptsToSchedule(stage.getId()).stream())
.collect(Collectors.toList());
tasks.forEach(taskId -> planStateManager.onTaskStateChanged(taskId, TaskState.State.EXECUTING));
tasks.forEach(taskId -> planStateManager.onTaskStateChanged(taskId, TaskState.State.COMPLETE));
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSchedulerTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSchedulerTest.java
index 9733c75..d98bb0d 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSchedulerTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSchedulerTest.java
@@ -58,12 +58,11 @@ import static org.mockito.Mockito.mock;
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest({ContainerManager.class, BlockManagerMaster.class,
- PubSubEventHandlerWrapper.class, UpdatePhysicalPlanEventHandler.class, MetricMessageHandler.class})
+ PubSubEventHandlerWrapper.class, UpdatePhysicalPlanEventHandler.class})
public final class BatchSchedulerTest {
private static final Logger LOG = LoggerFactory.getLogger(BatchSchedulerTest.class.getName());
private Scheduler scheduler;
private ExecutorRegistry executorRegistry;
- private final MetricMessageHandler metricMessageHandler = mock(MetricMessageHandler.class);
private final MessageSender<ControlMessage.Message> mockMsgSender = mock(MessageSender.class);
private static final int EXECUTOR_CAPACITY = 20;
@@ -134,7 +133,7 @@ public final class BatchSchedulerTest {
}
private void scheduleAndCheckPlanTermination(final PhysicalPlan plan) throws InjectionException {
- final PlanStateManager planStateManager = new PlanStateManager(plan, metricMessageHandler, 1);
+ final PlanStateManager planStateManager = new PlanStateManager(plan, 1);
scheduler.schedulePlan(plan, planStateManager);
// For each ScheduleGroup, test if the tasks of the next ScheduleGroup are scheduled
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java
index d31965b..8c4edf6 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java
@@ -46,7 +46,7 @@ final class SchedulerTestUtil {
// Stage has completed, so we break out of the loop.
break;
} else if (StageState.State.INCOMPLETE == stageState) {
- stage.getTaskIds().forEach(taskId -> {
+ planStateManager.getAllTaskAttemptsOfStage(stage.getId()).forEach(taskId -> {
final TaskState.State taskState = planStateManager.getTaskState(taskId);
if (TaskState.State.EXECUTING == taskState) {
sendTaskStateEventToScheduler(scheduler, executorRegistry, taskId,
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraintTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraintTest.java
index e3439d6..6af6a61 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraintTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraintTest.java
@@ -21,7 +21,7 @@ import edu.snu.nemo.common.ir.edge.executionproperty.CommunicationPatternPropert
import edu.snu.nemo.common.ir.edge.executionproperty.DataFlowProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.DataSkewMetricProperty;
import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.common.HashRange;
import edu.snu.nemo.common.KeyRange;
import edu.snu.nemo.runtime.common.plan.Stage;
@@ -46,6 +46,7 @@ import static org.mockito.Mockito.when;
@PrepareForTest({ExecutorRepresenter.class, Task.class, Stage.class, HashRange.class,
IRVertex.class, IREdge.class})
public final class SkewnessAwareSchedulingConstraintTest {
+ private final static int FIRST_ATTEMPT = 0;
private static StageEdge mockStageEdge(final int numSkewedHashRange,
final int numTotalHashRange) {
@@ -69,7 +70,7 @@ public final class SkewnessAwareSchedulingConstraintTest {
final IREdge dummyIREdge = new IREdge(CommunicationPatternProperty.Value.Shuffle, srcMockVertex, dstMockVertex);
dummyIREdge.setProperty(DataFlowProperty.of(DataFlowProperty.Value.Pull));
dummyIREdge.setProperty(DataSkewMetricProperty.of(new DataSkewMetricFactory(taskIdxToKeyRange)));
- final StageEdge dummyEdge = new StageEdge("Edge-0", dummyIREdge.getExecutionProperties(),
+ final StageEdge dummyEdge = new StageEdge("Edge0", dummyIREdge.getExecutionProperties(),
srcMockVertex, dstMockVertex, srcMockStage, dstMockStage, false);
return dummyEdge;
@@ -77,7 +78,7 @@ public final class SkewnessAwareSchedulingConstraintTest {
private static Task mockTask(final int taskIdx, final List<StageEdge> inEdges) {
final Task task = mock(Task.class);
- when(task.getTaskId()).thenReturn(RuntimeIdGenerator.generateTaskId(taskIdx, "Stage-0"));
+ when(task.getTaskId()).thenReturn(RuntimeIdManager.generateTaskId("Stage0", taskIdx, FIRST_ATTEMPT));
when(task.getTaskIncomingEdges()).thenReturn(inEdges);
return task;
}
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/TaskRetryTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/TaskRetryTest.java
index 2746bf5..6ebc7bc 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/TaskRetryTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/TaskRetryTest.java
@@ -17,14 +17,16 @@ package edu.snu.nemo.runtime.master.scheduler;
import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper;
import edu.snu.nemo.common.ir.vertex.executionproperty.ResourcePriorityProperty;
+import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
+import edu.snu.nemo.runtime.common.message.MessageEnvironment;
import edu.snu.nemo.runtime.common.message.MessageSender;
+import edu.snu.nemo.runtime.common.message.local.LocalMessageEnvironment;
import edu.snu.nemo.runtime.common.plan.PhysicalPlan;
import edu.snu.nemo.runtime.common.state.PlanState;
import edu.snu.nemo.runtime.common.state.TaskState;
import edu.snu.nemo.runtime.master.BlockManagerMaster;
import edu.snu.nemo.runtime.master.PlanStateManager;
-import edu.snu.nemo.runtime.master.MetricMessageHandler;
import edu.snu.nemo.runtime.master.eventhandler.UpdatePhysicalPlanEventHandler;
import edu.snu.nemo.runtime.master.resource.ExecutorRepresenter;
import edu.snu.nemo.runtime.master.resource.ResourceSpecification;
@@ -58,7 +60,7 @@ import static org.mockito.Mockito.mock;
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest({BlockManagerMaster.class, TaskDispatcher.class, SchedulingConstraintRegistry.class,
- PubSubEventHandlerWrapper.class, UpdatePhysicalPlanEventHandler.class, MetricMessageHandler.class})
+ PubSubEventHandlerWrapper.class, UpdatePhysicalPlanEventHandler.class})
public final class TaskRetryTest {
@Rule public TestName testName = new TestName();
@@ -88,7 +90,7 @@ public final class TaskRetryTest {
injector.bindVolatileInstance(PubSubEventHandlerWrapper.class, mock(PubSubEventHandlerWrapper.class));
injector.bindVolatileInstance(UpdatePhysicalPlanEventHandler.class, mock(UpdatePhysicalPlanEventHandler.class));
injector.bindVolatileInstance(SchedulingConstraintRegistry.class, mock(SchedulingConstraintRegistry.class));
- injector.bindVolatileInstance(BlockManagerMaster.class, mock(BlockManagerMaster.class));
+ injector.bindVolatileInstance(MessageEnvironment.class, mock(MessageEnvironment.class));
scheduler = injector.getInstance(Scheduler.class);
// Get PlanStateManager
@@ -182,7 +184,7 @@ public final class TaskRetryTest {
final int randomIndex = random.nextInt(executingTasks.size());
final String selectedTask = executingTasks.get(randomIndex);
SchedulerTestUtil.sendTaskStateEventToScheduler(scheduler, executorRegistry, selectedTask,
- TaskState.State.COMPLETE, planStateManager.getTaskAttempt(selectedTask));
+ TaskState.State.COMPLETE, RuntimeIdManager.getAttemptFromTaskId(selectedTask));
}
}
@@ -196,7 +198,7 @@ public final class TaskRetryTest {
final int randomIndex = random.nextInt(executingTasks.size());
final String selectedTask = executingTasks.get(randomIndex);
SchedulerTestUtil.sendTaskStateEventToScheduler(scheduler, executorRegistry, selectedTask,
- TaskState.State.SHOULD_RETRY, planStateManager.getTaskAttempt(selectedTask),
+ TaskState.State.SHOULD_RETRY, RuntimeIdManager.getAttemptFromTaskId(selectedTask),
TaskState.RecoverableTaskFailureCause.OUTPUT_WRITE_FAILURE);
}
}
@@ -204,16 +206,17 @@ public final class TaskRetryTest {
////////////////////////////////////////////////////////////////// Helper methods
private List<String> getTasksInState(final PlanStateManager planStateManager, final TaskState.State state) {
- return planStateManager.getAllTaskStates().entrySet().stream()
- .filter(entry -> entry.getValue().getStateMachine().getCurrentState().equals(state))
+ return planStateManager.getAllTaskAttemptIdsToItsState()
+ .entrySet()
+ .stream()
+ .filter(entry -> entry.getValue().equals(state))
.map(Map.Entry::getKey)
.collect(Collectors.toList());
}
private PlanStateManager runPhysicalPlan(final TestPlanGenerator.PlanType planType) throws Exception {
- final MetricMessageHandler metricMessageHandler = mock(MetricMessageHandler.class);
final PhysicalPlan plan = TestPlanGenerator.generatePhysicalPlan(planType, false);
- final PlanStateManager planStateManager = new PlanStateManager(plan, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
+ final PlanStateManager planStateManager = new PlanStateManager(plan, MAX_SCHEDULE_ATTEMPT);
scheduler.schedulePlan(plan, planStateManager);
return planStateManager;
}