You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@seatunnel.apache.org by GitBox <gi...@apache.org> on 2022/08/10 07:08:14 UTC

[GitHub] [incubator-seatunnel] Hisoka-X commented on a diff in pull request #2366: [ST-Engine][TaskExecutionService]Add dynamic thread sharing optimization

Hisoka-X commented on code in PR #2366:
URL: https://github.com/apache/incubator-seatunnel/pull/2366#discussion_r942091513


##########
seatunnel-engine/seatunnel-engine-server/src/main/java/org/apache/seatunnel/engine/server/TaskExecutionService.java:
##########
@@ -18,134 +18,149 @@
 package org.apache.seatunnel.engine.server;
 
 import static com.hazelcast.jet.impl.util.ExceptionUtil.withTryCatch;
+import static com.hazelcast.jet.impl.util.Util.uncheckRun;
+import static java.util.Collections.emptyList;
 import static java.util.concurrent.Executors.newCachedThreadPool;
+import static java.util.stream.Collectors.partitioningBy;
+import static java.util.stream.Collectors.toList;
 
+import org.apache.seatunnel.engine.common.utils.NonCompletableFuture;
+import org.apache.seatunnel.engine.server.execution.ExecutionState;
 import org.apache.seatunnel.engine.server.execution.ProgressState;
 import org.apache.seatunnel.engine.server.execution.Task;
+import org.apache.seatunnel.engine.server.execution.TaskCallTimer;
 import org.apache.seatunnel.engine.server.execution.TaskExecutionContext;
+import org.apache.seatunnel.engine.server.execution.TaskExecutionState;
 import org.apache.seatunnel.engine.server.execution.TaskGroup;
+import org.apache.seatunnel.engine.server.execution.TaskTracker;
 
-import com.hazelcast.jet.impl.util.NonCompletableFuture;
 import com.hazelcast.logging.ILogger;
-import com.hazelcast.spi.impl.NodeEngine;
 import com.hazelcast.spi.impl.NodeEngineImpl;
 import com.hazelcast.spi.properties.HazelcastProperties;
 import lombok.NonNull;
+import lombok.SneakyThrows;
 
-import java.util.HashMap;
+import java.util.Collection;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
+import java.util.concurrent.LinkedBlockingDeque;
 import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
 
 /**
  * This class is responsible for the execution of the Task
  */
 public class TaskExecutionService {
 
     private final String hzInstanceName;
-    private final NodeEngine nodeEngine;
+    private final NodeEngineImpl nodeEngine;
     private final ILogger logger;
     private volatile boolean isShutdown;
-    private final ExecutorService blockingTaskletExecutor = newCachedThreadPool(new BlockingTaskThreadFactory());
+    private final LinkedBlockingDeque<TaskTracker> threadShareTaskQueue = new LinkedBlockingDeque<>();
+    private final ExecutorService executorService = newCachedThreadPool(new BlockingTaskThreadFactory());
+    private final RunBusWorkSupplier runBusWorkSupplier = new RunBusWorkSupplier(executorService, threadShareTaskQueue);
     // key: TaskID
-    private final ConcurrentMap<Long, TaskExecutionContext> executionContexts = new ConcurrentHashMap<>();
+    private final ConcurrentMap<Long, ConcurrentMap<Long, TaskExecutionContext>> executionContexts = new ConcurrentHashMap<>();
 
     public TaskExecutionService(NodeEngineImpl nodeEngine, HazelcastProperties properties) {
         this.hzInstanceName = nodeEngine.getHazelcastInstance().getName();
         this.nodeEngine = nodeEngine;
         this.logger = nodeEngine.getLoggingService().getLogger(TaskExecutionService.class);
     }
 
+    public void start() {
+        runBusWorkSupplier.runNewBusWork(false);
+    }
+
     public void shutdown() {
         isShutdown = true;
-        blockingTaskletExecutor.shutdownNow();
+        executorService.shutdownNow();
     }
 
-    public TaskExecutionContext getExecutionContext(long taskId) {
-        return executionContexts.get(taskId);
+    public ConcurrentMap<Long, TaskExecutionContext> getExecutionContext(long taskGroupId) {
+        return executionContexts.get(taskGroupId);
     }
 
-    /**
-     * Submit a TaskGroup and run the Task in it
-     */
-    public Map<Long, TaskExecutionContext> submitTask(
-        TaskGroup taskGroup
-    ) {
-        Map<Long, TaskExecutionContext> contextMap = new HashMap<>(taskGroup.getTasks().size());
-        taskGroup.getTasks().forEach(task -> {
-            contextMap.put(task.getTaskID(), submitTask(task));
-        });
-        return contextMap;
+    private void submitThreadShareTask(TaskGroupExecutionTracker taskGroupExecutionTracker, List<Task> tasks) {
+        tasks.stream()
+            .map(t -> new TaskTracker(t, taskGroupExecutionTracker))
+            .forEach(threadShareTaskQueue::add);
     }
 
-    public TaskExecutionContext submitTask(Task task) {
-        CompletableFuture<Void> cancellationFuture = new CompletableFuture<Void>();
-        TaskletTracker taskletTracker = new TaskletTracker(task, cancellationFuture);
-        taskletTracker.taskletFutures =
-            blockingTaskletExecutor.submit(new BlockingWorker(taskletTracker));
+    private void submitBlockingTask(TaskGroupExecutionTracker taskGroupExecutionTracker, List<Task> tasks) {
 
-        TaskExecutionContext taskExecutionContext = new TaskExecutionContext(
-            taskletTracker.future,
-            cancellationFuture,
-            this
-        );
-
-        executionContexts.put(task.getTaskID(), taskExecutionContext);
-        return taskExecutionContext;
+        CountDownLatch startedLatch = new CountDownLatch(tasks.size());
+        taskGroupExecutionTracker.blockingFutures = tasks
+            .stream()
+            .map(t -> new BlockingWorker(new TaskTracker(t, taskGroupExecutionTracker), startedLatch))
+            .map(executorService::submit)
+            .collect(toList());
 
+        // Do not return from this method until all workers have started. Otherwise
+        // on cancellation there is a race where the executor might not have started
+        // the worker yet. This would result in taskletDone() never being called for
+        // a worker.
+        uncheckRun(startedLatch::await);
     }
 
-    private final class TaskletTracker {
-        final NonCompletableFuture future = new NonCompletableFuture();
-        final Task task;
-        volatile Future<?> taskletFutures;
-
-        TaskletTracker(Task task, CompletableFuture<Void> cancellationFuture) {
-            this.task = task;
-
-            cancellationFuture.whenComplete(withTryCatch(logger, (r, e) -> {
-                if (e == null) {
-                    e = new IllegalStateException("cancellationFuture should be completed exceptionally");
-                }
-                future.internalCompleteExceptionally(e);
-                taskletFutures.cancel(true);
-            }));
-        }
-
-        @Override
-        public String toString() {
-            return "Tracking " + task;
+    public CompletableFuture<TaskExecutionState> submitTaskGroup(
+        TaskGroup taskGroup,
+        CompletableFuture<Void> cancellationFuture
+    ) {
+        Collection<Task> tasks = taskGroup.getTasks();
+        final TaskGroupExecutionTracker executionTracker = new TaskGroupExecutionTracker(cancellationFuture, taskGroup);
+        try {
+            ConcurrentMap<Long, TaskExecutionContext> taskExecutionContextMap = new ConcurrentHashMap<>();
+            final Map<Boolean, List<Task>> byCooperation =
+                tasks.stream()
+                    .peek(x -> {
+                        TaskExecutionContext taskExecutionContext = new TaskExecutionContext(x, nodeEngine);
+                        x.setTaskExecutionContext(taskExecutionContext);
+                        taskExecutionContextMap.put(x.getTaskID(), taskExecutionContext);
+                    })
+                    .collect(partitioningBy(Task::isThreadsShare));
+            submitThreadShareTask(executionTracker, byCooperation.get(true));
+            submitBlockingTask(executionTracker, byCooperation.get(false));
+            executionContexts.put(taskGroup.getId(), taskExecutionContextMap);
+        } catch (Throwable t) {
+            executionTracker.future.complete(new TaskExecutionState(taskGroup.getId(), ExecutionState.FAILED, t));
         }
+        return new NonCompletableFuture<>(executionTracker.future);

Review Comment:
   I don't get why use `NonCompletableFuture`, if user want to cancel this job, how to do it?



##########
seatunnel-engine/seatunnel-engine-server/src/main/java/org/apache/seatunnel/engine/server/TaskExecutionService.java:
##########
@@ -156,7 +171,162 @@ private final class BlockingTaskThreadFactory implements ThreadFactory {
         @Override
         public Thread newThread(@NonNull Runnable r) {
             return new Thread(r,
-                String.format("hz.%s.seaTunnel.blocking.thread-%d", hzInstanceName, seq.getAndIncrement()));
+                String.format("hz.%s.seaTunnel.task.thread-%d", hzInstanceName, seq.getAndIncrement()));
+        }
+    }
+
+    /**
+     * BusWork is used to poll the task call method,
+     * When a task times out, a new BusWork will be created to take over the execution of the task
+     */
+    public final class BusWork implements Runnable {
+
+        AtomicBoolean keep = new AtomicBoolean(true);
+        public AtomicReference<TaskTracker> exclusiveTaskTracker = new AtomicReference<>();
+        final TaskCallTimer timer;
+        public LinkedBlockingDeque<TaskTracker> taskqueue;
+
+        @SuppressWarnings("checkstyle:MagicNumber")
+        public BusWork(LinkedBlockingDeque<TaskTracker> taskqueue, RunBusWorkSupplier runBusWorkSupplier) {

Review Comment:
   Suggestion use `PriorityBlockingQueue` to support unfair task execute order.



##########
seatunnel-engine/seatunnel-engine-server/src/main/java/org/apache/seatunnel/engine/server/TaskExecutionService.java:
##########
@@ -156,7 +171,162 @@ private final class BlockingTaskThreadFactory implements ThreadFactory {
         @Override
         public Thread newThread(@NonNull Runnable r) {
             return new Thread(r,
-                String.format("hz.%s.seaTunnel.blocking.thread-%d", hzInstanceName, seq.getAndIncrement()));
+                String.format("hz.%s.seaTunnel.task.thread-%d", hzInstanceName, seq.getAndIncrement()));
+        }
+    }
+
+    /**
+     * BusWork is used to poll the task call method,
+     * When a task times out, a new BusWork will be created to take over the execution of the task
+     */
+    public final class BusWork implements Runnable {
+
+        AtomicBoolean keep = new AtomicBoolean(true);
+        public AtomicReference<TaskTracker> exclusiveTaskTracker = new AtomicReference<>();
+        final TaskCallTimer timer;
+        public LinkedBlockingDeque<TaskTracker> taskqueue;
+
+        @SuppressWarnings("checkstyle:MagicNumber")
+        public BusWork(LinkedBlockingDeque<TaskTracker> taskqueue, RunBusWorkSupplier runBusWorkSupplier) {
+            logger.info(String.format("Created new BusWork : %s", this.hashCode()));
+            this.taskqueue = taskqueue;
+            this.timer = new TaskCallTimer(50, keep, runBusWorkSupplier, this);
+        }
+
+        @SneakyThrows
+        @Override
+        public void run() {
+            while (keep.get()) {
+                TaskTracker taskTracker = null != exclusiveTaskTracker.get() ?
+                    exclusiveTaskTracker.get() :
+                    taskqueue.takeFirst();
+                TaskGroupExecutionTracker taskGroupExecutionTracker = taskTracker.taskGroupExecutionTracker;
+                if (taskGroupExecutionTracker.executionCompletedExceptionally()) {
+                    taskGroupExecutionTracker.taskDone();
+                    if (null != exclusiveTaskTracker.get()) {
+                        // If it's exclusive need to end the work
+                        break;
+                    } else {
+                        // No action required and don't put back
+                        continue;
+                    }
+                }
+                //start timer, if it's exclusive, don't need to start
+                if (null == exclusiveTaskTracker.get()) {
+                    timer.timerStart(taskTracker);
+                }
+                ProgressState call = null;
+                try {
+                    //run task
+                    call = taskTracker.task.call();
+                    synchronized (timer) {

Review Comment:
   `timer` is `BusWork`'s field, but one `BusWork` object alway execute by one thread. It can't happend thread safe problem. Any detail I can't think about?



##########
seatunnel-engine/seatunnel-engine-server/src/main/java/org/apache/seatunnel/engine/server/TaskExecutionService.java:
##########
@@ -156,7 +171,162 @@ private final class BlockingTaskThreadFactory implements ThreadFactory {
         @Override
         public Thread newThread(@NonNull Runnable r) {
             return new Thread(r,
-                String.format("hz.%s.seaTunnel.blocking.thread-%d", hzInstanceName, seq.getAndIncrement()));
+                String.format("hz.%s.seaTunnel.task.thread-%d", hzInstanceName, seq.getAndIncrement()));
+        }
+    }
+
+    /**
+     * BusWork is used to poll the task call method,
+     * When a task times out, a new BusWork will be created to take over the execution of the task
+     */
+    public final class BusWork implements Runnable {

Review Comment:
   I don't understand what is `BusWork` before I read all code in this class. Maybe change a name?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@seatunnel.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org