You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2021/12/08 16:54:22 UTC

[GitHub] [airflow] josh-fell commented on a change in pull request #20044: Add ShortCircuitOperator configurability for respecting downstream trigger rules

josh-fell commented on a change in pull request #20044:
URL: https://github.com/apache/airflow/pull/20044#discussion_r765048078



##########
File path: tests/operators/test_python.py
##########
@@ -580,128 +582,164 @@ def test_raise_exception_on_invalid_task_id(self):
 class TestShortCircuitOperator(unittest.TestCase):
     @classmethod
     def setUpClass(cls):
-        super().setUpClass()
-
         with create_session() as session:
             session.query(DagRun).delete()
             session.query(TI).delete()
 
-    def tearDown(self):
-        super().tearDown()
+    def setUp(self):
+        self.dag = DAG(
+            "short_circuit_op_test",
+            start_date=DEFAULT_DATE,
+            schedule_interval=INTERVAL,
+        )
 
+        with self.dag:
+            self.op1 = DummyOperator(task_id="op1")
+            self.op2 = DummyOperator(task_id="op2")
+            self.op1.set_downstream(self.op2)
+
+    def tearDown(self):
         with create_session() as session:
             session.query(DagRun).delete()
             session.query(TI).delete()
 
-    def test_with_dag_run(self):
-        value = False
-        dag = DAG(
-            'shortcircuit_operator_test_with_dag_run',
-            default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
-            schedule_interval=INTERVAL,
+    def _assert_expected_task_states(self, dagrun, expected_states):
+        """Helper function that asserts `TaskInstances` of a given `task_id` are in a given state."""
+
+        tis = dagrun.get_task_instances()
+        for ti in tis:
+            try:
+                expected_state = expected_states[ti.task_id]
+            except KeyError:
+                raise ValueError(f"Invalid task id {ti.task_id} found!")
+            else:
+                assert ti.state == expected_state
+
+    all_downstream_skipped_states = {
+        "short_circuit": State.SUCCESS,
+        "op1": State.SKIPPED,
+        "op2": State.SKIPPED,
+    }
+    all_success_states = {"short_circuit": State.SUCCESS, "op1": State.SUCCESS, "op2": State.SUCCESS}
+
+    @parameterized.expand(
+        [
+            # Skip downstream tasks, do not respect trigger rules, default trigger rule on all downstream
+            # tasks
+            (False, True, TriggerRule.ALL_SUCCESS, all_downstream_skipped_states),
+            # Skip downstream tasks, do not respect trigger rules, non-default trigger rule on a downstream
+            # task
+            (False, True, TriggerRule.ALL_DONE, all_downstream_skipped_states),
+            # Skip downstream tasks, respect trigger rules, default trigger rule on all downstream tasks
+            (
+                False,
+                False,
+                TriggerRule.ALL_SUCCESS,
+                {"short_circuit": State.SUCCESS, "op1": State.SKIPPED, "op2": State.NONE},
+            ),
+            # Skip downstream tasks, respect trigger rules, non-default trigger rule on a downstream task
+            (
+                False,
+                False,
+                TriggerRule.ALL_DONE,
+                {"short_circuit": State.SUCCESS, "op1": State.SKIPPED, "op2": State.SUCCESS},
+            ),
+            # Do not skip downstream tasks, do not respect trigger rules, default trigger rule on all
+            # downstream tasks
+            (True, True, TriggerRule.ALL_SUCCESS, all_success_states),
+            # Do not skip downstream tasks, do not respect trigger rules, non-default trigger rule on a
+            # downstream task
+            (True, True, TriggerRule.ALL_DONE, all_success_states),
+            # Do not skip downstream tasks, respect trigger rules, default trigger rule on all downstream
+            # tasks
+            (True, False, TriggerRule.ALL_SUCCESS, all_success_states),
+            # Do not skip downstream tasks, respect trigger rules, non-default trigger rule on a downstream
+            # task
+            (True, False, TriggerRule.ALL_DONE, all_success_states),
+        ],
+    )
+    def test_short_circuiting(
+        self, callable_return, test_ignore_downstream_trigger_rules, test_trigger_rule, expected_task_states
+    ):
+        """
+        Checking the behavior of the ShortCircuitOperator in several scenarios enabling/disabling the skipping
+        of downstream tasks, both short-circuiting modes, and various trigger rules of downstream tasks.
+        """
+
+        self.short_circuit = ShortCircuitOperator(
+            task_id="short_circuit",
+            python_callable=lambda: callable_return,
+            ignore_downstream_trigger_rules=test_ignore_downstream_trigger_rules,
+            dag=self.dag,
         )
-        short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value)
-        branch_1 = DummyOperator(task_id='branch_1', dag=dag)
-        branch_1.set_upstream(short_op)
-        branch_2 = DummyOperator(task_id='branch_2', dag=dag)
-        branch_2.set_upstream(branch_1)
-        upstream = DummyOperator(task_id='upstream', dag=dag)
-        upstream.set_downstream(short_op)
-        dag.clear()
-
-        logging.error("Tasks %s", dag.tasks)
-        dr = dag.create_dagrun(
+        self.short_circuit.set_downstream(self.op1)
+        self.op2.trigger_rule = test_trigger_rule
+        self.dag.clear()
+
+        dagrun = self.dag.create_dagrun(
             run_type=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
         )
 
-        upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        tis = dr.get_task_instances()
-        assert len(tis) == 4
-        for ti in tis:
-            if ti.task_id == 'make_choice':
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == 'upstream':
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
-                assert ti.state == State.SKIPPED
-            else:
-                raise ValueError(f'Invalid task id {ti.task_id} found!')
+        self.short_circuit.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        self.op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        self.op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-        value = True
-        dag.clear()
-        dr.verify_integrity()
-        upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        assert self.short_circuit.ignore_downstream_trigger_rules == test_ignore_downstream_trigger_rules
+        assert self.short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS
+        assert self.op1.trigger_rule == TriggerRule.ALL_SUCCESS
+        assert self.op2.trigger_rule == test_trigger_rule
 
-        tis = dr.get_task_instances()
-        assert len(tis) == 4
-        for ti in tis:
-            if ti.task_id == 'make_choice':
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == 'upstream':
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
-                assert ti.state == State.NONE
-            else:
-                raise ValueError(f'Invalid task id {ti.task_id} found!')
+        self._assert_expected_task_states(dagrun, expected_task_states)
 
     def test_clear_skipped_downstream_task(self):
         """
         After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task
         should not cause it to be executed.

Review comment:
       The confusion might be that the `op2` task's trigger rule is being set. For this test the trigger rules are intended to be irrelevant. The reset is a carry-over form other tests and it's required. I'll update. Thanks for catching this!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org