You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2019/08/18 14:21:43 UTC

[flink] branch release-1.9 updated (c53fada -> 926e818)

This is an automated email from the ASF dual-hosted git repository.

trohrmann pushed a change to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git.


    from c53fada  [FLINK-13760][hive] Fix hardcode Scala version dependency in hive connector
     new 65e6dbb  [FLINK-11630] Triggers the termination of all running Tasks when shutting down TaskExecutor
     new 926e818  [hotfix] Introduce TaskDeploymentDescriptorBuilder in tests

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../flink/runtime/taskexecutor/TaskExecutor.java   |  65 +++++-
 .../TaskDeploymentDescriptorBuilder.java           | 157 ++++++++++++++
 .../runtime/taskexecutor/TaskExecutorTest.java     | 236 ++++++++++++++-------
 3 files changed, 375 insertions(+), 83 deletions(-)
 create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorBuilder.java


[flink] 02/02: [hotfix] Introduce TaskDeploymentDescriptorBuilder in tests

Posted by tr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

trohrmann pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 926e818e144ca15c19264ba7489e99db9d8aa225
Author: Andrey Zagrebin <az...@gmail.com>
AuthorDate: Fri Aug 16 10:07:41 2019 +0200

    [hotfix] Introduce TaskDeploymentDescriptorBuilder in tests
---
 .../TaskDeploymentDescriptorBuilder.java           | 157 +++++++++++++++++++++
 .../runtime/taskexecutor/TaskExecutorTest.java     | 133 ++---------------
 2 files changed, 170 insertions(+), 120 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorBuilder.java
new file mode 100644
index 0000000..4b9a407
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorBuilder.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.deployment;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
+import org.apache.flink.runtime.clusterframework.types.AllocationID;
+import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
+import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded;
+import org.apache.flink.runtime.executiongraph.DummyJobInformation;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.JobInformation;
+import org.apache.flink.runtime.executiongraph.TaskInformation;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.util.SerializedValue;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Collections;
+
+/**
+ * Builder for {@link TaskDeploymentDescriptor}.
+ */
+public class TaskDeploymentDescriptorBuilder {
+	private JobID jobId;
+	private MaybeOffloaded<JobInformation> serializedJobInformation;
+	private MaybeOffloaded<TaskInformation> serializedTaskInformation;
+	private ExecutionAttemptID executionId;
+	private AllocationID allocationId;
+	private int subtaskIndex;
+	private int attemptNumber;
+	private Collection<ResultPartitionDeploymentDescriptor> producedPartitions;
+	private Collection<InputGateDeploymentDescriptor> inputGates;
+	private int targetSlotNumber;
+
+	@Nullable
+	private JobManagerTaskRestore taskRestore;
+
+	private TaskDeploymentDescriptorBuilder(JobID jobId, String invokableClassName) throws IOException {
+		TaskInformation taskInformation = new TaskInformation(
+			new JobVertexID(),
+			"test task",
+			1,
+			1,
+			invokableClassName,
+			new Configuration());
+
+		this.jobId = jobId;
+		this.serializedJobInformation =
+			new NonOffloaded<>(new SerializedValue<>(new DummyJobInformation(jobId, "DummyJob")));
+		this.serializedTaskInformation = new NonOffloaded<>(new SerializedValue<>(taskInformation));
+		this.executionId = new ExecutionAttemptID();
+		this.allocationId = new AllocationID();
+		this.subtaskIndex = 0;
+		this.attemptNumber = 0;
+		this.producedPartitions = Collections.emptyList();
+		this.inputGates = Collections.emptyList();
+		this.targetSlotNumber = 0;
+		this.taskRestore = null;
+	}
+
+	public TaskDeploymentDescriptorBuilder setSerializedJobInformation(
+			MaybeOffloaded<JobInformation> serializedJobInformation) {
+		this.serializedJobInformation = serializedJobInformation;
+		return this;
+	}
+
+	public TaskDeploymentDescriptorBuilder setSerializedTaskInformation(
+			MaybeOffloaded<TaskInformation> serializedTaskInformation) {
+		this.serializedTaskInformation = serializedTaskInformation;
+		return this;
+	}
+
+	public TaskDeploymentDescriptorBuilder setJobId(JobID jobId) {
+		this.jobId = jobId;
+		return this;
+	}
+
+	public TaskDeploymentDescriptorBuilder setExecutionId(ExecutionAttemptID executionId) {
+		this.executionId = executionId;
+		return this;
+	}
+
+	public TaskDeploymentDescriptorBuilder setAllocationId(AllocationID allocationId) {
+		this.allocationId = allocationId;
+		return this;
+	}
+
+	public TaskDeploymentDescriptorBuilder setSubtaskIndex(int subtaskIndex) {
+		this.subtaskIndex = subtaskIndex;
+		return this;
+	}
+
+	public TaskDeploymentDescriptorBuilder setAttemptNumber(int attemptNumber) {
+		this.attemptNumber = attemptNumber;
+		return this;
+	}
+
+	public TaskDeploymentDescriptorBuilder setProducedPartitions(
+			Collection<ResultPartitionDeploymentDescriptor> producedPartitions) {
+		this.producedPartitions = producedPartitions;
+		return this;
+	}
+
+	public TaskDeploymentDescriptorBuilder setInputGates(Collection<InputGateDeploymentDescriptor> inputGates) {
+		this.inputGates = inputGates;
+		return this;
+	}
+
+	public TaskDeploymentDescriptorBuilder setTargetSlotNumber(int targetSlotNumber) {
+		this.targetSlotNumber = targetSlotNumber;
+		return this;
+	}
+
+	public TaskDeploymentDescriptorBuilder setTaskRestore(@Nullable JobManagerTaskRestore taskRestore) {
+		this.taskRestore = taskRestore;
+		return this;
+	}
+
+	public TaskDeploymentDescriptor build() {
+		return new TaskDeploymentDescriptor(
+			jobId,
+			serializedJobInformation,
+			serializedTaskInformation,
+			executionId,
+			allocationId,
+			subtaskIndex,
+			attemptNumber,
+			targetSlotNumber,
+			taskRestore,
+			producedPartitions,
+			inputGates);
+	}
+
+	public static TaskDeploymentDescriptorBuilder newBuilder(JobID jobId, Class<?> invokableClass) throws IOException {
+		return new TaskDeploymentDescriptorBuilder(jobId, invokableClass.getName());
+	}
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java
index bc6b296..be9fb53 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.runtime.taskexecutor;
 
-import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.java.tuple.Tuple3;
@@ -40,15 +39,11 @@ import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.concurrent.ScheduledExecutor;
 import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
-import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded;
+import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorBuilder;
 import org.apache.flink.runtime.entrypoint.ClusterInformation;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.execution.librarycache.ContextClassLoaderLibraryCacheManager;
 import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
-import org.apache.flink.runtime.executiongraph.DummyJobInformation;
-import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.executiongraph.JobInformation;
-import org.apache.flink.runtime.executiongraph.TaskInformation;
 import org.apache.flink.runtime.heartbeat.HeartbeatListener;
 import org.apache.flink.runtime.heartbeat.HeartbeatManager;
 import org.apache.flink.runtime.heartbeat.HeartbeatManagerImpl;
@@ -93,7 +88,6 @@ import org.apache.flink.runtime.taskexecutor.TaskSubmissionTestEnvironment.Build
 import org.apache.flink.runtime.taskexecutor.exceptions.RegistrationTimeoutException;
 import org.apache.flink.runtime.taskexecutor.exceptions.TaskManagerException;
 import org.apache.flink.runtime.taskexecutor.partition.PartitionTable;
-import org.apache.flink.runtime.taskexecutor.slot.SlotActions;
 import org.apache.flink.runtime.taskexecutor.slot.SlotNotFoundException;
 import org.apache.flink.runtime.taskexecutor.slot.SlotOffer;
 import org.apache.flink.runtime.taskexecutor.slot.TaskSlotTable;
@@ -113,7 +107,6 @@ import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.ExecutorUtils;
 import org.apache.flink.util.FlinkException;
 import org.apache.flink.util.NetUtils;
-import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.TestLogger;
 import org.apache.flink.util.function.FunctionUtils;
 
@@ -146,7 +139,6 @@ import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
@@ -658,37 +650,10 @@ public class TaskExecutorTest extends TestLogger {
 		final JobMasterId jobMasterId = JobMasterId.generate();
 		final JobVertexID jobVertexId = new JobVertexID();
 
-		JobInformation jobInformation = new JobInformation(
-				jobId,
-				testName.getMethodName(),
-				new SerializedValue<>(new ExecutionConfig()),
-				new Configuration(),
-				Collections.emptyList(),
-				Collections.emptyList());
-
-		TaskInformation taskInformation = new TaskInformation(
-				jobVertexId,
-				"test task",
-				1,
-				1,
-				TestInvokable.class.getName(),
-				new Configuration());
-
-		SerializedValue<JobInformation> serializedJobInformation = new SerializedValue<>(jobInformation);
-		SerializedValue<TaskInformation> serializedJobVertexInformation = new SerializedValue<>(taskInformation);
-
-		final TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(
-				jobId,
-				new TaskDeploymentDescriptor.NonOffloaded<>(serializedJobInformation),
-				new TaskDeploymentDescriptor.NonOffloaded<>(serializedJobVertexInformation),
-				new ExecutionAttemptID(),
-				allocationId,
-				0,
-				0,
-				0,
-				null,
-				Collections.emptyList(),
-				Collections.emptyList());
+		final TaskDeploymentDescriptor tdd = TaskDeploymentDescriptorBuilder
+			.newBuilder(jobId, TestInvokable.class)
+			.setAllocationId(allocationId)
+			.build();
 
 		final LibraryCacheManager libraryCacheManager = mock(LibraryCacheManager.class);
 		when(libraryCacheManager.getClassLoader(any(JobID.class))).thenReturn(ClassLoader.getSystemClassLoader());
@@ -772,8 +737,10 @@ public class TaskExecutorTest extends TestLogger {
 	public void testTaskInterruptionAndTerminationOnShutdown() throws Exception {
 		final JobMasterId jobMasterId = JobMasterId.generate();
 		final AllocationID allocationId = new AllocationID();
-		final TaskDeploymentDescriptor taskDeploymentDescriptor =
-			createTaskDeploymentDescriptor(TestInterruptableInvokable.class, allocationId);
+		final TaskDeploymentDescriptor taskDeploymentDescriptor = TaskDeploymentDescriptorBuilder
+			.newBuilder(jobId, TestInterruptableInvokable.class)
+			.setAllocationId(allocationId)
+			.build();
 
 		final JobManagerTable jobManagerTable = createJobManagerTableWithOneJob(jobMasterId);
 		final TaskExecutor taskExecutor = createTaskExecutorWithJobManagerTable(jobManagerTable);
@@ -878,31 +845,6 @@ public class TaskExecutorTest extends TestLogger {
 		return jobManagerTable;
 	}
 
-	private TaskDeploymentDescriptor createTaskDeploymentDescriptor(
-		final Class<? extends AbstractInvokable> invokableClass,
-		final AllocationID allocationId) throws IOException {
-		final TaskInformation taskInformation = new TaskInformation(
-			new JobVertexID(),
-			"test task",
-			1,
-			1,
-			invokableClass.getName(),
-			new Configuration());
-
-		return new TaskDeploymentDescriptor(
-			jobId,
-			new NonOffloaded<>(new SerializedValue<>(new DummyJobInformation(jobId, testName.getMethodName()))),
-			new NonOffloaded<>(new SerializedValue<>(taskInformation)),
-			new ExecutionAttemptID(),
-			allocationId,
-			0,
-			0,
-			0,
-			null,
-			Collections.emptyList(),
-			Collections.emptyList());
-	}
-
 	/**
 	 * Tests that a TaskManager detects a job leader for which it has reserved slots. Upon detecting
 	 * the job leader, it will offer all reserved slots to the JobManager.
@@ -1137,39 +1079,10 @@ public class TaskExecutorTest extends TestLogger {
 			taskSlotTable.allocateSlot(0, jobId, allocationId1, Time.milliseconds(10000L));
 			taskSlotTable.allocateSlot(1, jobId, allocationId2, Time.milliseconds(10000L));
 
-			final JobVertexID jobVertexId = new JobVertexID();
-
-			JobInformation jobInformation = new JobInformation(
-				jobId,
-				testName.getMethodName(),
-				new SerializedValue<>(new ExecutionConfig()),
-				new Configuration(),
-				Collections.emptyList(),
-				Collections.emptyList());
-
-			TaskInformation taskInformation = new TaskInformation(
-				jobVertexId,
-				"test task",
-				1,
-				1,
-				NoOpInvokable.class.getName(),
-				new Configuration());
-
-			SerializedValue<JobInformation> serializedJobInformation = new SerializedValue<>(jobInformation);
-			SerializedValue<TaskInformation> serializedJobVertexInformation = new SerializedValue<>(taskInformation);
-
-			final TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(
-				jobId,
-				new NonOffloaded<>(serializedJobInformation),
-				new NonOffloaded<>(serializedJobVertexInformation),
-				new ExecutionAttemptID(),
-				allocationId1,
-				0,
-				0,
-				0,
-				null,
-				Collections.emptyList(),
-				Collections.emptyList());
+			final TaskDeploymentDescriptor tdd = TaskDeploymentDescriptorBuilder
+				.newBuilder(jobId, NoOpInvokable.class)
+				.setAllocationId(allocationId1)
+				.build();
 
 			// we have to add the job after the TaskExecutor, because otherwise the service has not
 			// been properly started. This will also offer the slots to the job master
@@ -2213,24 +2126,4 @@ public class TaskExecutorTest extends TestLogger {
 			}
 		}
 	}
-
-	/**
-	 * {@link TaskSlotTable} which completes the given future when it is started.
-	 */
-	private static class TaskSlotTableWithStartFuture extends TaskSlotTable {
-		private final CompletableFuture<Void> taskSlotTableStarted;
-
-		private TaskSlotTableWithStartFuture(
-				CompletableFuture<Void> taskSlotTableStarted,
-				TimerService<AllocationID> timerService) {
-			super(Collections.singletonList(ResourceProfile.UNKNOWN), timerService);
-			this.taskSlotTableStarted = taskSlotTableStarted;
-		}
-
-		@Override
-		public void start(SlotActions initialSlotActions) {
-			super.start(initialSlotActions);
-			taskSlotTableStarted.complete(null);
-		}
-	}
 }


[flink] 01/02: [FLINK-11630] Triggers the termination of all running Tasks when shutting down TaskExecutor

Posted by tr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

trohrmann pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 65e6dbb5dc6d3fb021536363bc9da684cf1c306c
Author: blueszheng <ki...@163.com>
AuthorDate: Wed Feb 20 02:08:30 2019 +0800

    [FLINK-11630] Triggers the termination of all running Tasks when shutting down TaskExecutor
    
    This closes #9072.
    This closes #7757.
---
 .../flink/runtime/taskexecutor/TaskExecutor.java   |  65 ++++++-
 .../runtime/taskexecutor/TaskExecutorTest.java     | 209 ++++++++++++++++++++-
 2 files changed, 258 insertions(+), 16 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
index 9b295dd..621ef68 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
@@ -132,6 +132,8 @@ import java.util.Objects;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeoutException;
 import java.util.function.BiConsumer;
 import java.util.stream.Collectors;
@@ -182,6 +184,8 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
 	/** The kvState registration service in the task manager. */
 	private final KvStateService kvStateService;
 
+	private final TaskCompletionTracker taskCompletionTracker;
+
 	// --------- job manager connections -----------
 
 	private final Map<ResourceID, JobManagerConnection> jobManagerConnections;
@@ -273,6 +277,7 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
 		this.currentRegistrationTimeoutId = null;
 
 		this.stackTraceSampleService = new StackTraceSampleService(rpcService.getScheduledExecutor());
+		this.taskCompletionTracker = new TaskCompletionTracker();
 	}
 
 	@Override
@@ -333,31 +338,46 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
 	public CompletableFuture<Void> onStop() {
 		log.info("Stopping TaskExecutor {}.", getAddress());
 
-		Throwable throwable = null;
+		Throwable jobManagerDisconnectThrowable = null;
 
 		if (resourceManagerConnection != null) {
 			resourceManagerConnection.close();
 		}
 
+		FlinkException cause = new FlinkException("The TaskExecutor is shutting down.");
 		for (JobManagerConnection jobManagerConnection : jobManagerConnections.values()) {
 			try {
-				disassociateFromJobManager(jobManagerConnection, new FlinkException("The TaskExecutor is shutting down."));
+				disassociateFromJobManager(jobManagerConnection, cause);
 			} catch (Throwable t) {
-				throwable = ExceptionUtils.firstOrSuppressed(t, throwable);
+				jobManagerDisconnectThrowable = ExceptionUtils.firstOrSuppressed(t, jobManagerDisconnectThrowable);
 			}
 		}
 
-		try {
-			stopTaskExecutorServices();
-		} catch (Exception e) {
-			throwable = ExceptionUtils.firstOrSuppressed(e, throwable);
+		final Throwable throwableBeforeTasksCompletion = jobManagerDisconnectThrowable;
+
+		return FutureUtils
+			.runAfterwards(
+				taskCompletionTracker.failIncompleteTasksAndGetTerminationFuture(),
+				this::stopTaskExecutorServices)
+  		    .handle((ignored, throwable) -> {
+  		    	handleOnStopException(throwableBeforeTasksCompletion, throwable);
+  		    	return null;
+			});
+	}
+
+	private void handleOnStopException(Throwable throwableBeforeTasksCompletion, Throwable throwableAfterTasksCompletion) {
+		final Throwable throwable;
+
+		if (throwableBeforeTasksCompletion != null) {
+			throwable = ExceptionUtils.firstOrSuppressed(throwableBeforeTasksCompletion, throwableAfterTasksCompletion);
+		} else {
+			throwable = throwableAfterTasksCompletion;
 		}
 
 		if (throwable != null) {
-			return FutureUtils.completedExceptionally(new FlinkException("Error while shutting the TaskExecutor down.", throwable));
+			throw new CompletionException(new FlinkException("Error while shutting the TaskExecutor down.", throwable));
 		} else {
 			log.info("Stopped TaskExecutor {}.", getAddress());
-			return CompletableFuture.completedFuture(null);
 		}
 	}
 
@@ -596,6 +616,7 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
 
 			if (taskAdded) {
 				task.startTaskThread();
+				taskCompletionTracker.trackTaskCompletion(task);
 
 				setupResultPartitionBookkeeping(tdd, task.getTerminationFuture());
 				return CompletableFuture.completedFuture(Acknowledge.get());
@@ -1826,4 +1847,30 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
 			return taskSlotTable.createSlotReport(getResourceID());
 		}
 	}
+
+	private static class TaskCompletionTracker {
+		private final Map<ExecutionAttemptID, Task> incompleteTasks;
+
+		private TaskCompletionTracker() {
+			incompleteTasks = new ConcurrentHashMap<>(8);
+		}
+
+		void trackTaskCompletion(Task task) {
+			incompleteTasks.put(task.getExecutionId(), task);
+			task.getTerminationFuture().thenRun(() -> incompleteTasks.remove(task.getExecutionId()));
+		}
+
+		CompletableFuture<Void> failIncompleteTasksAndGetTerminationFuture() {
+			FlinkException cause = new FlinkException("The TaskExecutor is shutting down.");
+			return FutureUtils.waitForAll(
+				incompleteTasks
+					.values()
+					.stream()
+					.map(task -> {
+						task.failExternally(cause);
+						return task.getTerminationFuture();
+					})
+					.collect(Collectors.toList()));
+		}
+	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java
index c7b8e12..bc6b296 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java
@@ -40,9 +40,12 @@ import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.concurrent.ScheduledExecutor;
 import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded;
 import org.apache.flink.runtime.entrypoint.ClusterInformation;
 import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.execution.librarycache.ContextClassLoaderLibraryCacheManager;
 import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
+import org.apache.flink.runtime.executiongraph.DummyJobInformation;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.JobInformation;
 import org.apache.flink.runtime.executiongraph.TaskInformation;
@@ -77,6 +80,7 @@ import org.apache.flink.runtime.messages.Acknowledge;
 import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.registration.RegistrationResponse;
+import org.apache.flink.runtime.registration.RegistrationResponse.Decline;
 import org.apache.flink.runtime.registration.RetryingRegistrationConfiguration;
 import org.apache.flink.runtime.resourcemanager.ResourceManagerGateway;
 import org.apache.flink.runtime.resourcemanager.ResourceManagerId;
@@ -85,9 +89,11 @@ import org.apache.flink.runtime.rpc.RpcService;
 import org.apache.flink.runtime.rpc.RpcUtils;
 import org.apache.flink.runtime.rpc.TestingRpcService;
 import org.apache.flink.runtime.state.TaskExecutorLocalStateStoresManager;
+import org.apache.flink.runtime.taskexecutor.TaskSubmissionTestEnvironment.Builder;
 import org.apache.flink.runtime.taskexecutor.exceptions.RegistrationTimeoutException;
 import org.apache.flink.runtime.taskexecutor.exceptions.TaskManagerException;
 import org.apache.flink.runtime.taskexecutor.partition.PartitionTable;
+import org.apache.flink.runtime.taskexecutor.slot.SlotActions;
 import org.apache.flink.runtime.taskexecutor.slot.SlotNotFoundException;
 import org.apache.flink.runtime.taskexecutor.slot.SlotOffer;
 import org.apache.flink.runtime.taskexecutor.slot.TaskSlotTable;
@@ -99,6 +105,7 @@ import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.runtime.taskmanager.TaskManagerActions;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
+import org.apache.flink.runtime.taskmanager.TestCheckpointResponder;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.runtime.util.TestingFatalErrorHandler;
@@ -139,6 +146,7 @@ import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
@@ -760,6 +768,141 @@ public class TaskExecutorTest extends TestLogger {
 		}
 	}
 
+	@Test
+	public void testTaskInterruptionAndTerminationOnShutdown() throws Exception {
+		final JobMasterId jobMasterId = JobMasterId.generate();
+		final AllocationID allocationId = new AllocationID();
+		final TaskDeploymentDescriptor taskDeploymentDescriptor =
+			createTaskDeploymentDescriptor(TestInterruptableInvokable.class, allocationId);
+
+		final JobManagerTable jobManagerTable = createJobManagerTableWithOneJob(jobMasterId);
+		final TaskExecutor taskExecutor = createTaskExecutorWithJobManagerTable(jobManagerTable);
+
+		try {
+			taskExecutor.start();
+
+			final TaskExecutorGateway taskExecutorGateway = taskExecutor.getSelfGateway(TaskExecutorGateway.class);
+			final JobMasterGateway jobMasterGateway = jobManagerTable.get(jobId).getJobManagerGateway();
+			requestSlotFromTaskExecutor(taskExecutorGateway, jobMasterGateway, allocationId);
+
+			taskExecutorGateway.submitTask(taskDeploymentDescriptor, jobMasterId, timeout);
+
+			TestInterruptableInvokable.STARTED_FUTURE.get();
+		} finally {
+			taskExecutor.closeAsync();
+		}
+
+		// check task has been interrupted
+		TestInterruptableInvokable.INTERRUPTED_FUTURE.get();
+
+		// check task executor is waiting for the task completion and has not terminated yet
+		final CompletableFuture<Void> taskExecutorTerminationFuture = taskExecutor.getTerminationFuture();
+		assertThat(taskExecutorTerminationFuture.isDone(), is(false));
+
+		// check task executor has exited after task completion
+		TestInterruptableInvokable.DONE_FUTURE.complete(null);
+		taskExecutorTerminationFuture.get();
+	}
+
+	private void requestSlotFromTaskExecutor(
+			TaskExecutorGateway taskExecutorGateway,
+			JobMasterGateway jobMasterGateway,
+			AllocationID allocationId) throws ExecutionException, InterruptedException {
+		final CompletableFuture<Tuple3<ResourceID, InstanceID, SlotReport>> initialSlotReportFuture =
+			new CompletableFuture<>();
+		ResourceManagerId resourceManagerId = createAndRegisterResourceManager(initialSlotReportFuture);
+		initialSlotReportFuture.get();
+
+		taskExecutorGateway
+			.requestSlot(
+				new SlotID(ResourceID.generate(), 0),
+				jobId,
+				allocationId,
+				jobMasterGateway.getAddress(),
+				resourceManagerId,
+				timeout)
+			.get();
+
+		// now inform the task manager about the new job leader
+		jobManagerLeaderRetriever.notifyListener(
+			jobMasterGateway.getAddress(),
+			jobMasterGateway.getFencingToken().toUUID());
+	}
+
+	private ResourceManagerId createAndRegisterResourceManager(
+			CompletableFuture<Tuple3<ResourceID, InstanceID, SlotReport>> initialSlotReportFuture) {
+		final TestingResourceManagerGateway resourceManagerGateway = new TestingResourceManagerGateway();
+		resourceManagerGateway.setSendSlotReportFunction(resourceIDInstanceIDSlotReportTuple3 -> {
+			initialSlotReportFuture.complete(resourceIDInstanceIDSlotReportTuple3);
+			return CompletableFuture.completedFuture(Acknowledge.get());
+		});
+		rpc.registerGateway(resourceManagerGateway.getAddress(), resourceManagerGateway);
+
+		// tell the task manager about the rm leader
+		resourceManagerLeaderRetriever.notifyListener(
+			resourceManagerGateway.getAddress(),
+			resourceManagerGateway.getFencingToken().toUUID());
+
+		return resourceManagerGateway.getFencingToken();
+	}
+
+	private TaskExecutor createTaskExecutorWithJobManagerTable(JobManagerTable jobManagerTable) throws IOException {
+		final TaskExecutorLocalStateStoresManager localStateStoresManager = createTaskExecutorLocalStateStoresManager();
+		return createTaskExecutor(new TaskManagerServicesBuilder()
+			.setTaskSlotTable(new TaskSlotTable(Collections.singletonList(ResourceProfile.UNKNOWN), timerService))
+			.setJobManagerTable(jobManagerTable)
+			.setTaskStateManager(localStateStoresManager)
+			.build());
+	}
+
+	private JobManagerTable createJobManagerTableWithOneJob(JobMasterId jobMasterId) {
+		final TestingJobMasterGateway jobMasterGateway = new TestingJobMasterGatewayBuilder()
+			.setFencingTokenSupplier(() -> jobMasterId)
+			.setOfferSlotsFunction((resourceID, slotOffers) -> CompletableFuture.completedFuture(slotOffers))
+			.build();
+		rpc.registerGateway(jobMasterGateway.getAddress(), jobMasterGateway);
+
+		final JobManagerConnection jobManagerConnection = new JobManagerConnection(
+			jobId,
+			ResourceID.generate(),
+			jobMasterGateway,
+			new NoOpTaskManagerActions(),
+			new TestCheckpointResponder(),
+			new TestGlobalAggregateManager(),
+			ContextClassLoaderLibraryCacheManager.INSTANCE,
+			new NoOpResultPartitionConsumableNotifier(),
+			(j, i, r) -> CompletableFuture.completedFuture(null));
+
+		final JobManagerTable jobManagerTable = new JobManagerTable();
+		jobManagerTable.put(jobId, jobManagerConnection);
+		return jobManagerTable;
+	}
+
+	private TaskDeploymentDescriptor createTaskDeploymentDescriptor(
+		final Class<? extends AbstractInvokable> invokableClass,
+		final AllocationID allocationId) throws IOException {
+		final TaskInformation taskInformation = new TaskInformation(
+			new JobVertexID(),
+			"test task",
+			1,
+			1,
+			invokableClass.getName(),
+			new Configuration());
+
+		return new TaskDeploymentDescriptor(
+			jobId,
+			new NonOffloaded<>(new SerializedValue<>(new DummyJobInformation(jobId, testName.getMethodName()))),
+			new NonOffloaded<>(new SerializedValue<>(taskInformation)),
+			new ExecutionAttemptID(),
+			allocationId,
+			0,
+			0,
+			0,
+			null,
+			Collections.emptyList(),
+			Collections.emptyList());
+	}
+
 	/**
 	 * Tests that a TaskManager detects a job leader for which it has reserved slots. Upon detecting
 	 * the job leader, it will offer all reserved slots to the JobManager.
@@ -1017,8 +1160,8 @@ public class TaskExecutorTest extends TestLogger {
 
 			final TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(
 				jobId,
-				new TaskDeploymentDescriptor.NonOffloaded<>(serializedJobInformation),
-				new TaskDeploymentDescriptor.NonOffloaded<>(serializedJobVertexInformation),
+				new NonOffloaded<>(serializedJobInformation),
+				new NonOffloaded<>(serializedJobVertexInformation),
 				new ExecutionAttemptID(),
 				allocationId1,
 				0,
@@ -1284,7 +1427,7 @@ public class TaskExecutorTest extends TestLogger {
 							new ClusterInformation("localhost", 1234)));
 					} else {
 						secondRegistration.trigger();
-						return CompletableFuture.completedFuture(new RegistrationResponse.Decline("Only the first registration should succeed."));
+						return CompletableFuture.completedFuture(new Decline("Only the first registration should succeed."));
 					}
 				}
 			);
@@ -1642,7 +1785,7 @@ public class TaskExecutorTest extends TestLogger {
 		config.setString(ConfigConstants.TASK_MANAGER_LOG_PATH_KEY, "/i/dont/exist");
 
 		try (TaskSubmissionTestEnvironment env =
-			new TaskSubmissionTestEnvironment.Builder(jobId)
+			new Builder(jobId)
 				.setConfiguration(config)
 				.setLocalCommunication(false)
 				.build()) {
@@ -1659,7 +1802,7 @@ public class TaskExecutorTest extends TestLogger {
 
 	@Test(timeout = 10000L)
 	public void testTerminationOnFatalError() throws Throwable {
-		try (TaskSubmissionTestEnvironment env = new TaskSubmissionTestEnvironment.Builder(jobId).build()) {
+		try (TaskSubmissionTestEnvironment env = new Builder(jobId).build()) {
 			String testExceptionMsg = "Test exception of fatal error.";
 
 			env.getTaskExecutor().onFatalError(new Exception(testExceptionMsg));
@@ -1901,12 +2044,12 @@ public class TaskExecutorTest extends TestLogger {
 		}
 
 		@Override
-		public void start(LeaderRetrievalListener listener) throws Exception {
+		public void start(LeaderRetrievalListener listener) {
 			startFuture.complete(listener);
 		}
 
 		@Override
-		public void stop() throws Exception {
+		public void stop() {
 			stopFuture.complete(null);
 		}
 	}
@@ -2038,4 +2181,56 @@ public class TaskExecutorTest extends TestLogger {
 			return result;
 		}
 	}
+
+	/**
+	 * Test invokable which completes the given future when interrupted (can be used only once).
+	 */
+	public static class TestInterruptableInvokable extends AbstractInvokable {
+		private static final CompletableFuture<Void> INTERRUPTED_FUTURE = new CompletableFuture<>();
+		private static final CompletableFuture<Void> STARTED_FUTURE = new CompletableFuture<>();
+		private static final CompletableFuture<Void> DONE_FUTURE = new CompletableFuture<>();
+
+		public TestInterruptableInvokable(Environment environment) {
+			super(environment);
+		}
+
+		@Override
+		public void invoke() {
+			STARTED_FUTURE.complete(null);
+
+			try {
+				INTERRUPTED_FUTURE.get();
+			} catch (InterruptedException e) {
+				INTERRUPTED_FUTURE.complete(null);
+			} catch (ExecutionException e) {
+				ExceptionUtils.rethrow(e);
+			}
+
+			try {
+				DONE_FUTURE.get();
+			} catch (ExecutionException | InterruptedException e) {
+				ExceptionUtils.rethrow(e);
+			}
+		}
+	}
+
+	/**
+	 * {@link TaskSlotTable} which completes the given future when it is started.
+	 */
+	private static class TaskSlotTableWithStartFuture extends TaskSlotTable {
+		private final CompletableFuture<Void> taskSlotTableStarted;
+
+		private TaskSlotTableWithStartFuture(
+				CompletableFuture<Void> taskSlotTableStarted,
+				TimerService<AllocationID> timerService) {
+			super(Collections.singletonList(ResourceProfile.UNKNOWN), timerService);
+			this.taskSlotTableStarted = taskSlotTableStarted;
+		}
+
+		@Override
+		public void start(SlotActions initialSlotActions) {
+			super.start(initialSlotActions);
+			taskSlotTableStarted.complete(null);
+		}
+	}
 }