You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/06/15 19:22:05 UTC

[airflow] branch main updated: Handle missing/null serialized DAG dependencies (#16393)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0fa4d83  Handle missing/null serialized DAG dependencies (#16393)
0fa4d83 is described below

commit 0fa4d833f72a77f30bafa7c32f12b27c0ace4381
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Wed Jun 16 03:21:37 2021 +0800

    Handle missing/null serialized DAG dependencies (#16393)
    
    When a serialized DAG is missing a "dag_dependencies" field (possible
    when upgrading), PostgreSQL would return NULL when accessing the field
    with a JSON function. This value would fail subsequent code, so we need
    some logic to handle it.
    
    Fix #16356
---
 airflow/models/serialized_dag.py    | 18 ++++++------------
 tests/models/test_serialized_dag.py | 18 ++++++++++++++++++
 2 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py
index 71a3de4..2457e43 100644
--- a/airflow/models/serialized_dag.py
+++ b/airflow/models/serialized_dag.py
@@ -313,18 +313,12 @@ class SerializedDagModel(Base):
         :param session: ORM Session
         :type session: Session
         """
-        dependencies = {}
-
         if session.bind.dialect.name in ["sqlite", "mysql"]:
-            for row in session.query(cls.dag_id, func.json_extract(cls.data, "$.dag.dag_dependencies")).all():
-                dependencies[row[0]] = [DagDependency(**d) for d in json.loads(row[1])]
+            query = session.query(cls.dag_id, func.json_extract(cls.data, "$.dag.dag_dependencies"))
+            iterator = ((dag_id, json.loads(deps_data) if deps_data else []) for dag_id, deps_data in query)
         elif session.bind.dialect.name == "mssql":
-            for row in session.query(cls.dag_id, func.json_query(cls.data, "$.dag.dag_dependencies")).all():
-                dependencies[row[0]] = [DagDependency(**d) for d in json.loads(row[1])]
+            query = session.query(cls.dag_id, func.json_query(cls.data, "$.dag.dag_dependencies"))
+            iterator = ((dag_id, json.loads(deps_data) if deps_data else []) for dag_id, deps_data in query)
         else:
-            for row in session.query(
-                cls.dag_id, func.json_extract_path(cls.data, "dag", "dag_dependencies")
-            ).all():
-                dependencies[row[0]] = [DagDependency(**d) for d in row[1]]
-
-        return dependencies
+            iterator = session.query(cls.dag_id, func.json_extract_path(cls.data, "dag", "dag_dependencies"))
+        return {dag_id: [DagDependency(**d) for d in (deps_data or [])] for dag_id, deps_data in iterator}
diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py
index db8282f..3e68ddc 100644
--- a/tests/models/test_serialized_dag.py
+++ b/tests/models/test_serialized_dag.py
@@ -20,6 +20,8 @@
 
 import unittest
 
+from parameterized import parameterized
+
 from airflow import DAG, example_dags as example_dags_module
 from airflow.models import DagBag
 from airflow.models.dagcode import DagCode
@@ -149,3 +151,19 @@ class SerializedDagModelTest(unittest.TestCase):
         ]
         with assert_queries_count(10):
             SDM.bulk_sync_to_db(dags)
+
+    @parameterized.expand([({"dag_dependencies": None},), ({},)])
+    def test_get_dag_dependencies_default_to_empty(self, dag_dependencies_fields):
+        """Test a pre-2.1.0 serialized DAG can deserialize DAG dependencies."""
+        example_dags = make_example_dags(example_dags_module)
+
+        with create_session() as session:
+            sdms = [SDM(dag) for dag in example_dags.values()]
+            # Simulate pre-2.1.0 format.
+            for sdm in sdms:
+                del sdm.data["dag"]["dag_dependencies"]
+                sdm.data["dag"].update(dag_dependencies_fields)
+            session.bulk_save_objects(sdms)
+
+        expected_dependencies = {dag_id: [] for dag_id in example_dags}
+        assert SDM.get_dag_dependencies() == expected_dependencies