You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2023/01/27 11:00:58 UTC

[flink] 02/03: [FLINK-15550][runtime] The static latches in TaskTest were replaced by latches from invokable objects.

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 83d8e29f6d11354eafcfd5428535a65efffc3a80
Author: Anton Kalashnikov <ak...@gmail.com>
AuthorDate: Thu Jan 26 17:26:59 2023 +0100

    [FLINK-15550][runtime] The static latches in TaskTest were replaced by latches from invokable objects.
---
 .../apache/flink/runtime/taskmanager/TaskTest.java | 192 +++++++++++++--------
 1 file changed, 123 insertions(+), 69 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
index fd925ec60b2..8e728bb87d7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
@@ -39,6 +39,7 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.io.network.partition.consumer.RemoteChannelStateChecker;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
+import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable;
 import org.apache.flink.runtime.jobmanager.PartitionProducerDisposedException;
 import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
 import org.apache.flink.runtime.shuffle.PartitionDescriptor;
@@ -76,6 +77,7 @@ import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import static org.apache.flink.runtime.testutils.CommonTestUtils.waitUntilCondition;
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.CoreMatchers.is;
@@ -100,9 +102,6 @@ import static org.mockito.Mockito.when;
 public class TaskTest extends TestLogger {
     private static final String RESTORE_EXCEPTION_MSG = "TestExceptionInRestore";
 
-    private static OneShotLatch awaitLatch;
-    private static OneShotLatch triggerLatch;
-
     private ShuffleEnvironment<?, ?> shuffleEnvironment;
 
     @ClassRule
@@ -115,9 +114,6 @@ public class TaskTest extends TestLogger {
 
     @Before
     public void setup() {
-        awaitLatch = new OneShotLatch();
-        triggerLatch = new OneShotLatch();
-
         shuffleEnvironment = new NettyShuffleEnvironmentBuilder().build();
         wasCleanedUp = false;
     }
@@ -154,7 +150,7 @@ public class TaskTest extends TestLogger {
                         .setInvokable(InvokableBlockingInRestore.class)
                         .build(Executors.directExecutor());
         task.startTaskThread();
-        awaitLatch.await();
+        awaitInvokableLatch(task);
         task.cancelExecution();
         task.getExecutingThread().join();
         assertTrue(wasCleanedUp);
@@ -445,7 +441,7 @@ public class TaskTest extends TestLogger {
         task.startTaskThread();
 
         // wait till the task is in restore
-        awaitLatch.await();
+        awaitInvokableLatch(task);
 
         task.cancelExecution();
         assertTrue(
@@ -475,7 +471,7 @@ public class TaskTest extends TestLogger {
         task.startTaskThread();
 
         // wait till the task is in invoke
-        awaitLatch.await();
+        awaitInvokableLatch(task);
 
         task.cancelExecution();
         assertTrue(
@@ -506,7 +502,7 @@ public class TaskTest extends TestLogger {
         task.startTaskThread();
 
         // wait till the task is in invoke
-        awaitLatch.await();
+        awaitInvokableLatch(task);
 
         task.failExternally(new Exception(RESTORE_EXCEPTION_MSG));
 
@@ -534,7 +530,7 @@ public class TaskTest extends TestLogger {
         task.startTaskThread();
 
         // wait till the task is in invoke
-        awaitLatch.await();
+        awaitInvokableLatch(task);
 
         task.failExternally(new Exception("test"));
 
@@ -587,13 +583,13 @@ public class TaskTest extends TestLogger {
         task.startTaskThread();
 
         // wait till the task is in invoke
-        awaitLatch.await();
+        awaitInvokableLatch(task);
 
         task.cancelExecution();
         assertEquals(ExecutionState.CANCELING, task.getExecutionState());
 
         // this causes an exception
-        triggerLatch.trigger();
+        triggerInvokableLatch(task);
 
         task.getExecutingThread().join();
 
@@ -620,13 +616,13 @@ public class TaskTest extends TestLogger {
         task.startTaskThread();
 
         // wait till the task is in invoke
-        awaitLatch.await();
+        awaitInvokableLatch(task);
 
         task.failExternally(new Exception("external"));
         assertEquals(ExecutionState.FAILED, task.getExecutionState());
 
         // this causes an exception
-        triggerLatch.trigger();
+        triggerInvokableLatch(task);
 
         task.getExecutingThread().join();
 
@@ -647,11 +643,12 @@ public class TaskTest extends TestLogger {
                         .setInvokable(InvokableWithCancelTaskExceptionInInvoke.class)
                         .build(Executors.directExecutor());
 
-        // Cause CancelTaskException.
-        triggerLatch.trigger();
+        task.startTaskThread();
 
-        task.run();
+        // Cause CancelTaskException.
+        triggerInvokableLatch(task);
 
+        task.getExecutingThread().join();
         assertEquals(ExecutionState.CANCELED, task.getExecutionState());
     }
 
@@ -665,14 +662,14 @@ public class TaskTest extends TestLogger {
         task.startTaskThread();
 
         // Wait till the task is in invoke.
-        awaitLatch.await();
+        awaitInvokableLatch(task);
 
         task.failExternally(new Exception("external"));
         assertEquals(ExecutionState.FAILED, task.getExecutionState());
 
         // Either we cause the CancelTaskException or the TaskCanceler
         // by interrupting the invokable.
-        triggerLatch.trigger();
+        triggerInvokableLatch(task);
 
         task.getExecutingThread().join();
 
@@ -834,7 +831,7 @@ public class TaskTest extends TestLogger {
 
             try {
                 task.startTaskThread();
-                awaitLatch.await();
+                awaitInvokableLatch(task);
 
                 CompletableFuture<ExecutionState> promise = new CompletableFuture<>();
                 when(partitionChecker.requestPartitionProducerState(
@@ -877,7 +874,7 @@ public class TaskTest extends TestLogger {
 
             try {
                 task.startTaskThread();
-                awaitLatch.await();
+                awaitInvokableLatch(task);
 
                 CompletableFuture<ExecutionState> promise = new CompletableFuture<>();
                 when(partitionChecker.requestPartitionProducerState(
@@ -927,7 +924,7 @@ public class TaskTest extends TestLogger {
 
         task.startTaskThread();
 
-        awaitLatch.await();
+        awaitInvokableLatch(task);
 
         task.cancelExecution();
         task.getExecutingThread().join();
@@ -955,7 +952,7 @@ public class TaskTest extends TestLogger {
 
         task.startTaskThread();
 
-        awaitLatch.await();
+        awaitInvokableLatch(task);
 
         task.cancelExecution();
         task.getExecutingThread().join();
@@ -987,7 +984,7 @@ public class TaskTest extends TestLogger {
         try {
             task.startTaskThread();
 
-            awaitLatch.await();
+            awaitInvokableLatch(task);
 
             task.cancelExecution();
 
@@ -996,7 +993,7 @@ public class TaskTest extends TestLogger {
             assertThat(fatalError, is(notNullValue()));
         } finally {
             // Interrupt again to clean up Thread
-            triggerLatch.trigger();
+            triggerInvokableLatch(task);
             task.getExecutingThread().interrupt();
             task.getExecutingThread().join();
         }
@@ -1016,31 +1013,33 @@ public class TaskTest extends TestLogger {
         config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, 5);
         config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT, 50);
 
-        final Task task =
-                spy(
-                        createTaskBuilder()
-                                .setInvokable(InvokableBlockingWithTrigger.class)
-                                .setTaskManagerConfig(config)
-                                .setTaskManagerActions(taskManagerActions)
-                                .build(Executors.directExecutor()));
+        // We need to remember the original object since all changes in  `startTaskThread` applies
+        // to it rather than to spy object.
+        Task task =
+                createTaskBuilder()
+                        .setInvokable(InvokableBlockingWithTrigger.class)
+                        .setTaskManagerConfig(config)
+                        .setTaskManagerActions(taskManagerActions)
+                        .build(Executors.directExecutor());
+        final Task spyTask = spy(task);
 
         final Class<OutOfMemoryError> fatalErrorType = OutOfMemoryError.class;
         doThrow(fatalErrorType)
-                .when(task)
+                .when(spyTask)
                 .cancelOrFailAndCancelInvokableInternal(eq(ExecutionState.CANCELING), eq(null));
 
         try {
-            task.startTaskThread();
+            spyTask.startTaskThread();
 
-            awaitLatch.await();
+            awaitInvokableLatch(task);
 
-            task.cancelExecution();
+            spyTask.cancelExecution();
 
             // wait for the notification of notifyFatalError
             final Throwable fatalError = fatalErrorFuture.join();
             assertThat(fatalError, instanceOf(fatalErrorType));
         } finally {
-            triggerLatch.trigger();
+            triggerInvokableLatch(task);
         }
     }
 
@@ -1070,7 +1069,7 @@ public class TaskTest extends TestLogger {
 
         task.startTaskThread();
 
-        awaitLatch.await();
+        awaitInvokableLatch(task);
 
         assertEquals(
                 executionConfig.getTaskCancellationInterval(), task.getTaskCancellationInterval());
@@ -1093,11 +1092,11 @@ public class TaskTest extends TestLogger {
         task.startTaskThread();
 
         // wait till the task is in invoke
-        awaitLatch.await();
+        awaitInvokableLatch(task);
 
         assertFalse(task.getTerminationFuture().isDone());
 
-        triggerLatch.trigger();
+        triggerInvokableLatch(task);
 
         task.getExecutingThread().join();
 
@@ -1162,7 +1161,7 @@ public class TaskTest extends TestLogger {
 
         task.startTaskThread();
         try {
-            awaitLatch.await();
+            awaitInvokableLatch(task);
             assertEquals(ExecutionState.RUNNING, task.getExecutionState());
 
             assertCheckpointDeclined(
@@ -1181,7 +1180,7 @@ public class TaskTest extends TestLogger {
                     InvokableDeclingingCheckpoints.TRIGGERING_FAILED_CHECKPOINT_ID,
                     CheckpointFailureReason.TASK_FAILURE);
         } finally {
-            triggerLatch.trigger();
+            triggerInvokableLatch(task);
             task.getExecutingThread().join();
         }
         assertEquals(ExecutionState.FINISHED, task.getTerminationFuture().getNow(null));
@@ -1211,6 +1210,32 @@ public class TaskTest extends TestLogger {
         testCheckpointResponder.clear();
     }
 
+    private TaskInvokable waitForInvokable(Task task) throws Exception {
+        waitUntilCondition(() -> task.getInvokable() != null, 10L);
+
+        return task.getInvokable();
+    }
+
+    private void awaitInvokableLatch(Task task) throws Exception {
+        TaskInvokable taskInvokable = waitForInvokable(task);
+        if (!(taskInvokable instanceof AwaitLatchInvokable)) {
+            throw new Exception(
+                    "Invokable doesn't implement class - " + AwaitLatchInvokable.class.getName());
+        }
+
+        ((AwaitLatchInvokable) taskInvokable).await();
+    }
+
+    private void triggerInvokableLatch(Task task) throws Exception {
+        TaskInvokable taskInvokable = waitForInvokable(task);
+        if (!(taskInvokable instanceof TriggerLatchInvokable)) {
+            throw new Exception(
+                    "Invokable doesn't implement class - " + TriggerLatchInvokable.class.getName());
+        }
+
+        ((TriggerLatchInvokable) taskInvokable).trigger();
+    }
+
     // ------------------------------------------------------------------------
     //  customized TaskManagerActions
     // ------------------------------------------------------------------------
@@ -1343,7 +1368,7 @@ public class TaskTest extends TestLogger {
         public void cancel() {}
     }
 
-    private static class InvokableBlockingWithTrigger extends AbstractInvokable {
+    private static class InvokableBlockingWithTrigger extends TriggerLatchInvokable {
         public InvokableBlockingWithTrigger(Environment environment) {
             super(environment);
         }
@@ -1385,7 +1410,7 @@ public class TaskTest extends TestLogger {
         }
     }
 
-    private static final class InvokableBlockingInInvoke extends AbstractInvokable {
+    private static final class InvokableBlockingInInvoke extends AwaitLatchInvokable {
         public InvokableBlockingInInvoke(Environment environment) {
             super(environment);
         }
@@ -1403,7 +1428,7 @@ public class TaskTest extends TestLogger {
         }
     }
 
-    private static final class InvokableBlockingInRestore extends AbstractInvokable {
+    private static final class InvokableBlockingInRestore extends AwaitLatchInvokable {
         public InvokableBlockingInRestore(Environment environment) {
             super(environment);
         }
@@ -1431,7 +1456,7 @@ public class TaskTest extends TestLogger {
     }
 
     /** {@link AbstractInvokable} which throws {@link RuntimeException} on invoke. */
-    public static final class InvokableWithExceptionOnTrigger extends AbstractInvokable {
+    public static final class InvokableWithExceptionOnTrigger extends TriggerLatchInvokable {
         public InvokableWithExceptionOnTrigger(Environment environment) {
             super(environment);
         }
@@ -1444,23 +1469,9 @@ public class TaskTest extends TestLogger {
         }
     }
 
-    private static void awaitTriggerLatch() {
-        awaitLatch.trigger();
-
-        // make sure that the interrupt call does not
-        // grab us out of the lock early
-        while (true) {
-            try {
-                triggerLatch.await();
-                break;
-            } catch (InterruptedException e) {
-                // fall through the loop
-            }
-        }
-    }
-
     /** {@link AbstractInvokable} which throws {@link CancelTaskException} on invoke. */
-    public static final class InvokableWithCancelTaskExceptionInInvoke extends AbstractInvokable {
+    public static final class InvokableWithCancelTaskExceptionInInvoke
+            extends TriggerLatchInvokable {
         public InvokableWithCancelTaskExceptionInInvoke(Environment environment) {
             super(environment);
         }
@@ -1474,7 +1485,7 @@ public class TaskTest extends TestLogger {
     }
 
     /** {@link AbstractInvokable} which blocks in cancel. */
-    public static final class InvokableBlockingInCancel extends AbstractInvokable {
+    public static final class InvokableBlockingInCancel extends TriggerLatchInvokable {
         public InvokableBlockingInCancel(Environment environment) {
             super(environment);
         }
@@ -1506,7 +1517,7 @@ public class TaskTest extends TestLogger {
 
     /** {@link AbstractInvokable} which blocks in cancel and is interruptible. */
     public static final class InvokableInterruptibleSharedLockInInvokeAndCancel
-            extends AbstractInvokable {
+            extends TriggerLatchInvokable {
         private final Object lock = new Object();
 
         public InvokableInterruptibleSharedLockInInvokeAndCancel(Environment environment) {
@@ -1515,9 +1526,11 @@ public class TaskTest extends TestLogger {
 
         @Override
         public void invoke() throws Exception {
-            synchronized (lock) {
-                awaitLatch.trigger();
-                wait();
+            while (!triggerLatch.isTriggered()) {
+                synchronized (lock) {
+                    awaitLatch.trigger();
+                    lock.wait();
+                }
             }
         }
 
@@ -1531,7 +1544,7 @@ public class TaskTest extends TestLogger {
     }
 
     /** {@link AbstractInvokable} which blocks in cancel and is not interruptible. */
-    public static final class InvokableUnInterruptibleBlockingInvoke extends AbstractInvokable {
+    public static final class InvokableUnInterruptibleBlockingInvoke extends TriggerLatchInvokable {
         public InvokableUnInterruptibleBlockingInvoke(Environment environment) {
             super(environment);
         }
@@ -1564,4 +1577,45 @@ public class TaskTest extends TestLogger {
             super(cause);
         }
     }
+
+    private abstract static class AwaitLatchInvokable extends AbstractInvokable {
+
+        final OneShotLatch awaitLatch = new OneShotLatch();
+
+        public AwaitLatchInvokable(Environment environment) {
+            super(environment);
+        }
+
+        void await() throws InterruptedException {
+            awaitLatch.await();
+        }
+    }
+
+    private abstract static class TriggerLatchInvokable extends AwaitLatchInvokable {
+
+        final OneShotLatch triggerLatch = new OneShotLatch();
+
+        public TriggerLatchInvokable(Environment environment) {
+            super(environment);
+        }
+
+        void trigger() {
+            triggerLatch.trigger();
+        }
+
+        void awaitTriggerLatch() {
+            awaitLatch.trigger();
+
+            // make sure that the interrupt call does not
+            // grab us out of the lock early
+            while (true) {
+                try {
+                    triggerLatch.await();
+                    break;
+                } catch (InterruptedException e) {
+                    // fall through the loop
+                }
+            }
+        }
+    }
 }