You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2018/08/03 14:36:52 UTC

[flink] 01/06: [FLINK-10033] [runtime] Task releases reference to AbstractInvokable

This is an automated email from the ASF dual-hosted git repository.

trohrmann pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit df22e91b52ca112090d652b7b36f2c647362009b
Author: Stephan Ewen <se...@apache.org>
AuthorDate: Thu Aug 2 20:57:50 2018 +0200

    [FLINK-10033] [runtime] Task releases reference to AbstractInvokable
    
    To guard against memory leaks, the Task releases the reference to its AbstractInvokable
    when it shuts down or cancels.
    
    This closes #6480.
---
 .../org/apache/flink/runtime/taskmanager/Task.java | 51 +++++++++++++++-------
 .../apache/flink/runtime/taskmanager/TaskTest.java |  5 +++
 2 files changed, 40 insertions(+), 16 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
index 60b2ed8..92ae167 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
@@ -78,6 +78,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.lang.reflect.Constructor;
@@ -243,7 +244,9 @@ public class Task implements Runnable, TaskActions, CheckpointListener {
 	/** atomic flag that makes sure the invokable is canceled exactly once upon error. */
 	private final AtomicBoolean invokableHasBeenCanceled;
 
-	/** The invokable of this task, if initialized. */
+	/** The invokable of this task, if initialized. All accesses must copy the reference and
+	 * check for null, as this field is cleared as part of the disposal logic. */
+	@Nullable
 	private volatile AbstractInvokable invokable;
 
 	/** The current execution state of the task. */
@@ -473,6 +476,12 @@ public class Task implements Runnable, TaskActions, CheckpointListener {
 		return taskCancellationTimeout;
 	}
 
+	@Nullable
+	@VisibleForTesting
+	AbstractInvokable getInvokable() {
+		return invokable;
+	}
+
 	// ------------------------------------------------------------------------
 	//  Task Execution
 	// ------------------------------------------------------------------------
@@ -762,7 +771,7 @@ public class Task implements Runnable, TaskActions, CheckpointListener {
 					if (current == ExecutionState.RUNNING || current == ExecutionState.DEPLOYING) {
 						if (t instanceof CancelTaskException) {
 							if (transitionState(current, ExecutionState.CANCELED)) {
-								cancelInvokable();
+								cancelInvokable(invokable);
 
 								notifyObservers(ExecutionState.CANCELED, null);
 								break;
@@ -773,7 +782,7 @@ public class Task implements Runnable, TaskActions, CheckpointListener {
 								// proper failure of the task. record the exception as the root cause
 								String errorMessage = String.format("Execution of %s (%s) failed.", taskNameWithSubtask, executionId);
 								failureCause = t;
-								cancelInvokable();
+								cancelInvokable(invokable);
 
 								notifyObservers(ExecutionState.FAILED, new Exception(errorMessage, t));
 								break;
@@ -808,6 +817,10 @@ public class Task implements Runnable, TaskActions, CheckpointListener {
 			try {
 				LOG.info("Freeing task resources for {} ({}).", taskNameWithSubtask, executionId);
 
+				// clear the reference to the invokable. this helps guard against holding references
+				// to the invokable and its structures in cases where this Task object is still referenced
+				this.invokable = null;
+
 				// stop the async dispatcher.
 				// copy dispatcher reference to stack, against concurrent release
 				ExecutorService dispatcher = this.asyncCallDispatcher;
@@ -924,18 +937,20 @@ public class Task implements Runnable, TaskActions, CheckpointListener {
 	 * @throws IllegalStateException if the {@link Task} is not yet running
 	 */
 	public void stopExecution() {
+		// copy reference to stack, to guard against concurrent setting to null
+		final AbstractInvokable invokable = this.invokable;
+
 		if (invokable != null) {
-			LOG.info("Attempting to stop task {} ({}).", taskNameWithSubtask, executionId);
 			if (invokable instanceof StoppableTask) {
-				Runnable runnable = new Runnable() {
-					@Override
-					public void run() {
-						try {
-							((StoppableTask) invokable).stop();
-						} catch (RuntimeException e) {
-							LOG.error("Stopping task {} ({}) failed.", taskNameWithSubtask, executionId, e);
-							taskManagerActions.failTask(executionId, e);
-						}
+				LOG.info("Attempting to stop task {} ({}).", taskNameWithSubtask, executionId);
+				final StoppableTask stoppable = (StoppableTask) invokable;
+
+				Runnable runnable = () -> {
+					try {
+						stoppable.stop();
+					} catch (Throwable t) {
+						LOG.error("Stopping task {} ({}) failed.", taskNameWithSubtask, executionId, t);
+						taskManagerActions.failTask(executionId, t);
 					}
 				};
 				executeAsyncCallRunnable(runnable, String.format("Stopping source task %s (%s).", taskNameWithSubtask, executionId));
@@ -945,7 +960,7 @@ public class Task implements Runnable, TaskActions, CheckpointListener {
 		} else {
 			throw new IllegalStateException(
 				String.format(
-					"Cannot stop task %s (%s) because it is not yet running.",
+					"Cannot stop task %s (%s) because it is not running.",
 					taskNameWithSubtask,
 					executionId));
 		}
@@ -1010,6 +1025,10 @@ public class Task implements Runnable, TaskActions, CheckpointListener {
 				if (transitionState(ExecutionState.RUNNING, targetState, cause)) {
 					// we are canceling / failing out of the running state
 					// we need to cancel the invokable
+
+					// copy reference to guard against concurrent null-ing out the reference
+					final AbstractInvokable invokable = this.invokable;
+
 					if (invokable != null && invokableHasBeenCanceled.compareAndSet(false, true)) {
 						this.failureCause = cause;
 						notifyObservers(
@@ -1363,9 +1382,9 @@ public class Task implements Runnable, TaskActions, CheckpointListener {
 	//  Utilities
 	// ------------------------------------------------------------------------
 
-	private void cancelInvokable() {
+	private void cancelInvokable(AbstractInvokable invokable) {
 		// in case of an exception during execution, we still call "cancel()" on the task
-		if (invokable != null && this.invokable != null && invokableHasBeenCanceled.compareAndSet(false, true)) {
+		if (invokable != null && invokableHasBeenCanceled.compareAndSet(false, true)) {
 			try {
 				invokable.cancel();
 			}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
index 1829e97..3dfcfb3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
@@ -174,6 +174,7 @@ public class TaskTest extends TestLogger {
 			assertEquals(ExecutionState.FINISHED, task.getExecutionState());
 			assertFalse(task.isCanceledOrFailed());
 			assertNull(task.getFailureCause());
+			assertNull(task.getInvokable());
 
 			// verify listener messages
 			validateListenerMessage(ExecutionState.RUNNING, task, false);
@@ -202,6 +203,8 @@ public class TaskTest extends TestLogger {
 			// verify final state
 			assertEquals(ExecutionState.CANCELED, task.getExecutionState());
 			validateUnregisterTask(task.getExecutionId());
+
+			assertNull(task.getInvokable());
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -258,6 +261,8 @@ public class TaskTest extends TestLogger {
 
 			// make sure that the TaskManager received an message to unregister the task
 			validateUnregisterTask(task.getExecutionId());
+
+			assertNull(task.getInvokable());
 		}
 		catch (Exception e) {
 			e.printStackTrace();