You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ro...@apache.org on 2021/05/10 12:26:45 UTC

[flink] branch release-1.13 updated: [FLINK-21181][runtime] Wait for Invokable cancellation before releasing network resources

This is an automated email from the ASF dual-hosted git repository.

roman pushed a commit to branch release-1.13
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.13 by this push:
     new ce7c78c  [FLINK-21181][runtime] Wait for Invokable cancellation before releasing network resources
ce7c78c is described below

commit ce7c78ca72ce86df0b4a28fbd3233b89da3238e1
Author: Roman Khachatryan <kh...@gmail.com>
AuthorDate: Wed Apr 28 00:02:33 2021 +0200

    [FLINK-21181][runtime] Wait for Invokable cancellation before releasing network resources
---
 .../program/PerJobMiniClusterFactoryTest.java      |   7 +-
 .../runtime/webmonitor/WebFrontendITCase.java      |  15 ++-
 .../iterative/task/AbstractIterativeTask.java      |  12 +-
 .../runtime/iterative/task/IterationHeadTask.java  |   1 +
 .../iterative/task/IterationIntermediateTask.java  |  61 +++++-----
 .../task/IterationSynchronizationSinkTask.java     |   9 ++
 .../runtime/iterative/task/IterationTailTask.java  |  81 +++++++-------
 .../flink/runtime/iterative/task/Terminable.java   |   2 +
 .../runtime/jobgraph/tasks/AbstractInvokable.java  |   6 +-
 .../apache/flink/runtime/operators/BatchTask.java  |   7 +-
 .../flink/runtime/operators/DataSinkTask.java      |   9 +-
 .../flink/runtime/operators/DataSourceTask.java    |   8 +-
 .../org/apache/flink/runtime/taskmanager/Task.java |  31 ++++--
 .../jobmaster/TestingAbstractInvokables.java       |   4 +-
 .../CoordinatorEventsExactlyOnceITCase.java        |   3 +-
 .../TaskExecutorOperatorEventHandlingTest.java     |   4 +-
 .../apache/flink/runtime/taskmanager/TaskTest.java |  18 ++-
 .../testtasks/OnceBlockingNoOpInvokable.java       |   5 +-
 .../runtime/testutils/CancelableInvokable.java     |  21 +++-
 .../apache/flink/runtime/jobmanager/Tasks.scala    |   8 +-
 .../flink/streaming/runtime/tasks/StreamTask.java  |  10 +-
 .../streaming/runtime/tasks/StreamTaskTest.java    | 123 +++++++++++++++++++++
 22 files changed, 343 insertions(+), 102 deletions(-)

diff --git a/flink-clients/src/test/java/org/apache/flink/client/program/PerJobMiniClusterFactoryTest.java b/flink-clients/src/test/java/org/apache/flink/client/program/PerJobMiniClusterFactoryTest.java
index f97395c..7061f64 100644
--- a/flink-clients/src/test/java/org/apache/flink/client/program/PerJobMiniClusterFactoryTest.java
+++ b/flink-clients/src/test/java/org/apache/flink/client/program/PerJobMiniClusterFactoryTest.java
@@ -34,7 +34,9 @@ import org.junit.After;
 import org.junit.Test;
 
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
 
 import static org.apache.flink.core.testutils.CommonTestUtils.assertThrows;
 import static org.hamcrest.CoreMatchers.is;
@@ -185,7 +187,7 @@ public class PerJobMiniClusterFactoryTest extends TestLogger {
         }
 
         @Override
-        public void invoke() throws Exception {
+        public void doInvoke() throws Exception {
             synchronized (lock) {
                 while (running) {
                     lock.wait();
@@ -194,11 +196,12 @@ public class PerJobMiniClusterFactoryTest extends TestLogger {
         }
 
         @Override
-        public void cancel() {
+        public Future<Void> cancel() {
             synchronized (lock) {
                 running = false;
                 lock.notifyAll();
             }
+            return CompletableFuture.completedFuture(null);
         }
     }
 }
diff --git a/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/WebFrontendITCase.java b/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/WebFrontendITCase.java
index ec65969..d1c3f87 100644
--- a/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/WebFrontendITCase.java
+++ b/flink-runtime-web/src/test/java/org/apache/flink/runtime/webmonitor/WebFrontendITCase.java
@@ -60,7 +60,9 @@ import java.time.Duration;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Future;
 import java.util.stream.Collectors;
 
 import static org.hamcrest.CoreMatchers.containsString;
@@ -443,6 +445,8 @@ public class WebFrontendITCase extends TestLogger {
 
         private volatile boolean isRunning = true;
 
+        private final CompletableFuture<Void> terminationFuture = new CompletableFuture<>();
+
         public BlockingInvokable(Environment environment) {
             super(environment);
         }
@@ -450,14 +454,19 @@ public class WebFrontendITCase extends TestLogger {
         @Override
         public void invoke() throws Exception {
             latch.countDown();
-            while (isRunning) {
-                Thread.sleep(100);
+            try {
+                while (isRunning) {
+                    Thread.sleep(100);
+                }
+            } finally {
+                terminationFuture.complete(null);
             }
         }
 
         @Override
-        public void cancel() {
+        public Future<Void> cancel() {
             this.isRunning = false;
+            return terminationFuture;
         }
 
         public static void reset() {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/AbstractIterativeTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/AbstractIterativeTask.java
index a6d6e2d..b4c1ea3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/AbstractIterativeTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/AbstractIterativeTask.java
@@ -62,6 +62,7 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Future;
 
 /** The abstract base class for all tasks able to participate in an iteration. */
@@ -88,6 +89,8 @@ public abstract class AbstractIterativeTask<S extends Function, OT> extends Batc
 
     private volatile boolean terminationRequested;
 
+    private final CompletableFuture<Void> terminationCompletionFuture = new CompletableFuture<>();
+
     // --------------------------------------------------------------------------------------------
 
     /**
@@ -311,9 +314,14 @@ public abstract class AbstractIterativeTask<S extends Function, OT> extends Batc
     }
 
     @Override
-    public void cancel() throws Exception {
+    public void terminationCompleted() {
+        this.terminationCompletionFuture.complete(null);
+    }
+
+    @Override
+    public Future<Void> cancel() throws Exception {
         requestTermination();
-        super.cancel();
+        return this.terminationCompletionFuture;
     }
 
     // -----------------------------------------------------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java
index 2309630..2655154 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java
@@ -446,6 +446,7 @@ public class IterationHeadTask<X, Y, S extends Function, OT> extends AbstractIte
             if (solutionSet != null) {
                 solutionSet.close();
             }
+            terminationCompleted();
         }
     }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationIntermediateTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationIntermediateTask.java
index 2de3067..65b27c5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationIntermediateTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationIntermediateTask.java
@@ -101,45 +101,50 @@ public class IterationIntermediateTask<S extends Function, OT>
     @Override
     public void run() throws Exception {
 
-        SuperstepKickoffLatch nextSuperstepLatch =
-                SuperstepKickoffLatchBroker.instance().get(brokerKey());
+        try {
+            SuperstepKickoffLatch nextSuperstepLatch =
+                    SuperstepKickoffLatchBroker.instance().get(brokerKey());
 
-        while (this.running && !terminationRequested()) {
+            while (this.running && !terminationRequested()) {
 
-            if (log.isInfoEnabled()) {
-                log.info(formatLogString("starting iteration [" + currentIteration() + "]"));
-            }
+                if (log.isInfoEnabled()) {
+                    log.info(formatLogString("starting iteration [" + currentIteration() + "]"));
+                }
 
-            super.run();
+                super.run();
 
-            // check if termination was requested
-            verifyEndOfSuperstepState();
+                // check if termination was requested
+                verifyEndOfSuperstepState();
 
-            if (isWorksetUpdate && isWorksetIteration) {
-                long numCollected = worksetUpdateOutputCollector.getElementsCollectedAndReset();
-                worksetAggregator.aggregate(numCollected);
-            }
+                if (isWorksetUpdate && isWorksetIteration) {
+                    long numCollected = worksetUpdateOutputCollector.getElementsCollectedAndReset();
+                    worksetAggregator.aggregate(numCollected);
+                }
 
-            if (log.isInfoEnabled()) {
-                log.info(formatLogString("finishing iteration [" + currentIteration() + "]"));
-            }
+                if (log.isInfoEnabled()) {
+                    log.info(formatLogString("finishing iteration [" + currentIteration() + "]"));
+                }
 
-            // let the successors know that the end of this superstep data is reached
-            sendEndOfSuperstep();
+                // let the successors know that the end of this superstep data is reached
+                sendEndOfSuperstep();
 
-            if (isWorksetUpdate) {
-                // notify iteration head if responsible for workset update
-                worksetBackChannel.notifyOfEndOfSuperstep();
-            }
+                if (isWorksetUpdate) {
+                    // notify iteration head if responsible for workset update
+                    worksetBackChannel.notifyOfEndOfSuperstep();
+                }
 
-            boolean terminated =
-                    nextSuperstepLatch.awaitStartOfSuperstepOrTermination(currentIteration() + 1);
+                boolean terminated =
+                        nextSuperstepLatch.awaitStartOfSuperstepOrTermination(
+                                currentIteration() + 1);
 
-            if (terminated) {
-                requestTermination();
-            } else {
-                incrementIterationCounter();
+                if (terminated) {
+                    requestTermination();
+                } else {
+                    incrementIterationCounter();
+                }
             }
+        } finally {
+            terminationCompleted();
         }
     }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
index bf49c7d..8a2903a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
@@ -40,6 +40,7 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
@@ -73,6 +74,8 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable implemen
 
     private final AtomicBoolean terminated = new AtomicBoolean(false);
 
+    private final CompletableFuture<Void> terminationCompletionFuture = new CompletableFuture<>();
+
     // --------------------------------------------------------------------------------------------
 
     /**
@@ -175,6 +178,7 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable implemen
                 currentIteration++;
             }
         }
+        terminationCompleted();
     }
 
     private boolean checkForConvergence() {
@@ -275,4 +279,9 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable implemen
     public void requestTermination() {
         terminated.set(true);
     }
+
+    @Override
+    public void terminationCompleted() {
+        terminationCompletionFuture.complete(null);
+    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationTailTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationTailTask.java
index fe9cff6..a6aec7e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationTailTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationTailTask.java
@@ -115,45 +115,50 @@ public class IterationTailTask<S extends Function, OT> extends AbstractIterative
     @Override
     public void run() throws Exception {
 
-        SuperstepKickoffLatch nextSuperStepLatch =
-                SuperstepKickoffLatchBroker.instance().get(brokerKey());
-
-        while (this.running && !terminationRequested()) {
-
-            if (log.isInfoEnabled()) {
-                log.info(formatLogString("starting iteration [" + currentIteration() + "]"));
-            }
-
-            super.run();
-
-            // check if termination was requested
-            verifyEndOfSuperstepState();
-
-            if (isWorksetUpdate && isWorksetIteration) {
-                // aggregate workset update element count
-                long numCollected = worksetUpdateOutputCollector.getElementsCollectedAndReset();
-                worksetAggregator.aggregate(numCollected);
-            }
-
-            if (log.isInfoEnabled()) {
-                log.info(formatLogString("finishing iteration [" + currentIteration() + "]"));
-            }
-
-            if (isWorksetUpdate) {
-                // notify iteration head if responsible for workset update
-                worksetBackChannel.notifyOfEndOfSuperstep();
-            } else if (isSolutionSetUpdate) {
-                // notify iteration head if responsible for solution set update
-                solutionSetUpdateBarrier.notifySolutionSetUpdate();
-            }
-
-            boolean terminate =
-                    nextSuperStepLatch.awaitStartOfSuperstepOrTermination(currentIteration() + 1);
-            if (terminate) {
-                requestTermination();
-            } else {
-                incrementIterationCounter();
+        try {
+            SuperstepKickoffLatch nextSuperStepLatch =
+                    SuperstepKickoffLatchBroker.instance().get(brokerKey());
+
+            while (this.running && !terminationRequested()) {
+
+                if (log.isInfoEnabled()) {
+                    log.info(formatLogString("starting iteration [" + currentIteration() + "]"));
+                }
+
+                super.run();
+
+                // check if termination was requested
+                verifyEndOfSuperstepState();
+
+                if (isWorksetUpdate && isWorksetIteration) {
+                    // aggregate workset update element count
+                    long numCollected = worksetUpdateOutputCollector.getElementsCollectedAndReset();
+                    worksetAggregator.aggregate(numCollected);
+                }
+
+                if (log.isInfoEnabled()) {
+                    log.info(formatLogString("finishing iteration [" + currentIteration() + "]"));
+                }
+
+                if (isWorksetUpdate) {
+                    // notify iteration head if responsible for workset update
+                    worksetBackChannel.notifyOfEndOfSuperstep();
+                } else if (isSolutionSetUpdate) {
+                    // notify iteration head if responsible for solution set update
+                    solutionSetUpdateBarrier.notifySolutionSetUpdate();
+                }
+
+                boolean terminate =
+                        nextSuperStepLatch.awaitStartOfSuperstepOrTermination(
+                                currentIteration() + 1);
+                if (terminate) {
+                    requestTermination();
+                } else {
+                    incrementIterationCounter();
+                }
             }
+        } finally {
+            terminationCompleted();
         }
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/Terminable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/Terminable.java
index ee06eb9..1293682 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/Terminable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/Terminable.java
@@ -26,4 +26,6 @@ public interface Terminable {
     boolean terminationRequested();
 
     void requestTermination();
+
+    void terminationCompleted();
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/AbstractInvokable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/AbstractInvokable.java
index 6651523..453bd85 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/AbstractInvokable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/AbstractInvokable.java
@@ -32,6 +32,7 @@ import org.apache.flink.util.FlinkException;
 import org.apache.flink.util.SerializedValue;
 
 import java.io.IOException;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Future;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -104,9 +105,12 @@ public abstract class AbstractInvokable {
      * execution failure. It can be overwritten to respond to shut down the user code properly.
      *
      * @throws Exception thrown if any exception occurs during the execution of the user code
+     * @return a future that is completed when this {@link AbstractInvokable} is fully terminated.
+     *     Note that it may never complete if the invokable is stuck.
      */
-    public void cancel() throws Exception {
+    public Future<Void> cancel() throws Exception {
         // The default implementation does nothing.
+        return CompletableFuture.completedFuture(null);
     }
 
     /**
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java
index a282ec1d..85c5e9e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java
@@ -74,6 +74,8 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
 
 import static java.util.Collections.emptyList;
 
@@ -186,6 +188,7 @@ public class BatchTask<S extends Function, OT> extends AbstractInvokable
     protected Map<String, Accumulator<?, ?>> accumulatorMap;
 
     private OperatorMetricGroup metrics;
+    private final CompletableFuture<Void> terminationFuture = new CompletableFuture<>();
 
     // --------------------------------------------------------------------------------------------
     //                                  Constructor
@@ -361,6 +364,7 @@ public class BatchTask<S extends Function, OT> extends AbstractInvokable
 
             clearReaders(inputReaders);
             clearWriters(eventualOutputs);
+            terminationFuture.complete(null);
         }
 
         if (this.running) {
@@ -375,7 +379,7 @@ public class BatchTask<S extends Function, OT> extends AbstractInvokable
     }
 
     @Override
-    public void cancel() throws Exception {
+    public Future<Void> cancel() throws Exception {
         this.running = false;
 
         if (LOG.isDebugEnabled()) {
@@ -389,6 +393,7 @@ public class BatchTask<S extends Function, OT> extends AbstractInvokable
         } finally {
             closeLocalStrategiesAndCaches();
         }
+        return terminationFuture;
     }
 
     // --------------------------------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java
index 9f3a427..15d9ca8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java
@@ -54,6 +54,9 @@ import org.apache.commons.lang3.tuple.Pair;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
+
 /**
  * DataSinkTask which is executed by a task manager. The task hands the data to an output format.
  *
@@ -88,6 +91,8 @@ public class DataSinkTask<IT> extends AbstractInvokable {
 
     private volatile boolean cleanupCalled;
 
+    private final CompletableFuture<Void> terminationFuture = new CompletableFuture<>();
+
     /**
      * Create an Invokable task and set its environment.
      *
@@ -289,6 +294,7 @@ public class DataSinkTask<IT> extends AbstractInvokable {
             }
 
             BatchTask.clearReaders(new MutableReader<?>[] {inputReader});
+            terminationFuture.complete(null);
         }
 
         if (!this.taskCanceled) {
@@ -299,7 +305,7 @@ public class DataSinkTask<IT> extends AbstractInvokable {
     }
 
     @Override
-    public void cancel() throws Exception {
+    public Future<Void> cancel() throws Exception {
         this.taskCanceled = true;
         OutputFormat<IT> format = this.format;
         if (format != null) {
@@ -320,6 +326,7 @@ public class DataSinkTask<IT> extends AbstractInvokable {
         }
 
         LOG.debug(getLogString("Cancelling data sink operator"));
+        return terminationFuture;
     }
 
     /**
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSourceTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSourceTask.java
index eb05e73..41ecc15 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSourceTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSourceTask.java
@@ -54,6 +54,8 @@ import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
 import java.util.NoSuchElementException;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
 
 /**
  * DataSourceTask which is executed by a task manager. The task reads data and uses an {@link
@@ -85,6 +87,8 @@ public class DataSourceTask<OT> extends AbstractInvokable {
     // cancel flag
     private volatile boolean taskCanceled = false;
 
+    private final CompletableFuture<Void> terminationFuture = new CompletableFuture<>();
+
     /**
      * Create an Invokable task and set its environment.
      *
@@ -251,6 +255,7 @@ public class DataSourceTask<OT> extends AbstractInvokable {
                 ((RichInputFormat) this.format).closeInputFormat();
                 LOG.debug(getLogString("Rich Source detected. Closing the InputFormat."));
             }
+            terminationFuture.complete(null);
         }
 
         if (!this.taskCanceled) {
@@ -261,9 +266,10 @@ public class DataSourceTask<OT> extends AbstractInvokable {
     }
 
     @Override
-    public void cancel() throws Exception {
+    public Future<Void> cancel() throws Exception {
         this.taskCanceled = true;
         LOG.debug(getLogString("Cancelling data source operator"));
+        return terminationFuture;
     }
 
     /**
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
index 8056ca4..ce22e9e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
@@ -102,9 +102,12 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executor;
 import java.util.concurrent.Future;
 import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
 import java.util.function.Consumer;
@@ -1173,6 +1176,10 @@ public class Task
                                 new TaskCanceler(
                                         LOG,
                                         this::closeNetworkResources,
+                                        taskCancellationTimeout > 0
+                                                ? taskCancellationTimeout
+                                                : TaskManagerOptions.TASK_CANCELLATION_TIMEOUT
+                                                        .defaultValue(),
                                         invokable,
                                         executingThread,
                                         taskNameWithSubtask);
@@ -1550,6 +1557,9 @@ public class Task
 
         private final Logger logger;
         private final Runnable networkResourcesCloser;
+        /** Time to wait after cancellation and interruption before releasing network resources. */
+        private final long taskCancellationTimeout;
+
         private final AbstractInvokable invokable;
         private final Thread executer;
         private final String taskName;
@@ -1557,11 +1567,13 @@ public class Task
         TaskCanceler(
                 Logger logger,
                 Runnable networkResourcesCloser,
+                long taskCancellationTimeout,
                 AbstractInvokable invokable,
                 Thread executer,
                 String taskName) {
             this.logger = logger;
             this.networkResourcesCloser = networkResourcesCloser;
+            this.taskCancellationTimeout = taskCancellationTimeout;
             this.invokable = invokable;
             this.executer = executer;
             this.taskName = taskName;
@@ -1573,7 +1585,17 @@ public class Task
                 // the user-defined cancel method may throw errors.
                 // we need do continue despite that
                 try {
-                    invokable.cancel();
+                    Future<Void> cancellationFuture = invokable.cancel();
+                    // Wait for any active actions to complete (e.g. timers, mailbox actions)
+                    // Before that, interrupt to notify them about cancellation
+                    if (invokable.shouldInterruptOnCancel()) {
+                        executer.interrupt();
+                    }
+                    try {
+                        cancellationFuture.get(taskCancellationTimeout, TimeUnit.MILLISECONDS);
+                    } catch (ExecutionException | TimeoutException | InterruptedException e) {
+                        logger.debug("Error while waiting the task to terminate {}.", taskName, e);
+                    }
                 } catch (Throwable t) {
                     ExceptionUtils.rethrowIfFatalError(t);
                     logger.error("Error while canceling the task {}.", taskName, t);
@@ -1583,15 +1605,8 @@ public class Task
                 // in order to unblock async Threads, which produce/consume the
                 // intermediate streams outside of the main Task Thread (like
                 // the Kafka consumer).
-                //
-                // Don't do this before cancelling the invokable. Otherwise we
-                // will get misleading errors in the logs.
                 networkResourcesCloser.run();
 
-                // send the initial interruption signal, if requested
-                if (invokable.shouldInterruptOnCancel()) {
-                    executer.interrupt();
-                }
             } catch (Throwable t) {
                 ExceptionUtils.rethrowIfFatalError(t);
                 logger.error("Error in the task canceler for task {}.", taskName, t);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/TestingAbstractInvokables.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/TestingAbstractInvokables.java
index db8de8a..44b6084 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/TestingAbstractInvokables.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/TestingAbstractInvokables.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.types.IntValue;
 
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
 
 /** {@link AbstractInvokable} for testing purposes. */
 public class TestingAbstractInvokables {
@@ -106,8 +107,9 @@ public class TestingAbstractInvokables {
         }
 
         @Override
-        public void cancel() {
+        public Future<Void> cancel() {
             gotCanceledFuture.complete(true);
+            return CompletableFuture.completedFuture(null);
         }
 
         public static void resetGotCanceledFuture() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/CoordinatorEventsExactlyOnceITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/CoordinatorEventsExactlyOnceITCase.java
index 2337115..196b546 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/CoordinatorEventsExactlyOnceITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/CoordinatorEventsExactlyOnceITCase.java
@@ -586,8 +586,9 @@ public class CoordinatorEventsExactlyOnceITCase extends TestLogger {
         }
 
         @Override
-        public void cancel() throws Exception {
+        public Future<Void> cancel() throws Exception {
             running = false;
+            return CompletableFuture.completedFuture(null);
         }
 
         @Override
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorOperatorEventHandlingTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorOperatorEventHandlingTest.java
index f84d513..f5c1029 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorOperatorEventHandlingTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorOperatorEventHandlingTest.java
@@ -197,7 +197,7 @@ public class TaskExecutorOperatorEventHandlingTest extends TestLogger {
         }
 
         @Override
-        public void invoke() throws InterruptedException {
+        public void doInvoke() throws InterruptedException {
             waitUntilCancelled();
         }
 
@@ -216,7 +216,7 @@ public class TaskExecutorOperatorEventHandlingTest extends TestLogger {
         }
 
         @Override
-        public void invoke() throws Exception {
+        public void doInvoke() throws Exception {
             getEnvironment()
                     .getOperatorCoordinatorEventGateway()
                     .sendOperatorEventToCoordinator(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
index d812508..d7eecb0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
@@ -63,6 +63,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
 import java.util.concurrent.LinkedBlockingDeque;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -1144,8 +1145,9 @@ public class TaskTest extends TestLogger {
         public void invoke() {}
 
         @Override
-        public void cancel() {
+        public Future<Void> cancel() {
             fail("This should not be called");
+            return null;
         }
     }
 
@@ -1191,7 +1193,9 @@ public class TaskTest extends TestLogger {
         }
 
         @Override
-        public void cancel() {}
+        public Future<Void> cancel() {
+            return CompletableFuture.completedFuture(null);
+        }
     }
 
     private static final class InvokableBlockingWithTrigger extends AbstractInvokable {
@@ -1313,11 +1317,12 @@ public class TaskTest extends TestLogger {
         }
 
         @Override
-        public void cancel() throws Exception {
+        public Future<Void> cancel() throws Exception {
             synchronized (this) {
                 triggerLatch.trigger();
                 wait();
             }
+            return CompletableFuture.completedFuture(null);
         }
     }
 
@@ -1339,11 +1344,12 @@ public class TaskTest extends TestLogger {
         }
 
         @Override
-        public void cancel() {
+        public Future<Void> cancel() {
             synchronized (lock) {
                 // do nothing but a placeholder
                 triggerLatch.trigger();
             }
+            return CompletableFuture.completedFuture(null);
         }
     }
 
@@ -1367,7 +1373,9 @@ public class TaskTest extends TestLogger {
         }
 
         @Override
-        public void cancel() {}
+        public Future<Void> cancel() {
+            return CompletableFuture.completedFuture(null);
+        }
     }
 
     // ------------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testtasks/OnceBlockingNoOpInvokable.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testtasks/OnceBlockingNoOpInvokable.java
index c053104..fee64ee 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/testtasks/OnceBlockingNoOpInvokable.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testtasks/OnceBlockingNoOpInvokable.java
@@ -21,7 +21,9 @@ package org.apache.flink.runtime.testtasks;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicInteger;
 
 /**
@@ -62,11 +64,12 @@ public class OnceBlockingNoOpInvokable extends AbstractInvokable {
     }
 
     @Override
-    public void cancel() throws Exception {
+    public Future<Void> cancel() throws Exception {
         synchronized (lock) {
             running = false;
             lock.notifyAll();
         }
+        return CompletableFuture.completedFuture(null);
     }
 
     public static void waitUntilOpsAreRunning() throws InterruptedException {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CancelableInvokable.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CancelableInvokable.java
index bca5fc1..b4ac0a1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CancelableInvokable.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CancelableInvokable.java
@@ -21,6 +21,9 @@ package org.apache.flink.runtime.testutils;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
+
 /**
  * An {@link AbstractInvokable} that blocks at some point until cancelled.
  *
@@ -31,13 +34,29 @@ public abstract class CancelableInvokable extends AbstractInvokable {
 
     private volatile boolean canceled;
 
+    private final CompletableFuture<Void> terminationFuture = new CompletableFuture<>();
+
     protected CancelableInvokable(Environment environment) {
         super(environment);
     }
 
     @Override
-    public void cancel() {
+    public void invoke() throws Exception {
+        try {
+            doInvoke();
+            terminationFuture.complete(null);
+        } catch (Exception e) {
+            terminationFuture.completeExceptionally(e);
+            throw e;
+        }
+    }
+
+    protected abstract void doInvoke() throws Exception;
+
+    @Override
+    public Future<Void> cancel() {
         canceled = true;
+        return terminationFuture;
     }
 
     protected void waitUntilCancelled() throws InterruptedException {
diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala
index 84f166b..56f0c5e 100644
--- a/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala
+++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala
@@ -34,7 +34,7 @@ object Tasks {
         getEnvironment.getInputGate(0),
         classOf[IntValue],
         getEnvironment.getTaskManagerInfo.getTmpDirectories)
-      
+
       val writer = new RecordWriterBuilder[IntValue]().build(
         getEnvironment.getWriter(0))
 
@@ -77,7 +77,7 @@ object Tasks {
         getEnvironment.getInputGate(0),
         classOf[IntValue],
         getEnvironment.getTaskManagerInfo.getTmpDirectories)
-      
+
       val reader2 = new RecordReader[IntValue](
         getEnvironment.getInputGate(1),
         classOf[IntValue],
@@ -98,12 +98,12 @@ object Tasks {
         env.getInputGate(0),
         classOf[IntValue],
         getEnvironment.getTaskManagerInfo.getTmpDirectories)
-      
+
       val reader2 = new RecordReader[IntValue](
         env.getInputGate(1),
         classOf[IntValue],
         getEnvironment.getTaskManagerInfo.getTmpDirectories)
-      
+
       val reader3 = new RecordReader[IntValue](
         env.getInputGate(2),
         classOf[IntValue],
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index fedd2cb..9998743 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -250,6 +250,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>> extends Ab
 
     private long latestAsyncCheckpointStartDelayNanos;
 
+    private final CompletableFuture<Void> terminationFuture = new CompletableFuture<>();
+
     // ------------------------------------------------------------------------
 
     /**
@@ -759,7 +761,10 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>> extends Ab
 
         suppressedException = runAndSuppressThrowable(mailboxProcessor::close, suppressedException);
 
-        if (suppressedException != null) {
+        if (suppressedException == null) {
+            terminationFuture.complete(null);
+        } else {
+            terminationFuture.completeExceptionally(suppressedException);
             throw suppressedException;
         }
     }
@@ -769,7 +774,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>> extends Ab
     }
 
     @Override
-    public final void cancel() throws Exception {
+    public final Future<Void> cancel() throws Exception {
         isRunning = false;
         canceled = true;
 
@@ -793,6 +798,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>> extends Ab
                                 }
                             });
         }
+        return terminationFuture;
     }
 
     public MailboxExecutorFactory getMailboxExecutorFactory() {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index 9c46bfe..835151c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -41,6 +41,7 @@ import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.concurrent.TestingUncaughtExceptionHandler;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.execution.ExecutionState;
@@ -59,6 +60,7 @@ import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.shuffle.PartitionDescriptorBuilder;
 import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
@@ -90,6 +92,7 @@ import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.runtime.taskmanager.TaskManagerActions;
 import org.apache.flink.runtime.taskmanager.TestTaskBuilder;
 import org.apache.flink.runtime.util.FatalExitExceptionHandler;
+import org.apache.flink.runtime.util.NettyShuffleDescriptorBuilder;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
@@ -139,6 +142,7 @@ import java.io.Closeable;
 import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.io.StreamCorruptedException;
+import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
@@ -157,6 +161,7 @@ import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Consumer;
 
 import static java.util.Arrays.asList;
+import static java.util.Collections.singletonList;
 import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.STRING_TYPE_INFO;
 import static org.apache.flink.configuration.StateBackendOptions.STATE_BACKEND;
 import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.UNKNOWN_TASK_CHECKPOINT_NOTIFICATION_FAILURE;
@@ -193,6 +198,31 @@ public class StreamTaskTest extends TestLogger {
     @Rule public final Timeout timeoutPerTest = Timeout.seconds(30);
 
     @Test
+    public void testCancellationWaitsForActiveTimers() throws Exception {
+        StreamTaskWithBlockingTimer.reset();
+        ResultPartitionDeploymentDescriptor descriptor =
+                new ResultPartitionDeploymentDescriptor(
+                        PartitionDescriptorBuilder.newBuilder().build(),
+                        NettyShuffleDescriptorBuilder.newBuilder().buildLocal(),
+                        1,
+                        false);
+        Task task =
+                new TestTaskBuilder(new NettyShuffleEnvironmentBuilder().build())
+                        .setInvokable(StreamTaskWithBlockingTimer.class)
+                        .setResultPartitions(singletonList(descriptor))
+                        .build();
+        task.startTaskThread();
+
+        StreamTaskWithBlockingTimer.timerStarted.join();
+        task.cancelExecution();
+
+        task.getTerminationFuture().join();
+        // explicitly check for exceptions as they are ignored after cancellation
+        StreamTaskWithBlockingTimer.timerFinished.join();
+        checkState(task.getExecutionState() == ExecutionState.CANCELED);
+    }
+
+    @Test
     public void testSavepointSuspendCompleted() throws Exception {
         testSyncSavepointWithEndInput(
                 StreamTask::notifyCheckpointCompleteAsync, CheckpointType.SAVEPOINT_SUSPEND, false);
@@ -2605,4 +2635,97 @@ public class StreamTaskTest extends TestLogger {
         @Override
         public void processElement(StreamRecord<T> element) throws Exception {}
     }
+
+    /**
+     * A {@link StreamTask} that register a single timer that waits for a cancellation and then
+     * emits some data. The assumption is that output remains available until the future returned
+     * from {@link AbstractInvokable#cancel()} is completed. Public * access to allow reflection in
+     * {@link Task}.
+     */
+    public static class StreamTaskWithBlockingTimer extends StreamTask {
+        static volatile CompletableFuture<Void> timerStarted;
+        static volatile CompletableFuture<Void> timerFinished;
+        static volatile CompletableFuture<Void> invokableCancelled;
+
+        public static void reset() {
+            timerStarted = new CompletableFuture<>();
+            timerFinished = new CompletableFuture<>();
+            invokableCancelled = new CompletableFuture<>();
+        }
+
+        // public access to allow reflection in Task
+        public StreamTaskWithBlockingTimer(Environment env) throws Exception {
+            super(env);
+            super.inputProcessor = getInputProcessor();
+            getProcessingTimeServiceFactory()
+                    .createProcessingTimeService(mainMailboxExecutor)
+                    .registerTimer(0, unused -> onProcessingTime());
+        }
+
+        @Override
+        protected void cancelTask() throws Exception {
+            super.cancelTask();
+            invokableCancelled.complete(null);
+        }
+
+        private void onProcessingTime() {
+            try {
+                timerStarted.complete(null);
+                waitForCancellation();
+                emit();
+                timerFinished.complete(null);
+            } catch (Throwable e) { // assertion is Error
+                timerFinished.completeExceptionally(e);
+            }
+        }
+
+        private void waitForCancellation() {
+            invokableCancelled.join();
+            // allow network resources to be closed mistakenly
+            for (int i = 0; i < 10; i++) {
+                try {
+                    Thread.sleep(50);
+                } catch (InterruptedException e) {
+                    Thread.currentThread().interrupt();
+                    // ignore: can be interrupted by TaskCanceller/Interrupter
+                }
+            }
+        }
+
+        private void emit() throws IOException {
+            checkState(getEnvironment().getAllWriters().length > 0);
+            for (ResultPartitionWriter writer : getEnvironment().getAllWriters()) {
+                assertFalse(writer.isReleased());
+                assertFalse(writer.isFinished());
+                writer.emitRecord(ByteBuffer.allocate(10), 0);
+            }
+        }
+
+        @Override
+        protected void init() {}
+
+        private static StreamInputProcessor getInputProcessor() {
+            return new StreamInputProcessor() {
+
+                @Override
+                public InputStatus processInput() {
+                    return InputStatus.NOTHING_AVAILABLE;
+                }
+
+                @Override
+                public CompletableFuture<Void> prepareSnapshot(
+                        ChannelStateWriter channelStateWriter, long checkpointId) {
+                    return CompletableFuture.completedFuture(null);
+                }
+
+                @Override
+                public CompletableFuture<?> getAvailableFuture() {
+                    return new CompletableFuture<>();
+                }
+
+                @Override
+                public void close() {}
+            };
+        }
+    }
 }