You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@nemo.apache.org by GitBox <gi...@apache.org> on 2018/03/08 03:13:02 UTC

[GitHub] sanha commented on a change in pull request #2: [NEMO-7] Intra-TaskGroup pipelining

sanha commented on a change in pull request #2: [NEMO-7] Intra-TaskGroup pipelining
URL: https://github.com/apache/incubator-nemo/pull/2#discussion_r173052558
 
 

 ##########
 File path: runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskGroupExecutor.java
 ##########
 @@ -145,303 +196,467 @@ private void initializeDataTransfer() {
         .collect(Collectors.toSet());
   }
 
-  // Helper functions to initializes stage-internal edges.
-  private void createLocalReader(final Task task, final RuntimeEdge<Task> internalEdge) {
-    final InputReader inputReader = channelFactory.createLocalReader(taskGroupIdx, internalEdge);
-    addInputReader(task, inputReader);
-  }
-
-  private void createLocalWriter(final Task task, final RuntimeEdge<Task> internalEdge) {
-    final OutputWriter outputWriter = channelFactory.createLocalWriter(task, taskGroupIdx, internalEdge);
-    addOutputWriter(task, outputWriter);
-  }
-
-  // Helper functions to add the initialized reader/writer to the maintained map.
-  private void addInputReader(final Task task, final InputReader inputReader) {
+  /**
+   * Add input OutputCollectors to each {@link Task}.
+   * Input OutputCollector denotes all the OutputCollectors of intra-Stage parent tasks of this task.
+   *
+   * @param task the Task to add input OutputCollectors to.
+   */
+  private void addInputFromThisStage(final Task task) {
+    final TaskDataHandler dataHandler = taskToDataHandlerMap.get(task);
+    List<Task> parentTasks = taskGroupDag.getParents(task.getId());
     final String physicalTaskId = getPhysicalTaskId(task.getId());
-    physicalTaskIdToInputReaderMap.computeIfAbsent(physicalTaskId, readerList -> new ArrayList<>());
-    physicalTaskIdToInputReaderMap.get(physicalTaskId).add(inputReader);
-  }
 
-  private void addOutputWriter(final Task task, final OutputWriter outputWriter) {
-    final String physicalTaskId = getPhysicalTaskId(task.getId());
-    physicalTaskIdToOutputWriterMap.computeIfAbsent(physicalTaskId, readerList -> new ArrayList<>());
-    physicalTaskIdToOutputWriterMap.get(physicalTaskId).add(outputWriter);
+    if (parentTasks != null) {
+      parentTasks.forEach(parent -> {
+        final OutputCollectorImpl parentOutputCollector = taskToDataHandlerMap.get(parent).getOutputCollector();
+        if (parentOutputCollector.hasSideInputFor(physicalTaskId)) {
+          dataHandler.addSideInputFromThisStage(parentOutputCollector);
+        } else {
+          dataHandler.addInputFromThisStages(parentOutputCollector);
+          LOG.info("log: Added Output outputCollector of {} as InputPipe of {} {}",
+              getPhysicalTaskId(parent.getId()), taskGroupId, physicalTaskId);
+        }
+      });
+    }
   }
 
   /**
-   * Executes the task group.
+   * Add output outputCollectors to each {@link Task}.
+   * Output outputCollector denotes the one and only one outputCollector of this task.
+   * Check the outgoing edges that will use this outputCollector,
+   * and set this outputCollector as side input if any one of the edges uses this outputCollector as side input.
+   *
+   * @param task the Task to add output outputCollectors to.
    */
-  public void execute() {
-    LOG.info("{} Execution Started!", taskGroupId);
-    if (isExecutionRequested) {
-      throw new RuntimeException("TaskGroup {" + taskGroupId + "} execution called again!");
-    } else {
-      isExecutionRequested = true;
-    }
-
-    taskGroupStateManager.onTaskGroupStateChanged(
-        TaskGroupState.State.EXECUTING, Optional.empty(), Optional.empty());
+  private void setOutputCollector(final Task task) {
+    final TaskDataHandler dataHandler = taskToDataHandlerMap.get(task);
+    final OutputCollectorImpl outputCollector = new OutputCollectorImpl();
+    final String physicalTaskId = getPhysicalTaskId(task.getId());
 
-    taskGroupDag.topologicalDo(task -> {
-      final String physicalTaskId = getPhysicalTaskId(task.getId());
-      taskGroupStateManager.onTaskStateChanged(physicalTaskId, TaskState.State.EXECUTING, Optional.empty());
-      try {
-        if (task instanceof BoundedSourceTask) {
-          launchBoundedSourceTask((BoundedSourceTask) task);
-          taskGroupStateManager.onTaskStateChanged(physicalTaskId, TaskState.State.COMPLETE, Optional.empty());
-          LOG.info("{} Execution Complete!", taskGroupId);
-        } else if (task instanceof OperatorTask) {
-          launchOperatorTask((OperatorTask) task);
-          taskGroupStateManager.onTaskStateChanged(physicalTaskId, TaskState.State.COMPLETE, Optional.empty());
-          LOG.info("{} Execution Complete!", taskGroupId);
-        } else if (task instanceof MetricCollectionBarrierTask) {
-          launchMetricCollectionBarrierTask((MetricCollectionBarrierTask) task);
-          taskGroupStateManager.onTaskStateChanged(physicalTaskId, TaskState.State.ON_HOLD, Optional.empty());
-          LOG.info("{} Execution Complete!", taskGroupId);
-        } else {
-          throw new UnsupportedOperationException(task.toString());
-        }
-      } catch (final BlockFetchException ex) {
-        taskGroupStateManager.onTaskStateChanged(physicalTaskId, TaskState.State.FAILED_RECOVERABLE,
-            Optional.of(TaskGroupState.RecoverableFailureCause.INPUT_READ_FAILURE));
-        LOG.warn("{} Execution Failed (Recoverable)! Exception: {}",
-            new Object[] {taskGroupId, ex.toString()});
-      } catch (final BlockWriteException ex2) {
-        taskGroupStateManager.onTaskStateChanged(physicalTaskId, TaskState.State.FAILED_RECOVERABLE,
-            Optional.of(TaskGroupState.RecoverableFailureCause.OUTPUT_WRITE_FAILURE));
-        LOG.warn("{} Execution Failed (Recoverable)! Exception: {}",
-            new Object[] {taskGroupId, ex2.toString()});
-      } catch (final Exception e) {
-        taskGroupStateManager.onTaskStateChanged(
-            physicalTaskId, TaskState.State.FAILED_UNRECOVERABLE, Optional.empty());
-        throw new RuntimeException(e);
+    taskGroupDag.getOutgoingEdgesOf(task).forEach(outEdge -> {
+      if (outEdge.isSideInput()) {
+        outputCollector.setSideInputRuntimeEdge(outEdge);
+        outputCollector.setAsSideInputFor(physicalTaskId);
+        LOG.info("log: {} {} Marked as accepting sideInput(edge {})",
+            taskGroupId, physicalTaskId, outEdge.getId());
       }
     });
+
+    dataHandler.setOutputCollector(outputCollector);
+    LOG.info("log: {} {} Added OutputPipe", taskGroupId, physicalTaskId);
   }
 
-  /**
-   * Processes a BoundedSourceTask.
-   *
-   * @param boundedSourceTask the bounded source task to execute
-   * @throws Exception occurred during input read.
-   */
-  private void launchBoundedSourceTask(final BoundedSourceTask boundedSourceTask) throws Exception {
-    final String physicalTaskId = getPhysicalTaskId(boundedSourceTask.getId());
-    final Map<String, Object> metric = new HashMap<>();
-    metricCollector.beginMeasurement(physicalTaskId, metric);
+  private boolean hasOutputWriter(final Task task) {
+    return !taskToDataHandlerMap.get(task).getOutputWriters().isEmpty();
+  }
 
-    final long readStartTime = System.currentTimeMillis();
-    final Readable readable = boundedSourceTask.getReadable();
-    final Iterable readData = readable.read();
-    final long readEndTime = System.currentTimeMillis();
-    metric.put("BoundedSourceReadTime(ms)", readEndTime - readStartTime);
+  private void setTaskPutOnHold(final MetricCollectionBarrierTask task) {
+    final String physicalTaskId = getPhysicalTaskId(task.getId());
+    logicalTaskIdPutOnHold = RuntimeIdGenerator.getLogicalTaskIdIdFromPhysicalTaskId(physicalTaskId);
+  }
 
+  private void writeAndCloseOutputWriters(final Task task) {
+    final String physicalTaskId = getPhysicalTaskId(task.getId());
     final List<Long> writtenBytesList = new ArrayList<>();
-    for (final OutputWriter outputWriter : physicalTaskIdToOutputWriterMap.get(physicalTaskId)) {
-      outputWriter.write(readData);
+    final Map<String, Object> metric = new HashMap<>();
+    metricCollector.beginMeasurement(physicalTaskId, metric);
+    final long writeStartTime = System.currentTimeMillis();
+
+    taskToDataHandlerMap.get(task).getOutputWriters().forEach(outputWriter -> {
+      LOG.info("Write and close outputWriter of task {}", getPhysicalTaskId(task.getId()));
+      outputWriter.write();
       outputWriter.close();
       final Optional<Long> writtenBytes = outputWriter.getWrittenBytes();
       writtenBytes.ifPresent(writtenBytesList::add);
-    }
+    });
+
     final long writeEndTime = System.currentTimeMillis();
-    metric.put("OutputWriteTime(ms)", writeEndTime - readEndTime);
+    metric.put("OutputWriteTime(ms)", writeEndTime - writeStartTime);
     putWrittenBytesMetric(writtenBytesList, metric);
     metricCollector.endMeasurement(physicalTaskId, metric);
   }
 
-  /**
-   * Processes an OperatorTask.
-   * @param operatorTask to execute
-   */
-  private void launchOperatorTask(final OperatorTask operatorTask) {
-    final Map<Transform, Object> sideInputMap = new HashMap<>();
-    final List<DataUtil.IteratorWithNumBytes> sideInputIterators = new ArrayList<>();
-    final String physicalTaskId = getPhysicalTaskId(operatorTask.getId());
+  private void prepareInputFromSource() {
+    taskGroupDag.topologicalDo(task -> {
+      if (task instanceof BoundedSourceTask) {
+        try {
+          final String iteratorId = generateIteratorId();
+          final Iterator iterator = ((BoundedSourceTask) task).getReadable().read().iterator();
+          idToSrcIteratorMap.putIfAbsent(iteratorId, iterator);
+          srcIteratorIdToTasksMap.putIfAbsent(iteratorId, new ArrayList<>());
+          srcIteratorIdToTasksMap.get(iteratorId).add(task);
+        } catch (final BlockFetchException ex) {
+          taskGroupStateManager.onTaskGroupStateChanged(TaskGroupState.State.FAILED_RECOVERABLE,
+              Optional.empty(), Optional.of(TaskGroupState.RecoverableFailureCause.INPUT_READ_FAILURE));
+          LOG.info("{} Execution Failed (Recoverable: input read failure)! Exception: {}",
+              taskGroupId, ex.toString());
+        } catch (final Exception e) {
+          taskGroupStateManager.onTaskGroupStateChanged(TaskGroupState.State.FAILED_UNRECOVERABLE,
+              Optional.empty(), Optional.empty());
+          LOG.info("{} Execution Failed! Exception: {}", taskGroupId, e.toString());
+          throw new RuntimeException(e);
+        }
+      }
+      // TODO #XXX: Support other types of source tasks, i. e. InitializedSourceTask
+    });
+  }
 
-    final Map<String, Object> metric = new HashMap<>();
-    metricCollector.beginMeasurement(physicalTaskId, metric);
-    long accumulatedBlockedReadTime = 0;
-    long accumulatedWriteTime = 0;
-    long accumulatedSerializedBlockSize = 0;
-    long accumulatedEncodedBlockSize = 0;
-    boolean blockSizeAvailable = true;
-
-    final long readStartTime = System.currentTimeMillis();
-    // Check for side inputs
-    physicalTaskIdToInputReaderMap.get(physicalTaskId).stream().filter(InputReader::isSideInputReader)
-        .forEach(inputReader -> {
+  private void prepareInputFromOtherStages() {
+    inputReaders.stream().forEach(inputReader -> {
+      final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = inputReader.read();
+      numPartitions += futures.size();
+
+      // Add consumers which will push iterator when the futures are complete.
+      futures.forEach(compFuture -> compFuture.whenComplete((iterator, exception) -> {
+        if (exception != null) {
+          throw new BlockFetchException(exception);
+        }
+
+        final String iteratorId = generateIteratorId();
+        if (iteratorIdToTasksMap.containsKey(iteratorId)) {
+          throw new RuntimeException("iteratorIdToTasksMap already contains " + iteratorId);
+        } else {
+          iteratorIdToTasksMap.computeIfAbsent(iteratorId, absentIteratorId -> {
+            final List<Task> list = new ArrayList<>();
+            list.addAll(inputReaderToTasksMap.get(inputReader));
+            return Collections.unmodifiableList(list);
+          });
           try {
-            if (!inputReader.isSideInputReader()) {
-              // Trying to get sideInput from a reader that does not handle sideInput.
-              // This is probably a bug. We're not trying to recover but ensure a hard fail.
-              throw new RuntimeException("Trying to get sideInput from non-sideInput reader");
-            }
-            final DataUtil.IteratorWithNumBytes sideInputIterator = inputReader.read().get(0).get();
-            final Object sideInput = getSideInput(sideInputIterator);
-
-            final RuntimeEdge inEdge = inputReader.getRuntimeEdge();
-            final Transform srcTransform;
-            if (inEdge instanceof PhysicalStageEdge) {
-              srcTransform = ((OperatorVertex) ((PhysicalStageEdge) inEdge).getSrcVertex())
-                  .getTransform();
-            } else {
-              srcTransform = ((OperatorTask) inEdge.getSrc()).getTransform();
-            }
-            sideInputMap.put(srcTransform, sideInput);
-            sideInputIterators.add(sideInputIterator);
-          } catch (final InterruptedException | ExecutionException e) {
-            throw new BlockFetchException(e);
+            partitionQueue.put(Pair.of(iteratorId, iterator));
+          } catch (InterruptedException e) {
+            throw new RuntimeException("Interrupted while receiving iterator " + e);
           }
-        });
+        }
+      }));
+    });
+  }
 
-    for (final DataUtil.IteratorWithNumBytes iterator : sideInputIterators) {
-      try {
-        accumulatedSerializedBlockSize += iterator.getNumSerializedBytes();
-        accumulatedEncodedBlockSize += iterator.getNumEncodedBytes();
-      } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
-        blockSizeAvailable = false;
-        break;
-      }
-    }
+  private boolean finishedAllTasks() {
+    // Total size of this TaskGroup
+    int taskNum = taskToDataHandlerMap.keySet().size();
+    int finishedTaskNum = finishedTaskIds.size();
 
-    final Transform.Context transformContext = new ContextImpl(sideInputMap);
-    final OutputCollectorImpl outputCollector = new OutputCollectorImpl();
+    return finishedTaskNum == taskNum;
+  }
 
-    final Transform transform = operatorTask.getTransform();
-    transform.prepare(transformContext, outputCollector);
-
-    // Check for non-side inputs
-    // This blocking queue contains the pairs having data and source vertex ids.
-    final BlockingQueue<Pair<DataUtil.IteratorWithNumBytes, String>> dataQueue = new LinkedBlockingQueue<>();
-    final AtomicInteger sourceParallelism = new AtomicInteger(0);
-    physicalTaskIdToInputReaderMap.get(physicalTaskId).stream().filter(inputReader -> !inputReader.isSideInputReader())
-        .forEach(inputReader -> {
-          final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = inputReader.read();
-          final String srcIrVtxId = inputReader.getSrcIrVertexId();
-          sourceParallelism.getAndAdd(inputReader.getSourceParallelism());
-          // Add consumers which will push the data to the data queue when it ready to the futures.
-          futures.forEach(compFuture -> compFuture.whenComplete((data, exception) -> {
-            if (exception != null) {
-              throw new BlockFetchException(exception);
-            }
-            dataQueue.add(Pair.of(data, srcIrVtxId));
-          }));
-        });
-    final long readFutureEndTime = System.currentTimeMillis();
-    // Consumes all of the partitions from incoming edges.
-    for (int srcTaskNum = 0; srcTaskNum < sourceParallelism.get(); srcTaskNum++) {
+  private void initializePipeToDstTasksMap() {
+    srcIteratorIdToTasksMap.values().forEach(tasks ->
+        tasks.forEach(task -> {
+          final OutputCollectorImpl outputCollector = taskToDataHandlerMap.get(task).getOutputCollector();
+          final List<Task> dstTasks = taskGroupDag.getChildren(task.getId());
+          outputCollectorToDstTasksMap.putIfAbsent(outputCollector, dstTasks);
+          LOG.info("{} outputCollectorToDstTasksMap: [{}'s OutputPipe, {}]",
+              taskGroupId, getPhysicalTaskId(task.getId()), dstTasks);
+        }));
+    iteratorIdToTasksMap.values().forEach(tasks ->
+        tasks.forEach(task -> {
+          final OutputCollectorImpl outputCollector = taskToDataHandlerMap.get(task).getOutputCollector();
+          final List<Task> dstTasks = taskGroupDag.getChildren(task.getId());
+          outputCollectorToDstTasksMap.putIfAbsent(outputCollector, dstTasks);
+          LOG.info("{} outputCollectorToDstTasksMap: [{}'s OutputPipe, {}]",
+              taskGroupId, getPhysicalTaskId(task.getId()), dstTasks);
+        }));
+  }
+
+  private void updatePipeToDstTasksMap() {
+    Map<OutputCollectorImpl, List<Task>> currentMap = outputCollectorToDstTasksMap;
+    Map<OutputCollectorImpl, List<Task>> updatedMap = new HashMap<>();
+
+    currentMap.values().forEach(tasks ->
+        tasks.forEach(task -> {
+          final OutputCollectorImpl outputCollector = taskToDataHandlerMap.get(task).getOutputCollector();
+          final List<Task> dstTasks = taskGroupDag.getChildren(task.getId());
+          updatedMap.putIfAbsent(outputCollector, dstTasks);
+          LOG.info("{} outputCollectorToDstTasksMap: [{}, {}]",
+              taskGroupId, getPhysicalTaskId(task.getId()), dstTasks);
+        })
+    );
+
+    outputCollectorToDstTasksMap = updatedMap;
+  }
+
+  private void closeTransform(final Task task) {
+    if (task instanceof OperatorTask) {
+      Transform transform = ((OperatorTask) task).getTransform();
+      transform.close();
+      LOG.info("{} {} Closed Transform {}!", taskGroupId, getPhysicalTaskId(task.getId()), transform);
+    }
+  }
+
+  private void sideInputFromOtherStages(final Task task, final Map<Transform, Object> sideInputMap) {
+    taskToDataHandlerMap.get(task).getSideInputFromOtherStages().forEach(sideInputReader -> {
       try {
-        // Because the data queue is a blocking queue, we may need to wait some available data to be pushed.
-        final long blockedReadStartTime = System.currentTimeMillis();
-        final Pair<DataUtil.IteratorWithNumBytes, String> availableData = dataQueue.take();
-        final long blockedReadEndTime = System.currentTimeMillis();
-        accumulatedBlockedReadTime += blockedReadEndTime - blockedReadStartTime;
-        transform.onData(availableData.left(), availableData.right());
-        if (blockSizeAvailable) {
-          try {
-            accumulatedSerializedBlockSize += availableData.left().getNumSerializedBytes();
-            accumulatedEncodedBlockSize += availableData.left().getNumEncodedBytes();
-          } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
-            blockSizeAvailable = false;
-          }
+        final DataUtil.IteratorWithNumBytes sideInputIterator = sideInputReader.read().get(0).get();
+        final Object sideInput = getSideInput(sideInputIterator);
+        final RuntimeEdge inEdge = sideInputReader.getRuntimeEdge();
+        final Transform srcTransform;
+        if (inEdge instanceof PhysicalStageEdge) {
+          srcTransform = ((OperatorVertex) ((PhysicalStageEdge) inEdge).getSrcVertex()).getTransform();
+        } else {
+          srcTransform = ((OperatorTask) inEdge.getSrc()).getTransform();
+        }
+        sideInputMap.put(srcTransform, sideInput);
+
+        // Collect metrics on block size if possible.
+        try {
+          serBlockSize += sideInputIterator.getNumSerializedBytes();
+        } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
+          serBlockSize = -1;
         }
-      } catch (final InterruptedException e) {
+        try {
+          encodedBlockSize += sideInputIterator.getNumEncodedBytes();
+        } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
+          encodedBlockSize = -1;
+        }
+
+        LOG.info("log: {} {} read sideInput from InputReader {}",
+            taskGroupId, getPhysicalTaskId(task.getId()), sideInput);
+      } catch (final InterruptedException | ExecutionException e) {
         throw new BlockFetchException(e);
       }
+    });
+  }
+
+  private void sideInputFromThisStage(final Task task, final Map<Transform, Object> sideInputMap) {
+    final String physicalTaskId = getPhysicalTaskId(task.getId());
+    taskToDataHandlerMap.get(task).getSideInputFromThisStage().forEach(input -> {
+      // because sideInput is only 1 element in the outputCollector
+      Object sideInput = input.remove();
+      final RuntimeEdge inEdge = input.getSideInputRuntimeEdge();
+      final Transform srcTransform;
+      if (inEdge instanceof PhysicalStageEdge) {
+        srcTransform = ((OperatorVertex) ((PhysicalStageEdge) inEdge).getSrcVertex()).getTransform();
+      } else {
+        srcTransform = ((OperatorTask) inEdge.getSrc()).getTransform();
+      }
+      sideInputMap.put(srcTransform, sideInput);
+      LOG.info("log: {} {} read sideInput from InputPipe {}", taskGroupId, physicalTaskId, sideInput);
+    });
+  }
 
-      // Check whether there is any output data from the transform and write the output of this task to the writer.
-      final List output = outputCollector.collectOutputList();
-      if (!output.isEmpty() && physicalTaskIdToOutputWriterMap.containsKey(physicalTaskId)) {
-        final long writeStartTime = System.currentTimeMillis();
-        physicalTaskIdToOutputWriterMap.get(physicalTaskId).forEach(outputWriter -> outputWriter.write(output));
-        final long writeEndTime = System.currentTimeMillis();
-        accumulatedWriteTime += writeEndTime - writeStartTime;
-      } // If else, this is a sink task.
+  /**
+   * Executes the task group.
+   */
+  public void execute() {
+    final Map<String, Object> metric = new HashMap<>();
+    metricCollector.beginMeasurement(taskGroupId, metric);
+    long boundedSrcReadStartTime = 0;
+    long boundedSrcReadEndTime = 0;
+    long inputReadStartTime = 0;
+    long inputReadEndTime = 0;
+    if (isExecutionRequested) {
+      throw new RuntimeException("TaskGroup {" + taskGroupId + "} execution called again!");
+    } else {
+      isExecutionRequested = true;
     }
-    transform.close();
+    taskGroupStateManager.onTaskGroupStateChanged(TaskGroupState.State.EXECUTING, Optional.empty(), Optional.empty());
+    LOG.info("{} Executing!", taskGroupId);
+
+    // Prepare input data from bounded source.
+    boundedSrcReadStartTime = System.currentTimeMillis();
+    prepareInputFromSource();
+    boundedSrcReadEndTime = System.currentTimeMillis();
+    metric.put("BoundedSourceReadTime(ms)", boundedSrcReadEndTime - boundedSrcReadStartTime);
+
+    // Prepare input data from other stages.
+    inputReadStartTime = System.currentTimeMillis();
+    prepareInputFromOtherStages();
+
+    // Execute the TaskGroup DAG.
+    try {
+      srcIteratorIdToTasksMap.forEach((srcIteratorId, tasks) -> {
+        Iterator iterator = idToSrcIteratorMap.get(srcIteratorId);
+        iterator.forEachRemaining(element -> {
+          for (final Task task : tasks) {
+            List data = Collections.singletonList(element);
+            runTask(task, data);
+          }
+        });
+      });
 
-    metric.put("InputReadTime(ms)", readFutureEndTime - readStartTime + accumulatedBlockedReadTime);
-    final long transformEndTime = System.currentTimeMillis();
-    metric.put("TransformTime(ms)",
-        transformEndTime - readFutureEndTime - accumulatedWriteTime - accumulatedBlockedReadTime);
+      // Process data from other stages.
+      for (int currPartition = 0; currPartition < numPartitions; currPartition++) {
+        LOG.info("{} Partition {} out of {}", taskGroupId, currPartition, numPartitions);
+
+        Pair<String, DataUtil.IteratorWithNumBytes> idToIteratorPair = partitionQueue.take();
+        final String iteratorId = idToIteratorPair.left();
+        final DataUtil.IteratorWithNumBytes iterator = idToIteratorPair.right();
+        List<Task> dstTasks = iteratorIdToTasksMap.get(iteratorId);
+        iterator.forEachRemaining(element -> {
+          for (final Task task : dstTasks) {
+            List data = Collections.singletonList(element);
+            runTask(task, data);
+          }
+        });
 
-    // Check whether there is any output data from the transform and write the output of this task to the writer.
-    final List<Long> writtenBytesList = new ArrayList<>();
-    final List output = outputCollector.collectOutputList();
-    if (physicalTaskIdToOutputWriterMap.containsKey(physicalTaskId)) {
-      for (final OutputWriter outputWriter : physicalTaskIdToOutputWriterMap.get(physicalTaskId)) {
-        if (!output.isEmpty()) {
-          outputWriter.write(output);
+        // Collect metrics on block size if possible.
+        try {
+          serBlockSize += iterator.getNumSerializedBytes();
+        } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
+          serBlockSize = -1;
+        } catch (final IllegalStateException e) {
+          LOG.error("IllegalState ", e);
+        }
+        try {
+          encodedBlockSize += iterator.getNumEncodedBytes();
+        } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
+          encodedBlockSize = -1;
+        } catch (final IllegalStateException e) {
+          LOG.error("IllegalState ", e);
         }
-        outputWriter.close();
-        final Optional<Long> writtenBytes = outputWriter.getWrittenBytes();
-        writtenBytes.ifPresent(writtenBytesList::add);
       }
-    } else {
-      LOG.info("This is a sink task: {}", physicalTaskId);
+      inputReadEndTime = System.currentTimeMillis();
+      metric.put("InputReadTime(ms)", inputReadEndTime - inputReadStartTime);
+      LOG.info("{} Finished processing src data!", taskGroupId);
+
+      // Process intra-TaskGroup data.
+      // Intra-TaskGroup data comes from outputCollectors of this TaskGroup's Tasks.
+      initializePipeToDstTasksMap();
+      while (!finishedAllTasks()) {
+        outputCollectorToDstTasksMap.forEach((outputCollector, dstTasks) -> {
+          // Get the task that has this outputCollector as its output outputCollector
+          Task outputCollectorOwnerTask = taskToDataHandlerMap.values().stream()
+              .filter(dataHandler -> dataHandler.getOutputCollector() == outputCollector)
+              .findFirst().get().getTask();
+          LOG.info("{} outputCollectorOwnerTask {}", taskGroupId, getPhysicalTaskId(outputCollectorOwnerTask.getId()));
+
+          // Before consuming the output of outputCollectorOwnerTask as input,
+          // close transform if it is OperatorTransform.
+          closeTransform(outputCollectorOwnerTask);
+
+          // Set outputCollectorOwnerTask as finished.
+          finishedTaskIds.add(getPhysicalTaskId(outputCollectorOwnerTask.getId()));
+
+          // Pass outputCollectorOwnerTask's output to its children tasks recursively.
+          if (!dstTasks.isEmpty()) {
+            while (!outputCollector.isEmpty()) {
+              // Form input element-wise from the outputCollector
+              final List input = Collections.singletonList(outputCollector.remove());
+              for (final Task task : dstTasks) {
+                runTask(task, input);
+              }
+            }
+          }
+
+          // Write the whole iterable and close the OutputWriters.
+          if (hasOutputWriter(outputCollectorOwnerTask)) {
+            // If outputCollector isn't empty(if closeTransform produced some output),
+            // write them element-wise to OutputWriters.
+            while (!outputCollector.isEmpty()) {
+              final Object element = outputCollector.remove();
+              List<OutputWriter> outputWritersOfTask =
+                  taskToDataHandlerMap.get(outputCollectorOwnerTask).getOutputWriters();
+              outputWritersOfTask.forEach(outputWriter -> outputWriter.writeElement(element));
+              LOG.info("{} {} Write to OutputWriter element {}",
+                  taskGroupId, getPhysicalTaskId(outputCollectorOwnerTask.getId()), element);
+            }
+            writeAndCloseOutputWriters(outputCollectorOwnerTask);
+          }
+        });
+        updatePipeToDstTasksMap();
+      }
+    } catch (final BlockWriteException ex2) {
+      taskGroupStateManager.onTaskGroupStateChanged(TaskGroupState.State.FAILED_RECOVERABLE,
+          Optional.empty(), Optional.of(TaskGroupState.RecoverableFailureCause.OUTPUT_WRITE_FAILURE));
+      LOG.info("{} Execution Failed (Recoverable: output write failure)! Exception: {}",
+          taskGroupId, ex2.toString());
+    } catch (final Exception e) {
+      taskGroupStateManager.onTaskGroupStateChanged(TaskGroupState.State.FAILED_UNRECOVERABLE,
+          Optional.empty(), Optional.empty());
+      LOG.info("{} Execution Failed! Exception: {}",
+          taskGroupId, e.toString());
+      throw new RuntimeException(e);
     }
-    final long writeEndTime = System.currentTimeMillis();
-    metric.put("OutputTime(ms)", writeEndTime - transformEndTime + accumulatedWriteTime);
-    putReadBytesMetric(blockSizeAvailable, accumulatedSerializedBlockSize, accumulatedEncodedBlockSize, metric);
-    putWrittenBytesMetric(writtenBytesList, metric);
 
-    metricCollector.endMeasurement(physicalTaskId, metric);
+    // Put TaskGroup-unit metrics.
+    final boolean available = serBlockSize >= 0;
+    putReadBytesMetric(available, serBlockSize, encodedBlockSize, metric);
+    metricCollector.endMeasurement(taskGroupId, metric);
+    if (logicalTaskIdPutOnHold == null) {
+      taskGroupStateManager.onTaskGroupStateChanged(TaskGroupState.State.COMPLETE, Optional.empty(), Optional.empty());
+    } else {
+      taskGroupStateManager.onTaskGroupStateChanged(TaskGroupState.State.ON_HOLD,
+          Optional.of(logicalTaskIdPutOnHold),
+          Optional.empty());
+    }
+    LOG.info("{} Complete!", taskGroupId);
   }
 
   /**
-   * Pass on the data to the following tasks.
-   * @param task the task to carry on the data.
+   * Processes an OperatorTask.
+   *
+   * @param task to execute
    */
-  private void launchMetricCollectionBarrierTask(final MetricCollectionBarrierTask task) {
+  private void runTask(final Task task, final List<Object> data) {
     final String physicalTaskId = getPhysicalTaskId(task.getId());
-    final Map<String, Object> metric = new HashMap<>();
-    metricCollector.beginMeasurement(physicalTaskId, metric);
-    long accumulatedSerializedBlockSize = 0;
-    long accumulatedEncodedBlockSize = 0;
-    boolean blockSizeAvailable = true;
-
-    final long readStartTime = System.currentTimeMillis();
-    final BlockingQueue<DataUtil.IteratorWithNumBytes> dataQueue = new LinkedBlockingQueue<>();
-    final AtomicInteger sourceParallelism = new AtomicInteger(0);
-    physicalTaskIdToInputReaderMap.get(physicalTaskId).stream().filter(inputReader -> !inputReader.isSideInputReader())
-        .forEach(inputReader -> {
-          sourceParallelism.getAndAdd(inputReader.getSourceParallelism());
-          inputReader.read().forEach(compFuture -> compFuture.thenAccept(dataQueue::add));
+
+    // Process element-wise depending on the Task type
+    if (task instanceof BoundedSourceTask) {
+      OutputCollectorImpl outputCollector = taskToDataHandlerMap.get(task).getOutputCollector();
+
+      if (data.contains(null)) {  // data is [null] used for VoidCoders
+        outputCollector.emit(data);
+      } else {
+        data.forEach(dataElement -> {
+          outputCollector.emit(dataElement);
+          LOG.info("log: {} {} BoundedSourceTask emitting {} to outputCollector",
+              taskGroupId, physicalTaskId, dataElement);
         });
+      }
+    } else if (task instanceof OperatorTask) {
+      final Transform transform = ((OperatorTask) task).getTransform();
+
+      // Consumes the received element from incoming edges.
+      // Calculate the number of inter-TaskGroup data to process.
+      int numElements = data.size();
+      LOG.info("log: {} {}: numElements {}", taskGroupId, physicalTaskId, numElements);
+
+      IntStream.range(0, numElements).forEach(dataNum -> {
+        Object dataElement = data.get(dataNum);
+        LOG.info("log: {} {} OperatorTask applying {} to onData", taskGroupId, physicalTaskId, dataElement);
+        transform.onData(dataElement);
+      });
+    } else if (task instanceof MetricCollectionBarrierTask) {
+      OutputCollectorImpl outputCollector = taskToDataHandlerMap.get(task).getOutputCollector();
 
-    final List data = new ArrayList<>();
-    for (int srcTaskNum = 0; srcTaskNum < sourceParallelism.get(); srcTaskNum++) {
-      try {
-        final DataUtil.IteratorWithNumBytes availableData = dataQueue.take();
-        availableData.forEachRemaining(data::add);
-        if (blockSizeAvailable) {
-          try {
-            accumulatedSerializedBlockSize += availableData.getNumSerializedBytes();
-            accumulatedEncodedBlockSize += availableData.getNumEncodedBytes();
-          } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
-            blockSizeAvailable = false;
-          }
-        }
-      } catch (final InterruptedException e) {
-        throw new BlockFetchException(e);
+      if (data.contains(null)) {  // data is [null] used for VoidCoders
+        outputCollector.emit(data);
+      } else {
+        data.forEach(dataElement -> {
+          outputCollector.emit(dataElement);
+          LOG.info("log: {} {} MetricCollectionTask emitting {} to outputCollector",
+              taskGroupId, physicalTaskId, dataElement);
+        });
       }
+      setTaskPutOnHold((MetricCollectionBarrierTask) task);
+    } else {
+      throw new UnsupportedOperationException("This type  of Task is not supported");
     }
-    final long readEndTime = System.currentTimeMillis();
-    metric.put("InputReadTime(ms)", readEndTime - readStartTime);
 
-    final List<Long> writtenBytesList = new ArrayList<>();
-    for (final OutputWriter outputWriter : physicalTaskIdToOutputWriterMap.get(physicalTaskId)) {
-      outputWriter.write(data);
-      outputWriter.close();
-      final Optional<Long> writtenBytes = outputWriter.getWrittenBytes();
-      writtenBytes.ifPresent(writtenBytesList::add);
+    // For the produced output
+    OutputCollectorImpl outputCollector = taskToDataHandlerMap.get(task).getOutputCollector();
 
 Review comment:
   final?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services