You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by ja...@apache.org on 2018/06/14 08:57:33 UTC
[incubator-nemo] branch master updated: [NEMO-97] Refactor
TaskExecutor and fix a sideinput bug (#31)
This is an automated email from the ASF dual-hosted git repository.
jangho pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git
The following commit(s) were added to refs/heads/master by this push:
new 8df268f [NEMO-97] Refactor TaskExecutor and fix a sideinput bug (#31)
8df268f is described below
commit 8df268f8010f3dbf0d9304336b52d284db5ac453
Author: John Yang <jo...@gmail.com>
AuthorDate: Thu Jun 14 17:57:18 2018 +0900
[NEMO-97] Refactor TaskExecutor and fix a sideinput bug (#31)
JIRA: [NEMO-97: Refactor TaskExecutor and fix a sideinput bug](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-97)
**Major changes:**
- Enables intra-task sideinput processing
- General code refactoring for TaskExecutor
**Minor changes to note:**
- Introduces DataFetcher for reading task-external data
- Removes NCS logs which aren't very useful
**Tests for the changes:**
- Moves 'TaskExecutorTest' into the runtime package, and adds test cases including the one for checking intra-task sideinput processing
**Other comments:**
- N/A
resolves [NEMO-97](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-97)
---
.../main/java/edu/snu/nemo/common/ContextImpl.java | 6 +-
.../main/java/edu/snu/nemo/common/ir/Readable.java | 5 +-
.../nemo/common/ir/vertex/transform/Transform.java | 9 +-
.../beam/source/BeamBoundedSourceVertex.java | 3 +-
.../beam/transform/CreateViewTransform.java | 1 +
.../frontend/beam/transform/DoTransform.java | 10 +-
.../source/SparkDatasetBoundedSourceVertex.java | 12 +-
.../source/SparkTextFileBoundedSourceVertex.java | 3 +-
.../common/message/ncs/NcsMessageContext.java | 1 -
.../common/message/ncs/NcsMessageEnvironment.java | 5 -
.../common/message/ncs/NcsMessageSender.java | 4 -
.../edu/snu/nemo/runtime/common/plan/Task.java | 10 -
.../edu/snu/nemo/runtime/executor/Executor.java | 1 +
.../snu/nemo/runtime/executor/TaskExecutor.java | 757 ---------------------
.../executor/datatransfer/IRVertexDataHandler.java | 162 -----
.../runtime/executor/datatransfer/InputReader.java | 4 +-
.../executor/datatransfer/OutputCollectorImpl.java | 58 +-
.../nemo/runtime/executor/task/DataFetcher.java | 72 ++
.../executor/task/ParentTaskDataFetcher.java | 142 ++++
.../executor/task/SourceVertexDataFetcher.java | 58 ++
.../nemo/runtime/executor/task/TaskExecutor.java | 450 ++++++++++++
.../nemo/runtime/executor/task/VertexHarness.java | 108 +++
.../runtime/executor/task/TaskExecutorTest.java | 426 ++++++++++++
.../tests/runtime/executor/TaskExecutorTest.java | 280 --------
24 files changed, 1294 insertions(+), 1293 deletions(-)
diff --git a/common/src/main/java/edu/snu/nemo/common/ContextImpl.java b/common/src/main/java/edu/snu/nemo/common/ContextImpl.java
index 2e85afa..bb5aa21 100644
--- a/common/src/main/java/edu/snu/nemo/common/ContextImpl.java
+++ b/common/src/main/java/edu/snu/nemo/common/ContextImpl.java
@@ -23,18 +23,18 @@ import java.util.Map;
* Transform Context Implementation.
*/
public final class ContextImpl implements Transform.Context {
- private final Map<Transform, Object> sideInputs;
+ private final Map sideInputs;
/**
* Constructor of Context Implementation.
* @param sideInputs side inputs.
*/
- public ContextImpl(final Map<Transform, Object> sideInputs) {
+ public ContextImpl(final Map sideInputs) {
this.sideInputs = sideInputs;
}
@Override
- public Map<Transform, Object> getSideInputs() {
+ public Map getSideInputs() {
return this.sideInputs;
}
}
diff --git a/common/src/main/java/edu/snu/nemo/common/ir/Readable.java b/common/src/main/java/edu/snu/nemo/common/ir/Readable.java
index 9bca623..8f856a4 100644
--- a/common/src/main/java/edu/snu/nemo/common/ir/Readable.java
+++ b/common/src/main/java/edu/snu/nemo/common/ir/Readable.java
@@ -15,6 +15,7 @@
*/
package edu.snu.nemo.common.ir;
+import java.io.IOException;
import java.io.Serializable;
import java.util.List;
@@ -27,9 +28,9 @@ public interface Readable<O> extends Serializable {
* Method to read data from the source.
*
* @return an {@link Iterable} of the data read by the readable.
- * @throws Exception exception while reading data.
+ * @throws IOException exception while reading data.
*/
- Iterable<O> read() throws Exception;
+ Iterable<O> read() throws IOException;
/**
* Returns the list of locations where this readable resides.
diff --git a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
index 0f7b9d5..95fa539 100644
--- a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
+++ b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
@@ -46,12 +46,19 @@ public interface Transform<I, O> extends Serializable {
void close();
/**
+ * @return tag
+ */
+ default Object getTag() {
+ return null;
+ }
+
+ /**
* Context of the transform.
*/
interface Context {
/**
* @return sideInputs.
*/
- Map<Transform, Object> getSideInputs();
+ Map getSideInputs();
}
}
diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/source/BeamBoundedSourceVertex.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/source/BeamBoundedSourceVertex.java
index 1143a1a..e309d3f 100644
--- a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/source/BeamBoundedSourceVertex.java
+++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/source/BeamBoundedSourceVertex.java
@@ -17,6 +17,7 @@ package edu.snu.nemo.compiler.frontend.beam.source;
import edu.snu.nemo.common.ir.Readable;
+import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
@@ -96,7 +97,7 @@ public final class BeamBoundedSourceVertex<O> extends SourceVertex<O> {
}
@Override
- public Iterable<T> read() throws Exception {
+ public Iterable<T> read() throws IOException {
final ArrayList<T> elements = new ArrayList<>();
try (BoundedSource.BoundedReader<T> reader = boundedSource.createReader(null)) {
for (boolean available = reader.start(); available; available = reader.advance()) {
diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java
index b342595..dbfa004 100644
--- a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java
+++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java
@@ -60,6 +60,7 @@ public final class CreateViewTransform<I, O> implements Transform<I, O> {
* get the Tag of the Transform.
* @return the PCollectionView of the transform.
*/
+ @Override
public PCollectionView getTag() {
return this.pCollectionView;
}
diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java
index 883d401..b5d9690 100644
--- a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java
+++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java
@@ -33,7 +33,6 @@ import org.apache.beam.sdk.values.TupleTag;
import org.joda.time.Instant;
import java.io.IOException;
-import java.util.HashMap;
import java.util.Map;
/**
@@ -46,7 +45,6 @@ public final class DoTransform<I, O> implements Transform<I, O> {
private final DoFn doFn;
private final ObjectMapper mapper;
private final String serializedOptions;
- private Map<PCollectionView, Object> sideInputs;
private OutputCollector<O> outputCollector;
private StartBundleContext startBundleContext;
private FinishBundleContext finishBundleContext;
@@ -72,11 +70,9 @@ public final class DoTransform<I, O> implements Transform<I, O> {
@Override
public void prepare(final Context context, final OutputCollector<O> oc) {
this.outputCollector = oc;
- this.sideInputs = new HashMap<>();
- context.getSideInputs().forEach((k, v) -> this.sideInputs.put(((CreateViewTransform) k).getTag(), v));
this.startBundleContext = new StartBundleContext(doFn, serializedOptions);
this.finishBundleContext = new FinishBundleContext(doFn, outputCollector, serializedOptions);
- this.processContext = new ProcessContext(doFn, outputCollector, sideInputs, serializedOptions);
+ this.processContext = new ProcessContext(doFn, outputCollector, context.getSideInputs(), serializedOptions);
this.invoker = DoFnInvokers.invokerFor(doFn);
invoker.invokeSetup();
invoker.invokeStartBundle(startBundleContext);
@@ -195,7 +191,7 @@ public final class DoTransform<I, O> implements Transform<I, O> {
implements DoFnInvoker.ArgumentProvider<I, O> {
private I input;
private final OutputCollector<O> outputCollector;
- private final Map<PCollectionView, Object> sideInputs;
+ private final Map sideInputs;
private final ObjectMapper mapper;
private final PipelineOptions options;
@@ -209,7 +205,7 @@ public final class DoTransform<I, O> implements Transform<I, O> {
*/
ProcessContext(final DoFn<I, O> fn,
final OutputCollector<O> outputCollector,
- final Map<PCollectionView, Object> sideInputs,
+ final Map sideInputs,
final String serializedOptions) {
fn.super();
this.outputCollector = outputCollector;
diff --git a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkDatasetBoundedSourceVertex.java b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkDatasetBoundedSourceVertex.java
index 0746be5..3b05807 100644
--- a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkDatasetBoundedSourceVertex.java
+++ b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkDatasetBoundedSourceVertex.java
@@ -23,6 +23,8 @@ import org.apache.spark.*;
import org.apache.spark.rdd.RDD;
import scala.collection.JavaConverters;
+import javax.naming.OperationNotSupportedException;
+import java.io.IOException;
import java.util.*;
/**
@@ -105,12 +107,18 @@ public final class SparkDatasetBoundedSourceVertex<T> extends SourceVertex<T> {
}
@Override
- public Iterable<T> read() throws Exception {
+ public Iterable<T> read() throws IOException {
// for setting up the same environment in the executors.
final SparkSession spark = SparkSession.builder()
.config(sessionInitialConf)
.getOrCreate();
- final Dataset<T> dataset = SparkSession.initializeDataset(spark, commands);
+ final Dataset<T> dataset;
+
+ try {
+ dataset = SparkSession.initializeDataset(spark, commands);
+ } catch (final OperationNotSupportedException e) {
+ throw new IllegalStateException(e);
+ }
// Spark does lazy evaluation: it doesn't load the full dataset, but only the partition it is asked for.
final RDD<T> rdd = dataset.sparkRDD();
diff --git a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkTextFileBoundedSourceVertex.java b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkTextFileBoundedSourceVertex.java
index 5fab794..9b2fd38 100644
--- a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkTextFileBoundedSourceVertex.java
+++ b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkTextFileBoundedSourceVertex.java
@@ -21,6 +21,7 @@ import org.apache.spark.*;
import org.apache.spark.rdd.RDD;
import scala.collection.JavaConverters;
+import java.io.IOException;
import java.util.*;
/**
@@ -109,7 +110,7 @@ public final class SparkTextFileBoundedSourceVertex extends SourceVertex<String>
}
@Override
- public Iterable<String> read() throws Exception {
+ public Iterable<String> read() throws IOException {
// for setting up the same environment in the executors.
final SparkContext sparkContext = SparkContext.getOrCreate(sparkConf);
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageContext.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageContext.java
index 0a71d71..ea478b7 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageContext.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageContext.java
@@ -49,7 +49,6 @@ final class NcsMessageContext implements MessageContext {
@Override
@SuppressWarnings("squid:S2095")
public <U> void reply(final U replyMessage) {
- LOG.debug("[REPLY]: {}", replyMessage);
final Connection connection = connectionFactory.newConnection(idFactory.getNewInstance(senderId));
try {
connection.open();
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java
index 5dfcd6c..296a19e 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java
@@ -36,15 +36,11 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Future;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
/**
* Message environment for NCS.
*/
public final class NcsMessageEnvironment implements MessageEnvironment {
- private static final Logger LOG = LoggerFactory.getLogger(NcsMessageEnvironment.class.getName());
-
private static final String NCS_CONN_FACTORY_ID = "NCS_CONN_FACTORY_ID";
private final NetworkConnectionService networkConnectionService;
@@ -124,7 +120,6 @@ public final class NcsMessageEnvironment implements MessageEnvironment {
public void onNext(final Message<ControlMessage.Message> messages) {
final ControlMessage.Message controlMessage = extractSingleMessage(messages);
- LOG.debug("[RECEIVED]: msg={}", controlMessage);
final MessageType messageType = getMsgType(controlMessage);
switch (messageType) {
case Send:
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageSender.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageSender.java
index 5d1c61a..517fe54 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageSender.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageSender.java
@@ -42,15 +42,11 @@ final class NcsMessageSender implements MessageSender<ControlMessage.Message> {
@Override
public void send(final ControlMessage.Message message) {
- LOG.debug("[SEND]: msg.id={}, msg.listenerId={}",
- message.getId(), message.getListenerId());
connection.write(message);
}
@Override
public CompletableFuture<ControlMessage.Message> request(final ControlMessage.Message message) {
- LOG.debug("[REQUEST]: msg.id={}, msg.listenerId={}",
- message.getId(), message.getListenerId());
final CompletableFuture<ControlMessage.Message> future = replyFutureMap.beforeRequest(message.getId());
connection.write(message);
return future;
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java
index 5cd8b16..c5a1c3d 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java
@@ -16,7 +16,6 @@
package edu.snu.nemo.runtime.common.plan;
import edu.snu.nemo.common.ir.Readable;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
import java.io.Serializable;
import java.util.List;
@@ -28,7 +27,6 @@ import java.util.Map;
public final class Task implements Serializable {
private final String jobId;
private final String taskId;
- private final int taskIdx;
private final List<StageEdge> taskIncomingEdges;
private final List<StageEdge> taskOutgoingEdges;
private final int attemptIdx;
@@ -58,7 +56,6 @@ public final class Task implements Serializable {
final Map<String, Readable> irVertexIdToReadable) {
this.jobId = jobId;
this.taskId = taskId;
- this.taskIdx = RuntimeIdGenerator.getIndexFromTaskId(taskId);
this.attemptIdx = attemptIdx;
this.containerType = containerType;
this.serializedIRDag = serializedIRDag;
@@ -89,13 +86,6 @@ public final class Task implements Serializable {
}
/**
- * @return the idx of the task.
- */
- public int getTaskIdx() {
- return taskIdx;
- }
-
- /**
* @return the incoming edges of the task.
*/
public List<StageEdge> getTaskIncomingEdges() {
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java
index 287ca01..8b2925d 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java
@@ -31,6 +31,7 @@ import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
import edu.snu.nemo.runtime.common.plan.Task;
import edu.snu.nemo.runtime.executor.data.SerializerManager;
import edu.snu.nemo.runtime.executor.datatransfer.DataTransferFactory;
+import edu.snu.nemo.runtime.executor.task.TaskExecutor;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.reef.tang.annotations.Parameter;
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskExecutor.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskExecutor.java
deleted file mode 100644
index 4a9b1ee..0000000
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskExecutor.java
+++ /dev/null
@@ -1,757 +0,0 @@
-/*
- * Copyright (C) 2017 Seoul National University
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package edu.snu.nemo.runtime.executor;
-
-import edu.snu.nemo.common.ContextImpl;
-import edu.snu.nemo.common.Pair;
-import edu.snu.nemo.common.dag.DAG;
-import edu.snu.nemo.common.exception.BlockFetchException;
-import edu.snu.nemo.common.exception.BlockWriteException;
-import edu.snu.nemo.common.ir.Readable;
-import edu.snu.nemo.common.ir.vertex.*;
-import edu.snu.nemo.common.ir.vertex.transform.Transform;
-import edu.snu.nemo.runtime.common.plan.Task;
-import edu.snu.nemo.runtime.common.plan.StageEdge;
-import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
-import edu.snu.nemo.runtime.common.state.TaskState;
-import edu.snu.nemo.runtime.executor.data.DataUtil;
-import edu.snu.nemo.runtime.executor.datatransfer.*;
-
-import java.util.*;
-import java.util.concurrent.*;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.stream.Collectors;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * Executes a task.
- */
-public final class TaskExecutor {
- // Static variables
- private static final Logger LOG = LoggerFactory.getLogger(TaskExecutor.class.getName());
- private static final String ITERATORID_PREFIX = "ITERATOR_";
- private static final AtomicInteger ITERATORID_GENERATOR = new AtomicInteger(0);
-
- // From Task
- private final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag;
- private final String taskId;
- private final int taskIdx;
- private final TaskStateManager taskStateManager;
- private final List<StageEdge> stageIncomingEdges;
- private final List<StageEdge> stageOutgoingEdges;
- private Map<String, Readable> irVertexIdToReadable;
-
- // Other parameters
- private final DataTransferFactory channelFactory;
- private final MetricCollector metricCollector;
-
- // Data structures
- private final Map<InputReader, List<IRVertexDataHandler>> inputReaderToDataHandlersMap;
- private final Map<String, Iterator> idToSrcIteratorMap;
- private final Map<String, List<IRVertexDataHandler>> srcIteratorIdToDataHandlersMap;
- private final Map<String, List<IRVertexDataHandler>> iteratorIdToDataHandlersMap;
- private final LinkedBlockingQueue<Pair<String, DataUtil.IteratorWithNumBytes>> partitionQueue;
- private List<IRVertexDataHandler> irVertexDataHandlers;
- private Map<OutputCollectorImpl, List<IRVertexDataHandler>> outputToChildrenDataHandlersMap;
- private final Set<String> finishedVertexIds;
-
- // For metrics
- private long serBlockSize;
- private long encodedBlockSize;
-
- // Misc
- private boolean isExecuted;
- private String irVertexIdPutOnHold;
- private int numPartitions;
-
-
- /**
- * Constructor.
- * @param task Task with information needed during execution.
- * @param irVertexDag A DAG of vertices.
- * @param taskStateManager State manager for this Task.
- * @param channelFactory For reading from/writing to data to other Stages.
- * @param metricMessageSender For sending metric with execution stats to Master.
- */
- public TaskExecutor(final Task task,
- final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag,
- final TaskStateManager taskStateManager,
- final DataTransferFactory channelFactory,
- final MetricMessageSender metricMessageSender) {
- // Information from the Task.
- this.irVertexDag = irVertexDag;
- this.taskId = task.getTaskId();
- this.taskIdx = task.getTaskIdx();
- this.stageIncomingEdges = task.getTaskIncomingEdges();
- this.stageOutgoingEdges = task.getTaskOutgoingEdges();
- this.irVertexIdToReadable = task.getIrVertexIdToReadable();
-
- // Other parameters.
- this.taskStateManager = taskStateManager;
- this.channelFactory = channelFactory;
- this.metricCollector = new MetricCollector(metricMessageSender);
-
- // Initialize data structures.
- this.inputReaderToDataHandlersMap = new ConcurrentHashMap<>();
- this.idToSrcIteratorMap = new HashMap<>();
- this.srcIteratorIdToDataHandlersMap = new HashMap<>();
- this.iteratorIdToDataHandlersMap = new ConcurrentHashMap<>();
- this.partitionQueue = new LinkedBlockingQueue<>();
- this.outputToChildrenDataHandlersMap = new HashMap<>();
- this.irVertexDataHandlers = new ArrayList<>();
- this.finishedVertexIds = new HashSet<>();
-
- // Metrics
- this.serBlockSize = 0;
- this.encodedBlockSize = 0;
-
- // Misc
- this.isExecuted = false;
- this.irVertexIdPutOnHold = null;
- this.numPartitions = 0;
-
- initialize();
- }
-
- /**
- * Initializes this Task before execution.
- * 1) Create and connect reader/writers for both inter-Task data and intra-Task data.
- * 2) Prepares Transforms if needed.
- */
- private void initialize() {
- // Initialize data handlers for each IRVertex.
- irVertexDag.topologicalDo(irVertex -> irVertexDataHandlers.add(new IRVertexDataHandler(irVertex)));
-
- // Initialize data transfer.
- // Construct a pointer-based DAG of irVertexDataHandlers that are used for data transfer.
- // 'Pointer-based' means that it isn't Map/List-based in getting the data structure or parent/children
- // to avoid element-wise extra overhead of calculating hash values(HashMap) or iterating Lists.
- irVertexDag.topologicalDo(irVertex -> {
- final Set<StageEdge> inEdgesFromOtherStages = getInEdgesFromOtherStages(irVertex);
- final Set<StageEdge> outEdgesToOtherStages = getOutEdgesToOtherStages(irVertex);
- final IRVertexDataHandler dataHandler = getIRVertexDataHandler(irVertex);
-
- // Set data handlers of children irVertices.
- // This forms a pointer-based DAG of irVertexDataHandlers.
- final List<IRVertexDataHandler> childrenDataHandlers = new ArrayList<>();
- irVertexDag.getChildren(irVertex.getId()).forEach(child ->
- childrenDataHandlers.add(getIRVertexDataHandler(child)));
- dataHandler.setChildrenDataHandler(childrenDataHandlers);
-
- // Add InputReaders for inter-stage data transfer
- inEdgesFromOtherStages.forEach(stageEdge -> {
- final InputReader inputReader = channelFactory.createReader(
- taskIdx, stageEdge.getSrcVertex(), stageEdge);
-
- // For InputReaders that have side input, collect them separately.
- if (inputReader.isSideInputReader()) {
- dataHandler.addSideInputFromOtherStages(inputReader);
- } else {
- inputReaderToDataHandlersMap.putIfAbsent(inputReader, new ArrayList<>());
- inputReaderToDataHandlersMap.get(inputReader).add(dataHandler);
- }
- });
-
- // Add OutputWriters for inter-stage data transfer
- outEdgesToOtherStages.forEach(stageEdge -> {
- final OutputWriter outputWriter = channelFactory.createWriter(
- irVertex, taskIdx, stageEdge.getDstVertex(), stageEdge);
- dataHandler.addOutputWriter(outputWriter);
- });
-
- // Add InputPipes for intra-stage data transfer
- addInputFromThisStage(irVertex, dataHandler);
-
- // Add OutputPipe for intra-stage data transfer
- setOutputCollector(irVertex, dataHandler);
- });
-
- // Prepare Transforms if needed.
- irVertexDag.topologicalDo(irVertex -> {
- if (irVertex instanceof OperatorVertex) {
- final Transform transform = ((OperatorVertex) irVertex).getTransform();
- final Map<Transform, Object> sideInputMap = new HashMap<>();
- final IRVertexDataHandler dataHandler = getIRVertexDataHandler(irVertex);
- // Check and collect side inputs.
- if (!dataHandler.getSideInputFromOtherStages().isEmpty()) {
- sideInputFromOtherStages(irVertex, sideInputMap);
- }
- if (!dataHandler.getSideInputFromThisStage().isEmpty()) {
- sideInputFromThisStage(irVertex, sideInputMap);
- }
-
- final Transform.Context transformContext = new ContextImpl(sideInputMap);
- final OutputCollectorImpl outputCollector = dataHandler.getOutputCollector();
- transform.prepare(transformContext, outputCollector);
- }
- });
- }
-
- /**
- * Collect all inter-stage incoming edges of this vertex.
- *
- * @param irVertex the IRVertex whose inter-stage incoming edges to be collected.
- * @return the collected incoming edges.
- */
- private Set<StageEdge> getInEdgesFromOtherStages(final IRVertex irVertex) {
- return stageIncomingEdges.stream().filter(
- stageInEdge -> stageInEdge.getDstVertex().getId().equals(irVertex.getId()))
- .collect(Collectors.toSet());
- }
-
- /**
- * Collect all inter-stage outgoing edges of this vertex.
- *
- * @param irVertex the IRVertex whose inter-stage outgoing edges to be collected.
- * @return the collected outgoing edges.
- */
- private Set<StageEdge> getOutEdgesToOtherStages(final IRVertex irVertex) {
- return stageOutgoingEdges.stream().filter(
- stageInEdge -> stageInEdge.getSrcVertex().getId().equals(irVertex.getId()))
- .collect(Collectors.toSet());
- }
-
- /**
- * Add input OutputCollectors to each {@link IRVertex}.
- * Input OutputCollector denotes all the OutputCollectors of intra-Stage dependencies.
- *
- * @param irVertex the IRVertex to add input OutputCollectors to.
- */
- private void addInputFromThisStage(final IRVertex irVertex, final IRVertexDataHandler dataHandler) {
- List<IRVertex> parentVertices = irVertexDag.getParents(irVertex.getId());
- if (parentVertices != null) {
- parentVertices.forEach(parent -> {
- final OutputCollectorImpl parentOutputCollector = getIRVertexDataHandler(parent).getOutputCollector();
- if (parentOutputCollector.hasSideInputFor(irVertex.getId())) {
- dataHandler.addSideInputFromThisStage(parentOutputCollector);
- } else {
- dataHandler.addInputFromThisStages(parentOutputCollector);
- }
- });
- }
- }
-
- /**
- * Add outputCollectors to each {@link IRVertex}.
- * @param irVertex the IRVertex to add output outputCollectors to.
- */
- private void setOutputCollector(final IRVertex irVertex, final IRVertexDataHandler dataHandler) {
- final OutputCollectorImpl outputCollector = new OutputCollectorImpl();
- irVertexDag.getOutgoingEdgesOf(irVertex).forEach(outEdge -> {
- if (outEdge.isSideInput()) {
- outputCollector.setSideInputRuntimeEdge(outEdge);
- outputCollector.setAsSideInputFor(irVertex.getId());
- }
- });
-
- dataHandler.setOutputCollector(outputCollector);
- }
-
- /**
- * Check that this irVertex has OutputWriter for inter-stage data.
- *
- * @param irVertex the irVertex to check whether it has OutputWriters.
- * @return true if the irVertex has OutputWriters.
- */
- private boolean hasOutputWriter(final IRVertex irVertex) {
- return !getIRVertexDataHandler(irVertex).getOutputWriters().isEmpty();
- }
-
- private void setIRVertexPutOnHold(final MetricCollectionBarrierVertex irVertex) {
- irVertexIdPutOnHold = irVertex.getId();
- }
-
- /**
- * Finalize the output write of this Task.
- * As element-wise output write is done and the block is in memory,
- * flush the block into the designated data store and commit it.
- *
- * @param irVertex the IRVertex with OutputWriter to flush and commit output block.
- */
- private void writeAndCloseOutputWriters(final IRVertex irVertex) {
- final List<Long> writtenBytesList = new ArrayList<>();
- final Map<String, Object> metric = new HashMap<>();
- metricCollector.beginMeasurement(irVertex.getId(), metric);
- final long writeStartTime = System.currentTimeMillis();
-
- getIRVertexDataHandler(irVertex).getOutputWriters().forEach(outputWriter -> {
- outputWriter.close();
- final Optional<Long> writtenBytes = outputWriter.getWrittenBytes();
- writtenBytes.ifPresent(writtenBytesList::add);
- });
-
- final long writeEndTime = System.currentTimeMillis();
- metric.put("OutputWriteTime(ms)", writeEndTime - writeStartTime);
- putWrittenBytesMetric(writtenBytesList, metric);
- metricCollector.endMeasurement(irVertex.getId(), metric);
- }
-
- /**
- * Get input iterator from BoundedSource and bind it with id.
- */
- private void prepareInputFromSource() {
- irVertexDag.topologicalDo(irVertex -> {
- if (irVertex instanceof SourceVertex) {
- try {
- final String iteratorId = generateIteratorId();
- final Readable readable = irVertexIdToReadable.get(irVertex.getId());
- if (readable == null) {
- throw new RuntimeException(irVertex.toString());
- }
- final Iterator iterator = readable.read().iterator();
- idToSrcIteratorMap.putIfAbsent(iteratorId, iterator);
- srcIteratorIdToDataHandlersMap.putIfAbsent(iteratorId, new ArrayList<>());
- srcIteratorIdToDataHandlersMap.get(iteratorId).add(getIRVertexDataHandler(irVertex));
- } catch (final BlockFetchException ex) {
- taskStateManager.onTaskStateChanged(TaskState.State.FAILED_RECOVERABLE,
- Optional.empty(), Optional.of(TaskState.RecoverableFailureCause.INPUT_READ_FAILURE));
- LOG.error("{} Execution Failed (Recoverable: input read failure)! Exception: {}",
- taskId, ex.toString());
- } catch (final Exception e) {
- taskStateManager.onTaskStateChanged(TaskState.State.FAILED_UNRECOVERABLE,
- Optional.empty(), Optional.empty());
- LOG.error("{} Execution Failed! Exception: {}", taskId, e.toString());
- throw new RuntimeException(e);
- }
- }
- });
- }
-
- /**
- * Get input iterator from other stages received in the form of CompletableFuture
- * and bind it with id.
- */
- private void prepareInputFromOtherStages() {
- inputReaderToDataHandlersMap.forEach((inputReader, dataHandlers) -> {
- final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = inputReader.read();
- numPartitions += futures.size();
-
- // Add consumers which will push iterator when the futures are complete.
- futures.forEach(compFuture -> compFuture.whenComplete((iterator, exception) -> {
- if (exception != null) {
- throw new BlockFetchException(exception);
- }
-
- final String iteratorId = generateIteratorId();
- if (iteratorIdToDataHandlersMap.containsKey(iteratorId)) {
- throw new RuntimeException("iteratorIdToDataHandlersMap already contains " + iteratorId);
- } else {
- iteratorIdToDataHandlersMap.computeIfAbsent(iteratorId, absentIteratorId -> dataHandlers);
- try {
- partitionQueue.put(Pair.of(iteratorId, iterator));
- } catch (final InterruptedException e) {
- Thread.currentThread().interrupt();
- throw new BlockFetchException(e);
- }
- }
- }));
- });
- }
-
- /**
- * Check whether all vertices in this Task are finished.
- *
- * @return true if all vertices are finished.
- */
- private boolean finishedAllVertices() {
- // Total number of Tasks
- int vertexNum = irVertexDataHandlers.size();
- int finishedVertexNum = finishedVertexIds.size();
- return finishedVertexNum == vertexNum;
- }
-
- /**
- * Initialize the very first map of OutputCollector-children irVertex DAG.
- * In each map entry, the OutputCollector contains input data to be propagated through
- * the children irVertex DAG.
- */
- private void initializeOutputToChildrenDataHandlersMap() {
- srcIteratorIdToDataHandlersMap.values().forEach(dataHandlers ->
- dataHandlers.forEach(dataHandler -> {
- outputToChildrenDataHandlersMap.putIfAbsent(dataHandler.getOutputCollector(), dataHandler.getChildren());
- }));
- iteratorIdToDataHandlersMap.values().forEach(dataHandlers ->
- dataHandlers.forEach(dataHandler -> {
- outputToChildrenDataHandlersMap.putIfAbsent(dataHandler.getOutputCollector(), dataHandler.getChildren());
- }));
- }
-
- /**
- * Update the map of OutputCollector-children irVertex DAG.
- */
- private void updateOutputToChildrenDataHandlersMap() {
- Map<OutputCollectorImpl, List<IRVertexDataHandler>> currentMap = outputToChildrenDataHandlersMap;
- Map<OutputCollectorImpl, List<IRVertexDataHandler>> updatedMap = new HashMap<>();
-
- currentMap.values().forEach(dataHandlers ->
- dataHandlers.forEach(dataHandler -> {
- updatedMap.putIfAbsent(dataHandler.getOutputCollector(), dataHandler.getChildren());
- })
- );
-
- outputToChildrenDataHandlersMap = updatedMap;
- }
-
- /**
- * Update the map of OutputCollector-children irVertex DAG.
- *
- * @param irVertex the IRVertex with the transform to close.
- */
- private void closeTransform(final IRVertex irVertex) {
- if (irVertex instanceof OperatorVertex) {
- Transform transform = ((OperatorVertex) irVertex).getTransform();
- transform.close();
- }
- }
-
- /**
- * As a preprocessing of side input data, get inter stage side input
- * and form a map of source transform-side input.
- *
- * @param irVertex the IRVertex which receives side input from other stages.
- * @param sideInputMap the map of source transform-side input to build.
- */
- private void sideInputFromOtherStages(final IRVertex irVertex, final Map<Transform, Object> sideInputMap) {
- getIRVertexDataHandler(irVertex).getSideInputFromOtherStages().forEach(sideInputReader -> {
- try {
- final DataUtil.IteratorWithNumBytes sideInputIterator = sideInputReader.read().get(0).get();
- final Object sideInput = getSideInput(sideInputIterator);
- final RuntimeEdge inEdge = sideInputReader.getRuntimeEdge();
- final Transform srcTransform;
- if (inEdge instanceof StageEdge) {
- srcTransform = ((OperatorVertex) ((StageEdge) inEdge).getSrcVertex()).getTransform();
- } else {
- srcTransform = ((OperatorVertex) inEdge.getSrc()).getTransform();
- }
- sideInputMap.put(srcTransform, sideInput);
-
- // Collect metrics on block size if possible.
- try {
- serBlockSize += sideInputIterator.getNumSerializedBytes();
- } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
- serBlockSize = -1;
- }
- try {
- encodedBlockSize += sideInputIterator.getNumEncodedBytes();
- } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
- encodedBlockSize = -1;
- }
- } catch (final InterruptedException e) {
- Thread.currentThread().interrupt();
- throw new BlockFetchException(e);
- } catch (final ExecutionException e1) {
- throw new RuntimeException("Failed while reading side input from other stages " + e1);
- }
- });
- }
-
- /**
- * As a preprocessing of side input data, get intra stage side input
- * and form a map of source transform-side input.
- * Assumption: intra stage side input denotes a data element initially received
- * via side input reader from other stages.
- *
- * @param irVertex the IRVertex which receives the data element marked as side input.
- * @param sideInputMap the map of source transform-side input to build.
- */
- private void sideInputFromThisStage(final IRVertex irVertex, final Map<Transform, Object> sideInputMap) {
- getIRVertexDataHandler(irVertex).getSideInputFromThisStage().forEach(input -> {
- // because sideInput is only 1 element in the outputCollector
- Object sideInput = input.remove();
- final RuntimeEdge inEdge = input.getSideInputRuntimeEdge();
- final Transform srcTransform;
- if (inEdge instanceof StageEdge) {
- srcTransform = ((OperatorVertex) ((StageEdge) inEdge).getSrcVertex()).getTransform();
- } else {
- srcTransform = ((OperatorVertex) inEdge.getSrc()).getTransform();
- }
- sideInputMap.put(srcTransform, sideInput);
- });
- }
-
- /**
- * Executes the task.
- */
- public void execute() {
- final Map<String, Object> metric = new HashMap<>();
- metricCollector.beginMeasurement(taskId, metric);
- long boundedSrcReadStartTime = 0;
- long boundedSrcReadEndTime = 0;
- long inputReadStartTime = 0;
- long inputReadEndTime = 0;
- if (isExecuted) {
- throw new RuntimeException("Task {" + taskId + "} execution called again!");
- }
- isExecuted = true;
- taskStateManager.onTaskStateChanged(TaskState.State.EXECUTING, Optional.empty(), Optional.empty());
- LOG.info("{} Executing!", taskId);
-
- // Prepare input data from bounded source.
- boundedSrcReadStartTime = System.currentTimeMillis();
- prepareInputFromSource();
- boundedSrcReadEndTime = System.currentTimeMillis();
- metric.put("BoundedSourceReadTime(ms)", boundedSrcReadEndTime - boundedSrcReadStartTime);
-
- // Prepare input data from other stages.
- inputReadStartTime = System.currentTimeMillis();
- prepareInputFromOtherStages();
-
- // Execute the IRVertex DAG.
- try {
- srcIteratorIdToDataHandlersMap.forEach((srcIteratorId, dataHandlers) -> {
- Iterator iterator = idToSrcIteratorMap.get(srcIteratorId);
- iterator.forEachRemaining(element -> {
- for (final IRVertexDataHandler dataHandler : dataHandlers) {
- runTask(dataHandler, element);
- }
- });
- });
-
- // Process data from other stages.
- for (int currPartition = 0; currPartition < numPartitions; currPartition++) {
- Pair<String, DataUtil.IteratorWithNumBytes> idToIteratorPair = partitionQueue.take();
- final String iteratorId = idToIteratorPair.left();
- final DataUtil.IteratorWithNumBytes iterator = idToIteratorPair.right();
- List<IRVertexDataHandler> dataHandlers = iteratorIdToDataHandlersMap.get(iteratorId);
- iterator.forEachRemaining(element -> {
- for (final IRVertexDataHandler dataHandler : dataHandlers) {
- runTask(dataHandler, element);
- }
- });
-
- // Collect metrics on block size if possible.
- try {
- serBlockSize += iterator.getNumSerializedBytes();
- } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
- serBlockSize = -1;
- } catch (final IllegalStateException e) {
- LOG.error("Failed to get the number of bytes of serialized data - the data is not ready yet ", e);
- }
- try {
- encodedBlockSize += iterator.getNumEncodedBytes();
- } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
- encodedBlockSize = -1;
- } catch (final IllegalStateException e) {
- LOG.error("Failed to get the number of bytes of encoded data - the data is not ready yet ", e);
- }
- }
- inputReadEndTime = System.currentTimeMillis();
- metric.put("InputReadTime(ms)", inputReadEndTime - inputReadStartTime);
-
- // Process intra-Task data.
- // Intra-Task data comes from outputCollectors of this Task's vertices.
- initializeOutputToChildrenDataHandlersMap();
- while (!finishedAllVertices()) {
- outputToChildrenDataHandlersMap.forEach((outputCollector, childrenDataHandlers) -> {
- // Get the vertex that has this outputCollector as its output outputCollector
- final IRVertex outputProducer = irVertexDataHandlers.stream()
- .filter(dataHandler -> dataHandler.getOutputCollector() == outputCollector)
- .findFirst().get().getIRVertex();
-
- // Before consuming the output of outputProducer as input,
- // close transform if it is OperatorTransform.
- closeTransform(outputProducer);
-
- // Set outputProducer as finished.
- finishedVertexIds.add(outputProducer.getId());
-
- while (!outputCollector.isEmpty()) {
- final Object element = outputCollector.remove();
-
- // Pass outputProducer's output to its children tasks recursively.
- if (!childrenDataHandlers.isEmpty()) {
- for (final IRVertexDataHandler childDataHandler : childrenDataHandlers) {
- runTask(childDataHandler, element);
- }
- }
-
- // Write element-wise to OutputWriters if any and close the OutputWriters.
- if (hasOutputWriter(outputProducer)) {
- // If outputCollector isn't empty(if closeTransform produced some output),
- // write them element-wise to OutputWriters.
- List<OutputWriter> outputWritersOfTask =
- getIRVertexDataHandler(outputProducer).getOutputWriters();
- outputWritersOfTask.forEach(outputWriter -> outputWriter.write(element));
- }
- }
-
- if (hasOutputWriter(outputProducer)) {
- writeAndCloseOutputWriters(outputProducer);
- }
- });
- updateOutputToChildrenDataHandlersMap();
- }
- } catch (final BlockWriteException ex2) {
- taskStateManager.onTaskStateChanged(TaskState.State.FAILED_RECOVERABLE,
- Optional.empty(), Optional.of(TaskState.RecoverableFailureCause.OUTPUT_WRITE_FAILURE));
- LOG.error("{} Execution Failed (Recoverable: output write failure)! Exception: {}",
- taskId, ex2.toString());
- } catch (final Exception e) {
- taskStateManager.onTaskStateChanged(TaskState.State.FAILED_UNRECOVERABLE,
- Optional.empty(), Optional.empty());
- LOG.error("{} Execution Failed! Exception: {}",
- taskId, e.toString());
- throw new RuntimeException(e);
- }
-
- // Put Task-unit metrics.
- final boolean available = serBlockSize >= 0;
- putReadBytesMetric(available, serBlockSize, encodedBlockSize, metric);
- metricCollector.endMeasurement(taskId, metric);
- if (irVertexIdPutOnHold == null) {
- taskStateManager.onTaskStateChanged(TaskState.State.COMPLETE, Optional.empty(), Optional.empty());
- } else {
- taskStateManager.onTaskStateChanged(TaskState.State.ON_HOLD,
- Optional.of(irVertexIdPutOnHold),
- Optional.empty());
- }
- LOG.info("{} Complete!", taskId);
- }
-
- /**
- * Recursively executes a vertex with the input data element.
- *
- * @param dataHandler IRVertexDataHandler of a vertex to execute.
- * @param dataElement input data element to process.
- */
- private void runTask(final IRVertexDataHandler dataHandler, final Object dataElement) {
- final IRVertex irVertex = dataHandler.getIRVertex();
- final OutputCollectorImpl outputCollector = dataHandler.getOutputCollector();
-
- // Process element-wise depending on the vertex type
- if (irVertex instanceof SourceVertex) {
- if (dataElement == null) { // null used for Beam VoidCoders
- final List<Object> nullForVoidCoder = Collections.singletonList(dataElement);
- outputCollector.emit(nullForVoidCoder);
- } else {
- outputCollector.emit(dataElement);
- }
- } else if (irVertex instanceof OperatorVertex) {
- final Transform transform = ((OperatorVertex) irVertex).getTransform();
- transform.onData(dataElement);
- } else if (irVertex instanceof MetricCollectionBarrierVertex) {
- if (dataElement == null) { // null used for Beam VoidCoders
- final List<Object> nullForVoidCoder = Collections.singletonList(dataElement);
- outputCollector.emit(nullForVoidCoder);
- } else {
- outputCollector.emit(dataElement);
- }
- setIRVertexPutOnHold((MetricCollectionBarrierVertex) irVertex);
- } else {
- throw new UnsupportedOperationException("This type of IRVertex is not supported");
- }
-
- // For the produced output
- while (!outputCollector.isEmpty()) {
- final Object element = outputCollector.remove();
-
- // Pass output to its children recursively.
- List<IRVertexDataHandler> childrenDataHandlers = dataHandler.getChildren();
- if (!childrenDataHandlers.isEmpty()) {
- for (final IRVertexDataHandler childDataHandler : childrenDataHandlers) {
- runTask(childDataHandler, element);
- }
- }
-
- // Write element-wise to OutputWriters if any
- if (hasOutputWriter(irVertex)) {
- List<OutputWriter> outputWritersOfTask = dataHandler.getOutputWriters();
- outputWritersOfTask.forEach(outputWriter -> outputWriter.write(element));
- }
- }
- }
-
- /**
- * Generate a unique iterator id.
- *
- * @return the iterator id.
- */
- private String generateIteratorId() {
- return ITERATORID_PREFIX + ITERATORID_GENERATOR.getAndIncrement();
- }
-
- private IRVertexDataHandler getIRVertexDataHandler(final IRVertex irVertex) {
- return irVertexDataHandlers.stream()
- .filter(dataHandler -> dataHandler.getIRVertex() == irVertex)
- .findFirst().get();
- }
-
- /**
- * Puts read bytes metric if the input data size is known.
- *
- * @param serializedBytes size in serialized (encoded and optionally post-processed (e.g. compressed)) form
- * @param encodedBytes size in encoded form
- * @param metricMap the metric map to put written bytes metric.
- */
- private static void putReadBytesMetric(final boolean available,
- final long serializedBytes,
- final long encodedBytes,
- final Map<String, Object> metricMap) {
- if (available) {
- if (serializedBytes != encodedBytes) {
- metricMap.put("ReadBytes(raw)", serializedBytes);
- }
- metricMap.put("ReadBytes", encodedBytes);
- }
- }
-
- /**
- * Puts written bytes metric if the output data size is known.
- *
- * @param writtenBytesList the list of written bytes.
- * @param metricMap the metric map to put written bytes metric.
- */
- private static void putWrittenBytesMetric(final List<Long> writtenBytesList,
- final Map<String, Object> metricMap) {
- if (!writtenBytesList.isEmpty()) {
- long totalWrittenBytes = 0;
- for (final Long writtenBytes : writtenBytesList) {
- totalWrittenBytes += writtenBytes;
- }
- metricMap.put("WrittenBytes", totalWrittenBytes);
- }
- }
-
- /**
- * Get sideInput from data from {@link InputReader}.
- *
- * @param iterator data from {@link InputReader#read()}
- * @return The corresponding sideInput
- */
- private static Object getSideInput(final DataUtil.IteratorWithNumBytes iterator) {
- final List copy = new ArrayList();
- iterator.forEachRemaining(copy::add);
- if (copy.size() == 1) {
- return copy.get(0);
- } else {
- if (copy.get(0) instanceof Iterable) {
- final List collect = new ArrayList();
- copy.forEach(element -> ((Iterable) element).iterator().forEachRemaining(collect::add));
- return collect;
- } else if (copy.get(0) instanceof Map) {
- final Map collect = new HashMap();
- copy.forEach(element -> {
- final Set keySet = ((Map) element).keySet();
- keySet.forEach(key -> collect.put(key, ((Map) element).get(key)));
- });
- return collect;
- } else {
- return copy;
- }
- }
- }
-}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/IRVertexDataHandler.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/IRVertexDataHandler.java
deleted file mode 100644
index 84b7a8e..0000000
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/IRVertexDataHandler.java
+++ /dev/null
@@ -1,162 +0,0 @@
-/*
- * Copyright (C) 2018 Seoul National University
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package edu.snu.nemo.runtime.executor.datatransfer;
-
-import edu.snu.nemo.common.ir.vertex.IRVertex;
-
-import java.util.ArrayList;
-import java.util.List;
-
-/**
- * Per-Task data handler.
- * This is a wrapper class that handles data transfer of a Task.
- * As Task input is processed element-wise, Task output element percolates down
- * through the DAG of children TaskDataHandlers.
- */
-public final class IRVertexDataHandler {
- private final IRVertex irVertex;
- private List<IRVertexDataHandler> children;
- private final List<OutputCollectorImpl> inputFromThisStage;
- private final List<InputReader> sideInputFromOtherStages;
- private final List<OutputCollectorImpl> sideInputFromThisStage;
- private OutputCollectorImpl outputCollector;
- private final List<OutputWriter> outputWriters;
-
- /**
- * IRVertexDataHandler Constructor.
- *
- * @param irVertex Task of this IRVertexDataHandler.
- */
- public IRVertexDataHandler(final IRVertex irVertex) {
- this.irVertex = irVertex;
- this.children = new ArrayList<>();
- this.inputFromThisStage = new ArrayList<>();
- this.sideInputFromOtherStages = new ArrayList<>();
- this.sideInputFromThisStage = new ArrayList<>();
- this.outputCollector = null;
- this.outputWriters = new ArrayList<>();
- }
-
- /**
- * Get the irVertex that owns this IRVertexDataHandler.
- *
- * @return irVertex of this IRVertexDataHandler.
- */
- public IRVertex getIRVertex() {
- return irVertex;
- }
-
- /**
- * Get a DAG of children tasks' TaskDataHandlers.
- *
- * @return DAG of children tasks' TaskDataHandlers.
- */
- public List<IRVertexDataHandler> getChildren() {
- return children;
- }
-
- /**
- * Get side input from other Task.
- *
- * @return InputReader that has side input.
- */
- public List<InputReader> getSideInputFromOtherStages() {
- return sideInputFromOtherStages;
- }
-
- /**
- * Get intra-Task side input from parent tasks.
- * Just like normal intra-Task inputs, intra-Task side inputs are
- * collected in parent tasks' OutputCollectors.
- *
- * @return OutputCollectors of all parent tasks which are marked as having side input.
- */
- public List<OutputCollectorImpl> getSideInputFromThisStage() {
- return sideInputFromThisStage;
- }
-
- /**
- * Get OutputCollector of this irVertex.
- *
- * @return OutputCollector of this irVertex.
- */
- public OutputCollectorImpl getOutputCollector() {
- return outputCollector;
- }
-
- /**
- * Get OutputWriters of this irVertex.
- *
- * @return OutputWriters of this irVertex.
- */
- public List<OutputWriter> getOutputWriters() {
- return outputWriters;
- }
-
- /**
- * Set a DAG of children tasks' DataHandlers.
- *
- * @param childrenDataHandler list of children TaskDataHandlers.
- */
- public void setChildrenDataHandler(final List<IRVertexDataHandler> childrenDataHandler) {
- children = childrenDataHandler;
- }
-
- /**
- * Add OutputCollector of a parent irVertex that will provide intra-stage input.
- *
- * @param input OutputCollector of a parent irVertex.
- */
- public void addInputFromThisStages(final OutputCollectorImpl input) {
- inputFromThisStage.add(input);
- }
-
- /**
- * Add InputReader that will provide inter-stage side input.
- *
- * @param sideInputReader InputReader that will provide inter-stage side input.
- */
- public void addSideInputFromOtherStages(final InputReader sideInputReader) {
- sideInputFromOtherStages.add(sideInputReader);
- }
-
- /**
- * Add OutputCollector of a parent irVertex that will provide intra-stage side input.
- *
- * @param ocAsSideInput OutputCollector of a parent irVertex with side input.
- */
- public void addSideInputFromThisStage(final OutputCollectorImpl ocAsSideInput) {
- sideInputFromThisStage.add(ocAsSideInput);
- }
-
- /**
- * Set OutputCollector of this irVertex.
- *
- * @param oc OutputCollector of this irVertex.
- */
- public void setOutputCollector(final OutputCollectorImpl oc) {
- outputCollector = oc;
- }
-
- /**
- * Add OutputWriter of this irVertex.
- *
- * @param outputWriter OutputWriter of this irVertex.
- */
- public void addOutputWriter(final OutputWriter outputWriter) {
- outputWriters.add(outputWriter);
- }
-}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java
index a01e58c..176acdd 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java
@@ -154,8 +154,8 @@ public final class InputReader extends DataTransfer {
return RuntimeIdGenerator.generateBlockId(duplicateEdgeId, taskIdx);
}
- public String getSrcIrVertexId() {
- return srcVertex.getId();
+ public IRVertex getSrcIrVertex() {
+ return srcVertex;
}
public boolean isSideInputReader() {
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java
index 0588769..16697d6 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java
@@ -16,11 +16,9 @@
package edu.snu.nemo.runtime.executor.datatransfer;
import edu.snu.nemo.common.ir.OutputCollector;
-import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
import java.util.ArrayDeque;
-import java.util.ArrayList;
-import java.util.List;
+import java.util.Queue;
/**
* OutputCollector implementation.
@@ -28,17 +26,13 @@ import java.util.List;
* @param <O> output type.
*/
public final class OutputCollectorImpl<O> implements OutputCollector<O> {
- private final ArrayDeque<O> outputQueue;
- private RuntimeEdge sideInputRuntimeEdge;
- private List<String> sideInputReceivers;
+ private final Queue<O> outputQueue;
/**
* Constructor of a new OutputCollectorImpl.
*/
public OutputCollectorImpl() {
- this.outputQueue = new ArrayDeque<>();
- this.sideInputRuntimeEdge = null;
- this.sideInputReceivers = new ArrayList<>();
+ this.outputQueue = new ArrayDeque<>(1);
}
@Override
@@ -69,50 +63,4 @@ public final class OutputCollectorImpl<O> implements OutputCollector<O> {
public boolean isEmpty() {
return outputQueue.isEmpty();
}
-
- /**
- * Return the size of this OutputCollector.
- *
- * @return the total number of elements in this OutputCollector.
- */
- public int size() {
- return outputQueue.size();
- }
-
- /**
- * Mark this edge as side input so that TaskExecutor can retrieve
- * source transform using it.
- *
- * @param edge the RuntimeEdge to mark as side input.
- */
- public void setSideInputRuntimeEdge(final RuntimeEdge edge) {
- sideInputRuntimeEdge = edge;
- }
-
- /**
- * Get the RuntimeEdge marked as side input.
- *
- * @return the RuntimeEdge marked as side input.
- */
- public RuntimeEdge getSideInputRuntimeEdge() {
- return sideInputRuntimeEdge;
- }
-
- /**
- * Set this OutputCollector as having side input for the given child task.
- *
- * @param physicalTaskId the id of child task whose side input will be put into this OutputCollector.
- */
- public void setAsSideInputFor(final String physicalTaskId) {
- sideInputReceivers.add(physicalTaskId);
- }
-
- /**
- * Check if this OutputCollector has side input for the given child task.
- *
- * @return true if it contains side input for child task of the given id.
- */
- public boolean hasSideInputFor(final String physicalTaskId) {
- return sideInputReceivers.contains(physicalTaskId);
- }
}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/DataFetcher.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/DataFetcher.java
new file mode 100644
index 0000000..3dbc689
--- /dev/null
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/DataFetcher.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.runtime.executor.task;
+
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+
+import java.io.IOException;
+import java.util.Map;
+
+/**
+ * An abstraction for fetching data from task-external sources.
+ */
+abstract class DataFetcher {
+ private final IRVertex dataSource;
+ private final VertexHarness child;
+ private final Map<String, Object> metricMap;
+ private final boolean isToSideInput;
+ private final boolean isFromSideInput;
+
+ DataFetcher(final IRVertex dataSource,
+ final VertexHarness child,
+ final Map<String, Object> metricMap,
+ final boolean isFromSideInput,
+ final boolean isToSideInput) {
+ this.dataSource = dataSource;
+ this.child = child;
+ this.metricMap = metricMap;
+ this.isToSideInput = isToSideInput;
+ this.isFromSideInput = isFromSideInput;
+ }
+
+ /**
+ * Can block until the next data element becomes available.
+ *
+ * @return null if there's no more data element.
+ * @throws IOException while fetching data
+ */
+ abstract Object fetchDataElement() throws IOException;
+
+ protected Map<String, Object> getMetricMap() {
+ return metricMap;
+ }
+
+ VertexHarness getChild() {
+ return child;
+ }
+
+ public IRVertex getDataSource() {
+ return dataSource;
+ }
+
+ boolean isFromSideInput() {
+ return isFromSideInput;
+ }
+
+ boolean isToSideInput() {
+ return isToSideInput;
+ }
+}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcher.java
new file mode 100644
index 0000000..2abb3b7
--- /dev/null
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcher.java
@@ -0,0 +1,142 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.runtime.executor.task;
+
+import edu.snu.nemo.common.exception.BlockFetchException;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.runtime.executor.data.DataUtil;
+import edu.snu.nemo.runtime.executor.datatransfer.InputReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.concurrent.NotThreadSafe;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.LinkedBlockingQueue;
+
+/**
+ * Fetches data from parent tasks.
+ */
+@NotThreadSafe
+class ParentTaskDataFetcher extends DataFetcher {
+ private static final Logger LOG = LoggerFactory.getLogger(ParentTaskDataFetcher.class);
+
+ private final InputReader readersForParentTask;
+ private final LinkedBlockingQueue<DataUtil.IteratorWithNumBytes> dataQueue;
+
+ // Non-finals (lazy fetching)
+ private boolean hasFetchStarted;
+ private int expectedNumOfIterators;
+ private DataUtil.IteratorWithNumBytes currentIterator;
+ private int currentIteratorIndex;
+ private boolean noElementAtAll = true;
+
+ ParentTaskDataFetcher(final IRVertex dataSource,
+ final InputReader readerForParentTask,
+ final VertexHarness child,
+ final Map<String, Object> metricMap,
+ final boolean isFromSideInput,
+ final boolean isToSideInput) {
+ super(dataSource, child, metricMap, isFromSideInput, isToSideInput);
+ this.readersForParentTask = readerForParentTask;
+ this.hasFetchStarted = false;
+ this.dataQueue = new LinkedBlockingQueue<>();
+ }
+
+ private void handleMetric(final DataUtil.IteratorWithNumBytes iterator) {
+ long serBytes = 0;
+ long encodedBytes = 0;
+ try {
+ serBytes += iterator.getNumSerializedBytes();
+ } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
+ serBytes = -1;
+ } catch (final IllegalStateException e) {
+ LOG.error("Failed to get the number of bytes of serialized data - the data is not ready yet ", e);
+ }
+ try {
+ encodedBytes += iterator.getNumEncodedBytes();
+ } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
+ encodedBytes = -1;
+ } catch (final IllegalStateException e) {
+ LOG.error("Failed to get the number of bytes of encoded data - the data is not ready yet ", e);
+ }
+ if (serBytes != encodedBytes) {
+ getMetricMap().put("ReadBytes(raw)", serBytes);
+ }
+ getMetricMap().put("ReadBytes", encodedBytes);
+ }
+
+ /**
+ * Blocking call.
+ */
+ private void fetchInBackground() {
+ final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = readersForParentTask.read();
+ this.expectedNumOfIterators = futures.size();
+
+ futures.forEach(compFuture -> compFuture.whenComplete((iterator, exception) -> {
+ if (exception != null) {
+ throw new BlockFetchException(exception);
+ }
+
+ try {
+ dataQueue.put(iterator); // can block here
+ } catch (final InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new BlockFetchException(e);
+ }
+ }));
+ }
+
+ @Override
+ Object fetchDataElement() throws IOException {
+ try {
+ if (!hasFetchStarted) {
+ fetchInBackground();
+ hasFetchStarted = true;
+ this.currentIterator = dataQueue.take();
+ this.currentIteratorIndex = 1;
+ }
+
+ if (this.currentIterator.hasNext()) {
+ noElementAtAll = false;
+ return this.currentIterator.next();
+ } else {
+ // This iterator is done, proceed to the next iterator
+ if (currentIteratorIndex == expectedNumOfIterators) {
+ // No more iterator left
+ if (noElementAtAll) {
+ // This shouldn't normally happen, except for cases such as when Beam's VoidCoder is used.
+ noElementAtAll = false;
+ return Void.TYPE;
+ } else {
+ // This whole fetcher's done
+ return null;
+ }
+ } else {
+ handleMetric(currentIterator);
+ // Try the next iterator
+ this.currentIteratorIndex += 1;
+ this.currentIterator = dataQueue.take();
+ return fetchDataElement();
+ }
+ }
+ } catch (InterruptedException exception) {
+ throw new IOException(exception);
+ }
+ }
+}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/SourceVertexDataFetcher.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/SourceVertexDataFetcher.java
new file mode 100644
index 0000000..998df63
--- /dev/null
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/SourceVertexDataFetcher.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.runtime.executor.task;
+
+import edu.snu.nemo.common.ir.Readable;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.Map;
+
+/**
+ * Fetches data from a data source.
+ */
+class SourceVertexDataFetcher extends DataFetcher {
+ private final Readable readable;
+
+ // Non-finals (lazy fetching)
+ private Iterator iterator;
+
+ SourceVertexDataFetcher(final IRVertex dataSource,
+ final Readable readable,
+ final VertexHarness child,
+ final Map<String, Object> metricMap,
+ final boolean isFromSideInput,
+ final boolean isToSideInput) {
+ super(dataSource, child, metricMap, isFromSideInput, isToSideInput);
+ this.readable = readable;
+ }
+
+ @Override
+ Object fetchDataElement() throws IOException {
+ if (iterator == null) {
+ final long start = System.currentTimeMillis();
+ iterator = this.readable.read().iterator();
+ getMetricMap().put("BoundedSourceReadTime(ms)", System.currentTimeMillis() - start);
+ }
+
+ if (iterator.hasNext()) {
+ return iterator.next();
+ } else {
+ return null;
+ }
+ }
+}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
new file mode 100644
index 0000000..27931a0
--- /dev/null
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
@@ -0,0 +1,450 @@
+/*
+ * Copyright (C) 2017 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.runtime.executor.task;
+
+import com.google.common.collect.Lists;
+import edu.snu.nemo.common.ContextImpl;
+import edu.snu.nemo.common.Pair;
+import edu.snu.nemo.common.dag.DAG;
+import edu.snu.nemo.common.ir.Readable;
+import edu.snu.nemo.common.ir.vertex.*;
+import edu.snu.nemo.common.ir.vertex.transform.Transform;
+import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.plan.Task;
+import edu.snu.nemo.runtime.common.plan.StageEdge;
+import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
+import edu.snu.nemo.runtime.common.state.TaskState;
+import edu.snu.nemo.runtime.executor.MetricCollector;
+import edu.snu.nemo.runtime.executor.MetricMessageSender;
+import edu.snu.nemo.runtime.executor.TaskStateManager;
+import edu.snu.nemo.runtime.executor.datatransfer.*;
+
+import java.io.IOException;
+import java.util.*;
+import java.util.stream.Collectors;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.concurrent.NotThreadSafe;
+
+/**
+ * Executes a task.
+ * Should be accessed by a single thread.
+ */
+@NotThreadSafe
+public final class TaskExecutor {
+ private static final Logger LOG = LoggerFactory.getLogger(TaskExecutor.class.getName());
+ private static final int NONE_FINISHED = -1;
+
+ // Essential information
+ private boolean isExecuted;
+ private final String taskId;
+ private final TaskStateManager taskStateManager;
+ private final List<DataFetcher> dataFetchers;
+ private final List<VertexHarness> sortedHarnesses;
+ private final Map sideInputMap;
+
+ // Metrics information
+ private final Map<String, Object> metricMap;
+ private final MetricCollector metricCollector;
+
+ // Dynamic optimization
+ private String idOfVertexPutOnHold;
+
+ /**
+ * Constructor.
+ * @param task Task with information needed during execution.
+ * @param irVertexDag A DAG of vertices.
+ * @param taskStateManager State manager for this Task.
+ * @param dataTransferFactory For reading from/writing to data to other tasks.
+ * @param metricMessageSender For sending metric with execution stats to Master.
+ */
+ public TaskExecutor(final Task task,
+ final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag,
+ final TaskStateManager taskStateManager,
+ final DataTransferFactory dataTransferFactory,
+ final MetricMessageSender metricMessageSender) {
+ // Essential information
+ this.isExecuted = false;
+ this.taskId = task.getTaskId();
+ this.taskStateManager = taskStateManager;
+
+ // Metrics information
+ this.metricMap = new HashMap<>();
+ this.metricCollector = new MetricCollector(metricMessageSender);
+
+ // Dynamic optimization
+ // Assigning null is very bad, but we are keeping this for now
+ this.idOfVertexPutOnHold = null;
+
+ // Prepare data structures
+ this.sideInputMap = new HashMap();
+ final Pair<List<DataFetcher>, List<VertexHarness>> pair = prepare(task, irVertexDag, dataTransferFactory);
+ this.dataFetchers = pair.left();
+ this.sortedHarnesses = pair.right();
+ }
+
+ /**
+ * Converts the DAG of vertices into pointer-based DAG of vertex harnesses.
+ * This conversion is necessary for constructing concrete data channels for each vertex's inputs and outputs.
+ *
+ * - Source vertex read: Explicitly handled (SourceVertexDataFetcher)
+ * - Sink vertex write: Implicitly handled within the vertex
+ *
+ * - Parent-task read: Explicitly handled (ParentTaskDataFetcher)
+ * - Children-task write: Explicitly handled (VertexHarness)
+ *
+ * - Intra-task read: Implicitly handled when performing Intra-task writes
+ * - Intra-task write: Explicitly handled (VertexHarness)
+
+ * For element-wise data processing, we traverse vertex harnesses from the roots to the leaves for each element.
+ * This means that overheads associated with jumping from one harness to the other should be minimal.
+ * For example, we should never perform an expensive hash operation to traverse the harnesses.
+ *
+ * @param task task.
+ * @param irVertexDag dag.
+ * @return fetchers and harnesses.
+ */
+ private Pair<List<DataFetcher>, List<VertexHarness>> prepare(final Task task,
+ final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag,
+ final DataTransferFactory dataTransferFactory) {
+ final int taskIndex = RuntimeIdGenerator.getIndexFromTaskId(task.getTaskId());
+
+ // Traverse in a reverse-topological order to ensure that each visited vertex's children vertices exist.
+ final List<IRVertex> reverseTopologicallySorted = Lists.reverse(irVertexDag.getTopologicalSort());
+
+ // Create a harness for each vertex
+ final List<DataFetcher> dataFetcherList = new ArrayList<>();
+ final Map<String, VertexHarness> vertexIdToHarness = new HashMap<>();
+ reverseTopologicallySorted.forEach(irVertex -> {
+ final List<VertexHarness> children = getChildrenHarnesses(irVertex, irVertexDag, vertexIdToHarness);
+ final Optional<Readable> sourceReader = getSourceVertexReader(irVertex, task.getIrVertexIdToReadable());
+ if (sourceReader.isPresent() != irVertex instanceof SourceVertex) {
+ throw new IllegalStateException(irVertex.toString());
+ }
+
+ final List<Boolean> isToSideInputs = children.stream()
+ .map(VertexHarness::getIRVertex)
+ .map(childVertex -> irVertexDag.getEdgeBetween(irVertex.getId(), childVertex.getId()))
+ .map(RuntimeEdge::isSideInput)
+ .collect(Collectors.toList());
+
+ // Handle writes
+ final List<OutputWriter> childrenTaskWriters = getChildrenTaskWriters(
+ taskIndex, irVertex, task.getTaskOutgoingEdges(), dataTransferFactory); // Children-task write
+ final VertexHarness vertexHarness = new VertexHarness(irVertex, new OutputCollectorImpl(), children,
+ isToSideInputs, childrenTaskWriters, new ContextImpl(sideInputMap)); // Intra-vertex write
+ prepareTransform(vertexHarness);
+ vertexIdToHarness.put(irVertex.getId(), vertexHarness);
+
+ // Handle reads
+ final boolean isToSideInput = isToSideInputs.stream().anyMatch(bool -> bool);
+ if (irVertex instanceof SourceVertex) {
+ dataFetcherList.add(new SourceVertexDataFetcher(irVertex, sourceReader.get(), vertexHarness, metricMap,
+ false, isToSideInput)); // Source vertex read
+ }
+ final List<InputReader> parentTaskReaders =
+ getParentTaskReaders(taskIndex, irVertex, task.getTaskIncomingEdges(), dataTransferFactory);
+ parentTaskReaders.forEach(parentTaskReader -> {
+ final boolean isFromSideInput = parentTaskReader.isSideInputReader();
+ dataFetcherList.add(new ParentTaskDataFetcher(parentTaskReader.getSrcIrVertex(), parentTaskReader,
+ vertexHarness, metricMap, isFromSideInput, isToSideInput)); // Parent-task read
+ });
+ });
+
+ final List<VertexHarness> sortedHarnessList = irVertexDag.getTopologicalSort()
+ .stream()
+ .map(vertex -> vertexIdToHarness.get(vertex.getId()))
+ .collect(Collectors.toList());
+
+ return Pair.of(dataFetcherList, sortedHarnessList);
+ }
+
+ /**
+ * Recursively process a data element down the DAG dependency.
+ * @param vertexHarness VertexHarness of a vertex to execute.
+ * @param dataElement input data element to process.
+ */
+ private void processElementRecursively(final VertexHarness vertexHarness, final Object dataElement) {
+ final IRVertex irVertex = vertexHarness.getIRVertex();
+ final OutputCollectorImpl outputCollector = vertexHarness.getOutputCollector();
+ if (irVertex instanceof SourceVertex) {
+ outputCollector.emit(dataElement);
+ } else if (irVertex instanceof OperatorVertex) {
+ final Transform transform = ((OperatorVertex) irVertex).getTransform();
+ transform.onData(dataElement);
+ } else if (irVertex instanceof MetricCollectionBarrierVertex) {
+ outputCollector.emit(dataElement);
+ setIRVertexPutOnHold((MetricCollectionBarrierVertex) irVertex);
+ } else {
+ throw new UnsupportedOperationException("This type of IRVertex is not supported");
+ }
+
+ // Given a single input element, a vertex can produce many output elements.
+ // Here, we recursively process all of the output elements.
+ while (!outputCollector.isEmpty()) {
+ final Object element = outputCollector.remove();
+ handleOutputElement(vertexHarness, element); // Recursion
+ }
+ }
+
+ /**
+ * Execute a task, while handling unrecoverable errors and exceptions.
+ */
+ public void execute() {
+ try {
+ doExecute();
+ } catch (Throwable throwable) {
+ // ANY uncaught throwable is reported to the master
+ taskStateManager.onTaskStateChanged(TaskState.State.FAILED_UNRECOVERABLE, Optional.empty(), Optional.empty());
+ throwable.printStackTrace();
+ }
+ }
+
+ /**
+ * The task is executed in the following two phases.
+ * - Phase 1: Consume task-external side-input data
+ * - Phase 2: Consume task-external input data
+ * - Phase 3: Finalize task-internal states and data elements
+ */
+ private void doExecute() {
+ // Housekeeping stuff
+ if (isExecuted) {
+ throw new RuntimeException("Task {" + taskId + "} execution called again");
+ }
+ LOG.info("{} started", taskId);
+ taskStateManager.onTaskStateChanged(TaskState.State.EXECUTING, Optional.empty(), Optional.empty());
+ metricCollector.beginMeasurement(taskId, metricMap);
+
+ // Phase 1: Consume task-external side-input related data.
+ final Map<Boolean, List<DataFetcher>> sideInputRelated = dataFetchers.stream()
+ .collect(Collectors.partitioningBy(fetcher -> fetcher.isFromSideInput() || fetcher.isToSideInput()));
+ if (!handleDataFetchers(sideInputRelated.get(true))) {
+ return;
+ }
+ final Set<VertexHarness> finalizeLater = sideInputRelated.get(false).stream()
+ .map(DataFetcher::getChild)
+ .flatMap(vertex -> getAllReachables(vertex).stream())
+ .collect(Collectors.toSet());
+ for (final VertexHarness vertexHarness : sortedHarnesses) {
+ if (!finalizeLater.contains(vertexHarness)) {
+ finalizeVertex(vertexHarness); // finalize early to materialize intra-task side inputs.
+ }
+ }
+
+ // Phase 2: Consume task-external input data.
+ if (!handleDataFetchers(sideInputRelated.get(false))) {
+ return;
+ }
+
+ // Phase 3: Finalize task-internal states and elements
+ for (final VertexHarness vertexHarness : sortedHarnesses) {
+ if (finalizeLater.contains(vertexHarness)) {
+ finalizeVertex(vertexHarness);
+ }
+ }
+
+ // Miscellaneous: Metrics, DynOpt, etc
+ metricCollector.endMeasurement(taskId, metricMap);
+ if (idOfVertexPutOnHold == null) {
+ taskStateManager.onTaskStateChanged(TaskState.State.COMPLETE, Optional.empty(), Optional.empty());
+ LOG.info("{} completed", taskId);
+ } else {
+ taskStateManager.onTaskStateChanged(TaskState.State.ON_HOLD,
+ Optional.of(idOfVertexPutOnHold),
+ Optional.empty());
+ LOG.info("{} on hold", taskId);
+ }
+ }
+
+ private List<VertexHarness> getAllReachables(final VertexHarness src) {
+ final List<VertexHarness> result = new ArrayList<>();
+ result.add(src);
+ result.addAll(src.getNonSideInputChildren().stream()
+ .flatMap(child -> getAllReachables(child).stream()).collect(Collectors.toList()));
+ result.addAll(src.getSideInputChildren().stream()
+ .flatMap(child -> getAllReachables(child).stream()).collect(Collectors.toList()));
+ return result;
+ }
+
+ private void finalizeVertex(final VertexHarness vertexHarness) {
+ closeTransform(vertexHarness);
+ while (!vertexHarness.getOutputCollector().isEmpty()) {
+ final Object element = vertexHarness.getOutputCollector().remove();
+ handleOutputElement(vertexHarness, element);
+ }
+ finalizeOutputWriters(vertexHarness);
+ }
+
+ private void handleOutputElement(final VertexHarness vertexHarness, final Object element) {
+ vertexHarness.getWritersToChildrenTasks().forEach(outputWriter -> outputWriter.write(element));
+ if (vertexHarness.getSideInputChildren().size() > 0) {
+ sideInputMap.put(((OperatorVertex) vertexHarness.getIRVertex()).getTransform().getTag(), element);
+ }
+ vertexHarness.getNonSideInputChildren().forEach(child -> processElementRecursively(child, element));
+ }
+
+ /**
+ * @param fetchers to handle.
+ * @return false if IOException.
+ */
+ private boolean handleDataFetchers(final List<DataFetcher> fetchers) {
+ final List<DataFetcher> availableFetchers = new ArrayList<>(fetchers);
+ int finishedFetcherIndex = NONE_FINISHED;
+ while (!availableFetchers.isEmpty()) { // empty means we've consumed all task-external input data
+ for (int i = 0; i < availableFetchers.size(); i++) {
+ final DataFetcher dataFetcher = fetchers.get(i);
+ final Object element;
+ try {
+ element = dataFetcher.fetchDataElement();
+ } catch (IOException e) {
+ taskStateManager.onTaskStateChanged(TaskState.State.FAILED_RECOVERABLE,
+ Optional.empty(), Optional.of(TaskState.RecoverableFailureCause.INPUT_READ_FAILURE));
+ LOG.error("{} Execution Failed (Recoverable: input read failure)! Exception: {}", taskId, e.toString());
+ return false;
+ }
+
+ if (element == null) {
+ finishedFetcherIndex = i;
+ break;
+ } else {
+ if (dataFetcher.isFromSideInput()) {
+ sideInputMap.put(((OperatorVertex) dataFetcher.getDataSource()).getTransform().getTag(), element);
+ } else {
+ processElementRecursively(dataFetcher.getChild(), element);
+ }
+ }
+ }
+
+ // Remove the finished fetcher from the list
+ if (finishedFetcherIndex != NONE_FINISHED) {
+ availableFetchers.remove(finishedFetcherIndex);
+ }
+ }
+ return true;
+ }
+
+ ////////////////////////////////////////////// Helper methods for setting up initial data structures
+
+ private Optional<Readable> getSourceVertexReader(final IRVertex irVertex,
+ final Map<String, Readable> irVertexIdToReadable) {
+ if (irVertex instanceof SourceVertex) {
+ final Readable readable = irVertexIdToReadable.get(irVertex.getId());
+ if (readable == null) {
+ throw new IllegalStateException(irVertex.toString());
+ }
+ return Optional.of(readable);
+ } else {
+ return Optional.empty();
+ }
+ }
+
+ private List<InputReader> getParentTaskReaders(final int taskIndex,
+ final IRVertex irVertex,
+ final List<StageEdge> inEdgesFromParentTasks,
+ final DataTransferFactory dataTransferFactory) {
+ return inEdgesFromParentTasks
+ .stream()
+ .filter(inEdge -> inEdge.getDstVertex().getId().equals(irVertex.getId()))
+ .map(inEdgeForThisVertex -> dataTransferFactory
+ .createReader(taskIndex, inEdgeForThisVertex.getSrcVertex(), inEdgeForThisVertex))
+ .collect(Collectors.toList());
+ }
+
+ private List<OutputWriter> getChildrenTaskWriters(final int taskIndex,
+ final IRVertex irVertex,
+ final List<StageEdge> outEdgesToChildrenTasks,
+ final DataTransferFactory dataTransferFactory) {
+ return outEdgesToChildrenTasks
+ .stream()
+ .filter(outEdge -> outEdge.getSrcVertex().getId().equals(irVertex.getId()))
+ .map(outEdgeForThisVertex -> dataTransferFactory
+ .createWriter(irVertex, taskIndex, outEdgeForThisVertex.getDstVertex(), outEdgeForThisVertex))
+ .collect(Collectors.toList());
+ }
+
+ private List<VertexHarness> getChildrenHarnesses(final IRVertex irVertex,
+ final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag,
+ final Map<String, VertexHarness> vertexIdToHarness) {
+ final List<VertexHarness> childrenHandlers = irVertexDag.getChildren(irVertex.getId())
+ .stream()
+ .map(IRVertex::getId)
+ .map(vertexIdToHarness::get)
+ .collect(Collectors.toList());
+ if (childrenHandlers.stream().anyMatch(harness -> harness == null)) {
+ // Sanity check: there shouldn't be a null harness.
+ throw new IllegalStateException(childrenHandlers.toString());
+ }
+ return childrenHandlers;
+ }
+
+ ////////////////////////////////////////////// Transform-specific helper methods
+
+ private void prepareTransform(final VertexHarness vertexHarness) {
+ final IRVertex irVertex = vertexHarness.getIRVertex();
+ if (irVertex instanceof OperatorVertex) {
+ final Transform transform = ((OperatorVertex) irVertex).getTransform();
+ transform.prepare(vertexHarness.getContext(), vertexHarness.getOutputCollector());
+ }
+ }
+
+ private void closeTransform(final VertexHarness vertexHarness) {
+ final IRVertex irVertex = vertexHarness.getIRVertex();
+ if (irVertex instanceof OperatorVertex) {
+ Transform transform = ((OperatorVertex) irVertex).getTransform();
+ transform.close();
+ }
+ }
+
+ ////////////////////////////////////////////// Misc
+
+ private void setIRVertexPutOnHold(final MetricCollectionBarrierVertex irVertex) {
+ idOfVertexPutOnHold = irVertex.getId();
+ }
+
+ /**
+ * Finalize the output write of this vertex.
+ * As element-wise output write is done and the block is in memory,
+ * flush the block into the designated data store and commit it.
+ * @param vertexHarness harness.
+ */
+ private void finalizeOutputWriters(final VertexHarness vertexHarness) {
+ final List<Long> writtenBytesList = new ArrayList<>();
+ final Map<String, Object> metric = new HashMap<>();
+ final IRVertex irVertex = vertexHarness.getIRVertex();
+
+ metricCollector.beginMeasurement(irVertex.getId(), metric);
+ final long writeStartTime = System.currentTimeMillis();
+
+ vertexHarness.getWritersToChildrenTasks().forEach(outputWriter -> {
+ outputWriter.close();
+ final Optional<Long> writtenBytes = outputWriter.getWrittenBytes();
+ writtenBytes.ifPresent(writtenBytesList::add);
+ });
+
+ final long writeEndTime = System.currentTimeMillis();
+ metric.put("OutputWriteTime(ms)", writeEndTime - writeStartTime);
+ if (!writtenBytesList.isEmpty()) {
+ long totalWrittenBytes = 0;
+ for (final Long writtenBytes : writtenBytesList) {
+ totalWrittenBytes += writtenBytes;
+ }
+ metricMap.put("WrittenBytes", totalWrittenBytes);
+ }
+ metricCollector.endMeasurement(irVertex.getId(), metric);
+ }
+
+}
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
new file mode 100644
index 0000000..2d915c4
--- /dev/null
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
@@ -0,0 +1,108 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.runtime.executor.task;
+
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.common.ir.vertex.transform.Transform;
+import edu.snu.nemo.runtime.executor.datatransfer.OutputCollectorImpl;
+import edu.snu.nemo.runtime.executor.datatransfer.OutputWriter;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Captures the relationship between a non-source IRVertex's outputCollector, and children vertices.
+ */
+final class VertexHarness {
+ // IRVertex and transform-specific information
+ private final IRVertex irVertex;
+ private final OutputCollectorImpl outputCollector;
+ private final Transform.Context context;
+
+ // These lists can be empty
+ private final List<VertexHarness> sideInputChildren;
+ private final List<VertexHarness> nonSideInputChildren;
+ private final List<OutputWriter> writersToChildrenTasks;
+
+ VertexHarness(final IRVertex irVertex,
+ final OutputCollectorImpl outputCollector,
+ final List<VertexHarness> children,
+ final List<Boolean> isSideInputs,
+ final List<OutputWriter> writersToChildrenTasks,
+ final Transform.Context context) {
+ this.irVertex = irVertex;
+ this.outputCollector = outputCollector;
+ if (children.size() != isSideInputs.size()) {
+ throw new IllegalStateException(irVertex.toString());
+ }
+ final List<VertexHarness> sides = new ArrayList<>();
+ final List<VertexHarness> nonSides = new ArrayList<>();
+ for (int i = 0; i < children.size(); i++) {
+ final VertexHarness child = children.get(i);
+ if (isSideInputs.get(0)) {
+ sides.add(child);
+ } else {
+ nonSides.add(child);
+ }
+ }
+ this.sideInputChildren = sides;
+ this.nonSideInputChildren = nonSides;
+ this.writersToChildrenTasks = writersToChildrenTasks;
+ this.context = context;
+ }
+
+ /**
+ * @return irVertex of this VertexHarness.
+ */
+ IRVertex getIRVertex() {
+ return irVertex;
+ }
+
+ /**
+ * @return OutputCollector of this irVertex.
+ */
+ OutputCollectorImpl getOutputCollector() {
+ return outputCollector;
+ }
+
+ /**
+ * @return list of non-sideinput children. (empty if none exists)
+ */
+ List<VertexHarness> getNonSideInputChildren() {
+ return nonSideInputChildren;
+ }
+
+ /**
+ * @return list of sideinput children. (empty if none exists)
+ */
+ List<VertexHarness> getSideInputChildren() {
+ return sideInputChildren;
+ }
+
+ /**
+ * @return OutputWriters of this irVertex. (empty if none exists)
+ */
+ List<OutputWriter> getWritersToChildrenTasks() {
+ return writersToChildrenTasks;
+ }
+
+ /**
+ * @return context.
+ */
+ Transform.Context getContext() {
+ return context;
+ }
+}
diff --git a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
new file mode 100644
index 0000000..f2c0082
--- /dev/null
+++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
@@ -0,0 +1,426 @@
+/*
+ * Copyright (C) 2017 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.runtime.executor.task;
+
+import edu.snu.nemo.common.Pair;
+import edu.snu.nemo.common.ir.OutputCollector;
+import edu.snu.nemo.common.coder.Coder;
+import edu.snu.nemo.common.dag.DAG;
+import edu.snu.nemo.common.dag.DAGBuilder;
+import edu.snu.nemo.common.ir.Readable;
+import edu.snu.nemo.common.ir.vertex.InMemorySourceVertex;
+import edu.snu.nemo.common.ir.vertex.OperatorVertex;
+import edu.snu.nemo.common.ir.vertex.transform.Transform;
+import edu.snu.nemo.common.ir.edge.executionproperty.DataStoreProperty;
+import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import edu.snu.nemo.runtime.common.plan.Task;
+import edu.snu.nemo.runtime.common.plan.StageEdge;
+import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
+import edu.snu.nemo.runtime.executor.MetricMessageSender;
+import edu.snu.nemo.runtime.executor.TaskStateManager;
+import edu.snu.nemo.runtime.executor.data.DataUtil;
+import edu.snu.nemo.runtime.executor.datatransfer.DataTransferFactory;
+import edu.snu.nemo.runtime.executor.datatransfer.InputReader;
+import edu.snu.nemo.runtime.executor.datatransfer.OutputWriter;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
+import java.io.IOException;
+import java.util.*;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.*;
+
+/**
+ * Tests {@link TaskExecutor}.
+ */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({InputReader.class, OutputWriter.class, DataTransferFactory.class,
+ TaskStateManager.class, StageEdge.class})
+public final class TaskExecutorTest {
+ private static final int DATA_SIZE = 100;
+ private static final String CONTAINER_TYPE = "CONTAINER_TYPE";
+ private static final int SOURCE_PARALLELISM = 5;
+ private List<Integer> elements;
+ private Map<String, List> vertexIdToOutputData;
+ private DataTransferFactory dataTransferFactory;
+ private TaskStateManager taskStateManager;
+ private MetricMessageSender metricMessageSender;
+ private AtomicInteger stageId;
+
+ private String generateTaskId() {
+ return RuntimeIdGenerator.generateTaskId(0,
+ RuntimeIdGenerator.generateStageId(stageId.getAndIncrement()));
+ }
+
+ @Before
+ public void setUp() throws Exception {
+ elements = getRangedNumList(0, DATA_SIZE);
+ stageId = new AtomicInteger(1);
+
+ // Mock a TaskStateManager. It accumulates the state change into a list.
+ taskStateManager = mock(TaskStateManager.class);
+
+ // Mock a DataTransferFactory.
+ vertexIdToOutputData = new HashMap<>();
+ dataTransferFactory = mock(DataTransferFactory.class);
+ when(dataTransferFactory.createReader(anyInt(), any(), any())).then(new ParentTaskReaderAnswer());
+ when(dataTransferFactory.createWriter(any(), anyInt(), any(), any())).then(new ChildTaskWriterAnswer());
+
+ // Mock a MetricMessageSender.
+ metricMessageSender = mock(MetricMessageSender.class);
+ doNothing().when(metricMessageSender).send(anyString(), anyString());
+ doNothing().when(metricMessageSender).close();
+ }
+
+ private boolean checkEqualElements(final List<Integer> left, final List<Integer> right) {
+ Collections.sort(left);
+ Collections.sort(right);
+ return left.equals(right);
+ }
+
+ /**
+ * Test source vertex data fetching.
+ */
+ @Test(timeout=5000)
+ public void testSourceVertexDataFetching() throws Exception {
+ final IRVertex sourceIRVertex = new InMemorySourceVertex<>(elements);
+
+ final Readable readable = new Readable() {
+ @Override
+ public Iterable read() throws IOException {
+ return elements;
+ }
+ @Override
+ public List<String> getLocations() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ final Map<String, Readable> vertexIdToReadable = new HashMap<>();
+ vertexIdToReadable.put(sourceIRVertex.getId(), readable);
+
+ final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag =
+ new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>()
+ .addVertex(sourceIRVertex)
+ .buildWithoutSourceSinkCheck();
+
+ final Task task =
+ new Task(
+ "testSourceVertexDataFetching",
+ generateTaskId(),
+ 0,
+ CONTAINER_TYPE,
+ new byte[0],
+ Collections.emptyList(),
+ Collections.singletonList(mockStageEdgeFrom(sourceIRVertex)),
+ vertexIdToReadable);
+
+ // Execute the task.
+ final TaskExecutor taskExecutor = new TaskExecutor(
+ task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender);
+ taskExecutor.execute();
+
+ // Check the output.
+ assertTrue(checkEqualElements(elements, vertexIdToOutputData.get(sourceIRVertex.getId())));
+ }
+
+ /**
+ * Test parent task data fetching.
+ */
+ @Test(timeout=5000)
+ public void testParentTaskDataFetching() throws Exception {
+ final IRVertex vertex = new OperatorVertex(new RelayTransform());
+
+ final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>()
+ .addVertex(vertex)
+ .buildWithoutSourceSinkCheck();
+
+ final Task task = new Task(
+ "testSourceVertexDataFetching",
+ generateTaskId(),
+ 0,
+ CONTAINER_TYPE,
+ new byte[0],
+ Collections.singletonList(mockStageEdgeTo(vertex)),
+ Collections.singletonList(mockStageEdgeFrom(vertex)),
+ Collections.emptyMap());
+
+ // Execute the task.
+ final TaskExecutor taskExecutor = new TaskExecutor(
+ task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender);
+ taskExecutor.execute();
+
+ // Check the output.
+ assertTrue(checkEqualElements(elements, vertexIdToOutputData.get(vertex.getId())));
+ }
+
+ /**
+ * The DAG of the task to test will looks like:
+ * parent task -> task (vertex 1 -> task 2) -> child task
+ *
+ * The output data from task 1 will be split according to source parallelism through {@link ParentTaskReaderAnswer}.
+ * Because of this, task 1 will process multiple partitions and emit data in multiple times also.
+ * On the other hand, task 2 will receive the output data once and produce a single output.
+ */
+ @Test(timeout=5000)
+ public void testTwoOperators() throws Exception {
+ final IRVertex operatorIRVertex1 = new OperatorVertex(new RelayTransform());
+ final IRVertex operatorIRVertex2 = new OperatorVertex(new RelayTransform());
+
+ final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>()
+ .addVertex(operatorIRVertex1)
+ .addVertex(operatorIRVertex2)
+ .connectVertices(createEdge(operatorIRVertex1, operatorIRVertex2, false))
+ .buildWithoutSourceSinkCheck();
+
+ final Task task = new Task(
+ "testSourceVertexDataFetching",
+ generateTaskId(),
+ 0,
+ CONTAINER_TYPE,
+ new byte[0],
+ Collections.singletonList(mockStageEdgeTo(operatorIRVertex1)),
+ Collections.singletonList(mockStageEdgeFrom(operatorIRVertex2)),
+ Collections.emptyMap());
+
+ // Execute the task.
+ final TaskExecutor taskExecutor = new TaskExecutor(
+ task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender);
+ taskExecutor.execute();
+
+ // Check the output.
+ assertTrue(checkEqualElements(elements, vertexIdToOutputData.get(operatorIRVertex2.getId())));
+ }
+
+ @Test(timeout=5000)
+ public void testTwoOperatorsWithSideInput() throws Exception {
+ final Object tag = new Object();
+ final Transform singleListTransform = new CreateSingleListTransform();
+ final IRVertex operatorIRVertex1 = new OperatorVertex(singleListTransform);
+ final IRVertex operatorIRVertex2 = new OperatorVertex(new SideInputPairTransform(singleListTransform.getTag()));
+
+ final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>()
+ .addVertex(operatorIRVertex1)
+ .addVertex(operatorIRVertex2)
+ .connectVertices(createEdge(operatorIRVertex1, operatorIRVertex2, true))
+ .buildWithoutSourceSinkCheck();
+
+ final Task task = new Task(
+ "testSourceVertexDataFetching",
+ generateTaskId(),
+ 0,
+ CONTAINER_TYPE,
+ new byte[0],
+ Arrays.asList(mockStageEdgeTo(operatorIRVertex1), mockStageEdgeTo(operatorIRVertex2)),
+ Collections.singletonList(mockStageEdgeFrom(operatorIRVertex2)),
+ Collections.emptyMap());
+
+ // Execute the task.
+ final TaskExecutor taskExecutor = new TaskExecutor(
+ task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender);
+ taskExecutor.execute();
+
+ // Check the output.
+ final List<Pair<List<Integer>, Integer>> pairs = vertexIdToOutputData.get(operatorIRVertex2.getId());
+ final List<Integer> values = pairs.stream().map(Pair::right).collect(Collectors.toList());
+ assertTrue(checkEqualElements(elements, values));
+ assertTrue(pairs.stream().map(Pair::left).allMatch(sideInput -> checkEqualElements(sideInput, values)));
+ }
+
+ private RuntimeEdge<IRVertex> createEdge(final IRVertex src,
+ final IRVertex dst,
+ final boolean isSideInput) {
+ final String runtimeIREdgeId = "Runtime edge between operator tasks";
+ final Coder coder = Coder.DUMMY_CODER;
+ ExecutionPropertyMap edgeProperties = new ExecutionPropertyMap(runtimeIREdgeId);
+ edgeProperties.put(DataStoreProperty.of(DataStoreProperty.Value.MemoryStore));
+ return new RuntimeEdge<>(runtimeIREdgeId, edgeProperties, src, dst, coder, isSideInput);
+
+ }
+
+ private StageEdge mockStageEdgeFrom(final IRVertex irVertex) {
+ final StageEdge edge = mock(StageEdge.class);
+ when(edge.getSrcVertex()).thenReturn(irVertex);
+ when(edge.getDstVertex()).thenReturn(new OperatorVertex(new RelayTransform()));
+ return edge;
+ }
+
+ private StageEdge mockStageEdgeTo(final IRVertex irVertex) {
+ final StageEdge edge = mock(StageEdge.class);
+ when(edge.getSrcVertex()).thenReturn(new OperatorVertex(new RelayTransform()));
+ when(edge.getDstVertex()).thenReturn(irVertex);
+ return edge;
+ }
+
+ /**
+ * Represents the answer return an inter-stage {@link InputReader},
+ * which will have multiple iterable according to the source parallelism.
+ */
+ private class ParentTaskReaderAnswer implements Answer<InputReader> {
+ @Override
+ public InputReader answer(final InvocationOnMock invocationOnMock) throws Throwable {
+ final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> inputFutures = new ArrayList<>(SOURCE_PARALLELISM);
+ final int elementsPerSource = DATA_SIZE / SOURCE_PARALLELISM;
+ for (int i = 0; i < SOURCE_PARALLELISM; i++) {
+ inputFutures.add(CompletableFuture.completedFuture(
+ DataUtil.IteratorWithNumBytes.of(elements.subList(i * elementsPerSource, (i + 1) * elementsPerSource)
+ .iterator())));
+ }
+ final InputReader inputReader = mock(InputReader.class);
+ when(inputReader.read()).thenReturn(inputFutures);
+ when(inputReader.isSideInputReader()).thenReturn(false);
+ when(inputReader.getSourceParallelism()).thenReturn(SOURCE_PARALLELISM);
+ return inputReader;
+ }
+ }
+
+ /**
+ * Represents the answer return a {@link OutputWriter},
+ * which will stores the data to the map between task id and output data.
+ */
+ private class ChildTaskWriterAnswer implements Answer<OutputWriter> {
+ @Override
+ public OutputWriter answer(final InvocationOnMock invocationOnMock) throws Throwable {
+ final Object[] args = invocationOnMock.getArguments();
+ final IRVertex vertex = (IRVertex) args[0];
+ final OutputWriter outputWriter = mock(OutputWriter.class);
+ doAnswer(new Answer() {
+ @Override
+ public Object answer(final InvocationOnMock invocationOnMock) throws Throwable {
+ final Object[] args = invocationOnMock.getArguments();
+ final Object dataToWrite = args[0];
+ vertexIdToOutputData.computeIfAbsent(vertex.getId(), emptyTaskId -> new ArrayList<>());
+ vertexIdToOutputData.get(vertex.getId()).add(dataToWrite);
+ return null;
+ }
+ }).when(outputWriter).write(any());
+ return outputWriter;
+ }
+ }
+
+ /**
+ * Simple identity function for testing.
+ * @param <T> input/output type.
+ */
+ private class RelayTransform<T> implements Transform<T, T> {
+ private OutputCollector<T> outputCollector;
+
+ @Override
+ public void prepare(final Context context, final OutputCollector<T> outputCollector) {
+ this.outputCollector = outputCollector;
+ }
+
+ @Override
+ public void onData(final Object element) {
+ outputCollector.emit((T) element);
+ }
+
+ @Override
+ public void close() {
+ // Do nothing.
+ }
+ }
+
+ /**
+ * Creates a view.
+ * @param <T> input type.
+ */
+ private class CreateSingleListTransform<T> implements Transform<T, List<T>> {
+ private List<T> list;
+ private OutputCollector<List<T>> outputCollector;
+ private final Object tag = new Object();
+
+ @Override
+ public void prepare(final Context context, final OutputCollector<List<T>> outputCollector) {
+ this.list = new ArrayList<>();
+ this.outputCollector = outputCollector;
+ }
+
+ @Override
+ public void onData(final Object element) {
+ list.add((T) element);
+ }
+
+ @Override
+ public void close() {
+ outputCollector.emit(list);
+ }
+
+ @Override
+ public Object getTag() {
+ return tag;
+ }
+ }
+
+ /**
+ * Pairs data element with a side input.
+ * @param <T> input/output type.
+ */
+ private class SideInputPairTransform<T> implements Transform<T, T> {
+ private final Object sideInputTag;
+ private Context context;
+ private OutputCollector<T> outputCollector;
+
+ public SideInputPairTransform(final Object sideInputTag) {
+ this.sideInputTag = sideInputTag;
+ }
+
+ @Override
+ public void prepare(final Context context, final OutputCollector<T> outputCollector) {
+ this.context = context;
+ this.outputCollector = outputCollector;
+ }
+
+ @Override
+ public void onData(final Object element) {
+ final Object sideInput = context.getSideInputs().get(sideInputTag);
+ outputCollector.emit((T) Pair.of(sideInput, element));
+ }
+
+ @Override
+ public void close() {
+ // Do nothing.
+ }
+ }
+
+ /**
+ * Gets a list of integer pair elements in range.
+ * @param start value of the range (inclusive).
+ * @param end value of the range (exclusive).
+ * @return the list of elements.
+ */
+ private List<Integer> getRangedNumList(final int start, final int end) {
+ final List<Integer> numList = new ArrayList<>(end - start);
+ IntStream.range(start, end).forEach(number -> numList.add(number));
+ return numList;
+ }
+}
diff --git a/tests/src/test/java/edu/snu/nemo/tests/runtime/executor/TaskExecutorTest.java b/tests/src/test/java/edu/snu/nemo/tests/runtime/executor/TaskExecutorTest.java
deleted file mode 100644
index ffbfef6..0000000
--- a/tests/src/test/java/edu/snu/nemo/tests/runtime/executor/TaskExecutorTest.java
+++ /dev/null
@@ -1,280 +0,0 @@
-/*
- * Copyright (C) 2017 Seoul National University
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package edu.snu.nemo.tests.runtime.executor;
-
-import edu.snu.nemo.common.ir.OutputCollector;
-import edu.snu.nemo.common.coder.Coder;
-import edu.snu.nemo.common.dag.DAG;
-import edu.snu.nemo.common.dag.DAGBuilder;
-import edu.snu.nemo.common.ir.Readable;
-import edu.snu.nemo.common.ir.vertex.OperatorVertex;
-import edu.snu.nemo.common.ir.vertex.transform.Transform;
-import edu.snu.nemo.common.ir.edge.executionproperty.DataStoreProperty;
-import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
-import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.compiler.optimizer.examples.EmptyComponents;
-import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
-import edu.snu.nemo.runtime.common.plan.Task;
-import edu.snu.nemo.runtime.common.plan.StageEdge;
-import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
-import edu.snu.nemo.runtime.executor.MetricMessageSender;
-import edu.snu.nemo.runtime.executor.TaskExecutor;
-import edu.snu.nemo.runtime.executor.TaskStateManager;
-import edu.snu.nemo.runtime.executor.data.DataUtil;
-import edu.snu.nemo.runtime.executor.datatransfer.DataTransferFactory;
-import edu.snu.nemo.runtime.executor.datatransfer.InputReader;
-import edu.snu.nemo.runtime.executor.datatransfer.OutputWriter;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
-
-import java.util.*;
-import java.util.concurrent.CompletableFuture;
-
-import static edu.snu.nemo.tests.runtime.RuntimeTestUtil.getRangedNumList;
-import static org.junit.Assert.assertEquals;
-import static org.mockito.ArgumentMatchers.anyInt;
-import static org.mockito.ArgumentMatchers.anyString;
-import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.*;
-
-/**
- * Tests {@link TaskExecutor}.
- */
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({InputReader.class, OutputWriter.class, DataTransferFactory.class,
- TaskStateManager.class, StageEdge.class})
-public final class TaskExecutorTest {
- private static final int DATA_SIZE = 100;
- private static final String CONTAINER_TYPE = "CONTAINER_TYPE";
- private static final int SOURCE_PARALLELISM = 5;
- private List elements;
- private Map<String, List<Object>> vertexIdToOutputData;
- private DataTransferFactory dataTransferFactory;
- private TaskStateManager taskStateManager;
- private MetricMessageSender metricMessageSender;
-
- @Before
- public void setUp() throws Exception {
- elements = getRangedNumList(0, DATA_SIZE);
-
- // Mock a TaskStateManager. It accumulates the state change into a list.
- taskStateManager = mock(TaskStateManager.class);
-
- // Mock a DataTransferFactory.
- vertexIdToOutputData = new HashMap<>();
- dataTransferFactory = mock(DataTransferFactory.class);
- when(dataTransferFactory.createReader(anyInt(), any(), any())).then(new InterStageReaderAnswer());
- when(dataTransferFactory.createWriter(any(), anyInt(), any(), any())).then(new WriterAnswer());
-
- // Mock a MetricMessageSender.
- metricMessageSender = mock(MetricMessageSender.class);
- doNothing().when(metricMessageSender).send(anyString(), anyString());
- doNothing().when(metricMessageSender).close();
- }
-
- /**
- * Test the {@link edu.snu.nemo.common.ir.vertex.SourceVertex} processing in {@link TaskExecutor}.
- */
- @Test(timeout=5000)
- public void testSourceVertex() throws Exception {
- final IRVertex sourceIRVertex = new EmptyComponents.EmptySourceVertex("empty");
- final String stageId = RuntimeIdGenerator.generateStageId(0);
-
- final Readable readable = new Readable() {
- @Override
- public Iterable read() throws Exception {
- return elements;
- }
- @Override
- public List<String> getLocations() {
- throw new UnsupportedOperationException();
- }
- };
- final Map<String, Readable> vertexIdToReadable = new HashMap<>();
- vertexIdToReadable.put(sourceIRVertex.getId(), readable);
-
- final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag =
- new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>().addVertex(sourceIRVertex).buildWithoutSourceSinkCheck();
- final StageEdge stageOutEdge = mock(StageEdge.class);
- when(stageOutEdge.getSrcVertex()).thenReturn(sourceIRVertex);
- final String taskId = RuntimeIdGenerator.generateTaskId(0, stageId);
- final Task task =
- new Task(
- "testSourceVertex",
- taskId,
- 0,
- CONTAINER_TYPE,
- new byte[0],
- Collections.emptyList(),
- Collections.singletonList(stageOutEdge),
- vertexIdToReadable);
-
- // Execute the task.
- final TaskExecutor taskExecutor = new TaskExecutor(
- task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender);
- taskExecutor.execute();
-
- // Check the output.
- assertEquals(100, vertexIdToOutputData.get(sourceIRVertex.getId()).size());
- assertEquals(elements.get(0), vertexIdToOutputData.get(sourceIRVertex.getId()).get(0));
- }
-
- /**
- * Test the {@link edu.snu.nemo.common.ir.vertex.OperatorVertex} processing in {@link TaskExecutor}.
- *
- * The DAG of the task to test will looks like:
- * operator task 1 -> operator task 2
- *
- * The output data from upstream stage will be split
- * according to source parallelism through {@link InterStageReaderAnswer}.
- * Because of this, the operator task 1 will process multiple partitions and emit data in multiple times also.
- * On the other hand, operator task 2 will receive the output data once and produce a single output.
- */
- @Test(timeout=5000)
- public void testOperatorVertex() throws Exception {
- final IRVertex operatorIRVertex1 = new OperatorVertex(new SimpleTransform());
- final IRVertex operatorIRVertex2 = new OperatorVertex(new SimpleTransform());
- final String runtimeIREdgeId = "Runtime edge between operator tasks";
-
- final String stageId = RuntimeIdGenerator.generateStageId(1);
-
- final Coder coder = Coder.DUMMY_CODER;
- ExecutionPropertyMap edgeProperties = new ExecutionPropertyMap(runtimeIREdgeId);
- edgeProperties.put(DataStoreProperty.of(DataStoreProperty.Value.MemoryStore));
- final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>()
- .addVertex(operatorIRVertex1)
- .addVertex(operatorIRVertex2)
- .connectVertices(new RuntimeEdge<IRVertex>(
- runtimeIREdgeId, edgeProperties, operatorIRVertex1, operatorIRVertex2, coder))
- .buildWithoutSourceSinkCheck();
- final String taskId = RuntimeIdGenerator.generateTaskId(0, stageId);
- final StageEdge stageInEdge = mock(StageEdge.class);
- when(stageInEdge.getDstVertex()).thenReturn(operatorIRVertex1);
- final StageEdge stageOutEdge = mock(StageEdge.class);
- when(stageOutEdge.getSrcVertex()).thenReturn(operatorIRVertex2);
- final Task task =
- new Task(
- "testSourceVertex",
- taskId,
- 0,
- CONTAINER_TYPE,
- new byte[0],
- Collections.singletonList(stageInEdge),
- Collections.singletonList(stageOutEdge),
- Collections.emptyMap());
-
- // Execute the task.
- final TaskExecutor taskExecutor = new TaskExecutor(
- task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender);
- taskExecutor.execute();
-
- // Check the output.
- assertEquals(100, vertexIdToOutputData.get(operatorIRVertex2.getId()).size());
- }
-
- /**
- * Represents the answer return an intra-stage {@link InputReader},
- * which will have a single iterable from the upstream task.
- */
- private class IntraStageReaderAnswer implements Answer<InputReader> {
- @Override
- public InputReader answer(final InvocationOnMock invocationOnMock) throws Throwable {
- // Read the data.
- final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> inputFutures = new ArrayList<>();
- inputFutures.add(CompletableFuture.completedFuture(
- DataUtil.IteratorWithNumBytes.of(elements.iterator())));
- final InputReader inputReader = mock(InputReader.class);
- when(inputReader.read()).thenReturn(inputFutures);
- when(inputReader.isSideInputReader()).thenReturn(false);
- when(inputReader.getSourceParallelism()).thenReturn(1);
- return inputReader;
- }
- }
-
- /**
- * Represents the answer return an inter-stage {@link InputReader},
- * which will have multiple iterable according to the source parallelism.
- */
- private class InterStageReaderAnswer implements Answer<InputReader> {
- @Override
- public InputReader answer(final InvocationOnMock invocationOnMock) throws Throwable {
- final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> inputFutures = new ArrayList<>(SOURCE_PARALLELISM);
- final int elementsPerSource = DATA_SIZE / SOURCE_PARALLELISM;
- for (int i = 0; i < SOURCE_PARALLELISM; i++) {
- inputFutures.add(CompletableFuture.completedFuture(
- DataUtil.IteratorWithNumBytes.of(elements.subList(i * elementsPerSource, (i + 1) * elementsPerSource)
- .iterator())));
- }
- final InputReader inputReader = mock(InputReader.class);
- when(inputReader.read()).thenReturn(inputFutures);
- when(inputReader.isSideInputReader()).thenReturn(false);
- when(inputReader.getSourceParallelism()).thenReturn(SOURCE_PARALLELISM);
- return inputReader;
- }
- }
-
- /**
- * Represents the answer return a {@link OutputWriter},
- * which will stores the data to the map between task id and output data.
- */
- private class WriterAnswer implements Answer<OutputWriter> {
- @Override
- public OutputWriter answer(final InvocationOnMock invocationOnMock) throws Throwable {
- final Object[] args = invocationOnMock.getArguments();
- final IRVertex vertex = (IRVertex) args[0];
- final OutputWriter outputWriter = mock(OutputWriter.class);
- doAnswer(new Answer() {
- @Override
- public Object answer(final InvocationOnMock invocationOnMock) throws Throwable {
- final Object[] args = invocationOnMock.getArguments();
- final Object dataToWrite = args[0];
- vertexIdToOutputData.computeIfAbsent(vertex.getId(), emptyTaskId -> new ArrayList<>());
- vertexIdToOutputData.get(vertex.getId()).add(dataToWrite);
- return null;
- }
- }).when(outputWriter).write(any());
- return outputWriter;
- }
- }
-
- /**
- * Simple {@link Transform} for testing.
- * @param <T> input/output type.
- */
- private class SimpleTransform<T> implements Transform<T, T> {
- private OutputCollector<T> outputCollector;
-
- @Override
- public void prepare(final Context context, final OutputCollector<T> outputCollector) {
- this.outputCollector = outputCollector;
- }
-
- @Override
- public void onData(final Object element) {
- outputCollector.emit((T) element);
- }
-
- @Override
- public void close() {
- // Do nothing.
- }
- }
-}
--
To stop receiving notification emails like this one, please contact
jangho@apache.org.