You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by jh...@apache.org on 2021/08/09 22:52:30 UTC

[airflow] 32/39: BugFix: Correctly handle custom `deps` and `task_group` during DAG Serialization (#16734)

This is an automated email from the ASF dual-hosted git repository.

jhtimmins pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 12c46c972dc84be0dc673c6ff1ced731afcf5630
Author: luoyuliuyin <lu...@gmail.com>
AuthorDate: Tue Jul 6 20:44:32 2021 +0800

    BugFix: Correctly handle custom `deps` and `task_group` during DAG Serialization (#16734)
    
    We check if the dag changed or not via dag_hash, so we need to correctly handle deps and task_group during DAG serialization to ensure that the generation of dag_hash is stable.
    
    closes https://github.com/apache/airflow/issues/16690
    
    (cherry picked from commit 0632ecf6f56214c78deea2a4b54ea0daebb4e95d)
---
 airflow/serialization/serialized_objects.py   | 16 +++--
 tests/serialization/test_dag_serialization.py | 99 +++++++++++++++++++++++++++
 2 files changed, 110 insertions(+), 5 deletions(-)

diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index d2f456d..bdbaea8 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -424,7 +424,10 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
                     )
 
                 deps.append(f'{module_name}.{klass.__name__}')
-            serialize_op['deps'] = deps
+            # deps needs to be sorted here, because op.deps is a set, which is unstable when traversing,
+            # and the same call may get different results.
+            # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
+            serialize_op['deps'] = sorted(deps)
 
         # Store all template_fields as they are if there are JSON Serializable
         # If not, store them as strings
@@ -796,6 +799,9 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization):
         if not task_group:
             return None
 
+        # task_group.xxx_ids needs to be sorted here, because task_group.xxx_ids is a set,
+        # when converting set to list, the order is uncertain.
+        # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
         serialize_group = {
             "_group_id": task_group._group_id,
             "prefix_group_id": task_group.prefix_group_id,
@@ -808,10 +814,10 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization):
                 else (DAT.TASK_GROUP, SerializedTaskGroup.serialize_task_group(child))
                 for label, child in task_group.children.items()
             },
-            "upstream_group_ids": cls._serialize(list(task_group.upstream_group_ids)),
-            "downstream_group_ids": cls._serialize(list(task_group.downstream_group_ids)),
-            "upstream_task_ids": cls._serialize(list(task_group.upstream_task_ids)),
-            "downstream_task_ids": cls._serialize(list(task_group.downstream_task_ids)),
+            "upstream_group_ids": cls._serialize(sorted(task_group.upstream_group_ids)),
+            "downstream_group_ids": cls._serialize(sorted(task_group.downstream_group_ids)),
+            "upstream_task_ids": cls._serialize(sorted(task_group.upstream_task_ids)),
+            "downstream_task_ids": cls._serialize(sorted(task_group.downstream_task_ids)),
         }
 
         return serialize_group
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index f24e862..05edfda 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -963,6 +963,105 @@ class TestStringifiedDAGs(unittest.TestCase):
 
         check_task_group(serialized_dag.task_group)
 
+    def test_deps_sorted(self):
+        """
+        Tests serialize_operator, make sure the deps is in order
+        """
+        from airflow.operators.dummy import DummyOperator
+        from airflow.sensors.external_task import ExternalTaskSensor
+
+        execution_date = datetime(2020, 1, 1)
+        with DAG(dag_id="test_deps_sorted", start_date=execution_date) as dag:
+            task1 = ExternalTaskSensor(
+                task_id="task1",
+                external_dag_id="external_dag_id",
+                mode="reschedule",
+            )
+            task2 = DummyOperator(task_id="task2")
+            task1 >> task2
+
+        serialize_op = SerializedBaseOperator.serialize_operator(dag.task_dict["task1"])
+        deps = serialize_op["deps"]
+        assert deps == [
+            'airflow.ti_deps.deps.not_in_retry_period_dep.NotInRetryPeriodDep',
+            'airflow.ti_deps.deps.not_previously_skipped_dep.NotPreviouslySkippedDep',
+            'airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep',
+            'airflow.ti_deps.deps.ready_to_reschedule.ReadyToRescheduleDep',
+            'airflow.ti_deps.deps.trigger_rule_dep.TriggerRuleDep',
+        ]
+
+    def test_task_group_sorted(self):
+        """
+        Tests serialize_task_group, make sure the list is in order
+        """
+        from airflow.operators.dummy import DummyOperator
+        from airflow.serialization.serialized_objects import SerializedTaskGroup
+        from airflow.utils.task_group import TaskGroup
+
+        """
+                    start
+                    ╱  ╲
+                  ╱      ╲
+        task_group_up1  task_group_up2
+            (task_up1)  (task_up2)
+                 ╲       ╱
+              task_group_middle
+                (task_middle)
+                  ╱      ╲
+        task_group_down1 task_group_down2
+           (task_down1) (task_down2)
+                 ╲        ╱
+                   ╲    ╱
+                    end
+        """
+        execution_date = datetime(2020, 1, 1)
+        with DAG(dag_id="test_task_group_sorted", start_date=execution_date) as dag:
+            start = DummyOperator(task_id="start")
+
+            with TaskGroup("task_group_up1") as task_group_up1:
+                _ = DummyOperator(task_id="task_up1")
+
+            with TaskGroup("task_group_up2") as task_group_up2:
+                _ = DummyOperator(task_id="task_up2")
+
+            with TaskGroup("task_group_middle") as task_group_middle:
+                _ = DummyOperator(task_id="task_middle")
+
+            with TaskGroup("task_group_down1") as task_group_down1:
+                _ = DummyOperator(task_id="task_down1")
+
+            with TaskGroup("task_group_down2") as task_group_down2:
+                _ = DummyOperator(task_id="task_down2")
+
+            end = DummyOperator(task_id='end')
+
+            start >> task_group_up1
+            start >> task_group_up2
+            task_group_up1 >> task_group_middle
+            task_group_up2 >> task_group_middle
+            task_group_middle >> task_group_down1
+            task_group_middle >> task_group_down2
+            task_group_down1 >> end
+            task_group_down2 >> end
+
+        task_group_middle_dict = SerializedTaskGroup.serialize_task_group(
+            dag.task_group.children["task_group_middle"]
+        )
+        upstream_group_ids = task_group_middle_dict["upstream_group_ids"]
+        assert upstream_group_ids == ['task_group_up1', 'task_group_up2']
+
+        upstream_task_ids = task_group_middle_dict["upstream_task_ids"]
+        assert upstream_task_ids == ['task_group_up1.task_up1', 'task_group_up2.task_up2']
+
+        downstream_group_ids = task_group_middle_dict["downstream_group_ids"]
+        assert downstream_group_ids == ['task_group_down1', 'task_group_down2']
+
+        task_group_down1_dict = SerializedTaskGroup.serialize_task_group(
+            dag.task_group.children["task_group_down1"]
+        )
+        downstream_task_ids = task_group_down1_dict["downstream_task_ids"]
+        assert downstream_task_ids == ['end']
+
     def test_edge_info_serialization(self):
         """
         Tests edge_info serialization/deserialization.