You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/04/08 20:12:28 UTC

[airflow] branch main updated: Support dag serialization with custom ti_deps rules (#22698)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fb59a76f2d Support dag serialization with custom ti_deps rules (#22698)
fb59a76f2d is described below

commit fb59a76f2d7e3544ad5fc109db678bf5ce5b8f01
Author: QP Hou <qp...@scribd.com>
AuthorDate: Fri Apr 8 13:12:18 2022 -0700

    Support dag serialization with custom ti_deps rules (#22698)
---
 airflow/plugins_manager.py                    | 25 +++++++++++
 airflow/serialization/serialized_objects.py   | 64 +++++++++++++++++----------
 tests/cli/commands/test_plugins_command.py    |  1 +
 tests/plugins/test_plugin.py                  |  6 +++
 tests/serialization/test_dag_serialization.py | 59 ++++++++++++++++++++++++
 5 files changed, 132 insertions(+), 23 deletions(-)

diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index 87709d5286..82e295fa19 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -63,6 +63,7 @@ flask_appbuilder_menu_links: Optional[List[Any]] = None
 global_operator_extra_links: Optional[List[Any]] = None
 operator_extra_links: Optional[List[Any]] = None
 registered_operator_link_classes: Optional[Dict[str, Type]] = None
+registered_ti_dep_classes: Optional[Dict[str, Type]] = None
 timetable_classes: Optional[Dict[str, Type["Timetable"]]] = None
 """Mapping of class names to class of OperatorLinks registered by plugins.
 
@@ -78,6 +79,7 @@ PLUGINS_ATTRIBUTES_TO_DUMP = {
     "appbuilder_menu_items",
     "global_operator_extra_links",
     "operator_extra_links",
+    "ti_deps",
     "timetables",
     "source",
     "listeners",
@@ -154,6 +156,8 @@ class AirflowPlugin:
     # buttons.
     operator_extra_links: List[Any] = []
 
+    ti_deps: List[Any] = []
+
     # A list of timetable classes that can be used for DAG scheduling.
     timetables: List[Type["Timetable"]] = []
 
@@ -350,6 +354,27 @@ def initialize_web_ui_plugins():
             )
 
 
+def initialize_ti_deps_plugins():
+    """Creates modules for loaded extension from custom task instance dependency rule plugins"""
+    global registered_ti_dep_classes
+    if registered_ti_dep_classes is not None:
+        return
+
+    ensure_plugins_loaded()
+
+    if plugins is None:
+        raise AirflowPluginException("Can't load plugins.")
+
+    log.debug("Initialize custom taskinstance deps plugins")
+
+    registered_ti_dep_classes = {}
+
+    for plugin in plugins:
+        registered_ti_dep_classes.update(
+            {as_importable_string(ti_dep.__class__): ti_dep.__class__ for ti_dep in plugin.ti_deps}
+        )
+
+
 def initialize_extra_operators_links_plugins():
     """Creates modules for loaded extension from extra operators links plugins"""
     global global_operator_extra_links
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index dc3a22552c..7c895b5ee9 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -637,25 +637,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
             )
 
         if include_deps:
-            # Are the deps different to "stock", if so serialize the class names!
-            # For Airflow 2.0 expediency we _only_ allow built in Dep classes.
-            # Fix this for 2.0.x or 2.1
-            deps = []
-            for dep in op.deps:
-                klass = type(dep)
-                module_name = klass.__module__
-                if not module_name.startswith("airflow.ti_deps.deps."):
-                    assert op.dag  # for type checking
-                    raise SerializationError(
-                        f"Cannot serialize {(op.dag.dag_id + '.' + op.task_id)!r} with `deps` from non-core "
-                        f"module {module_name!r}"
-                    )
-
-                deps.append(f'{module_name}.{klass.__name__}')
-            # deps needs to be sorted here, because op.deps is a set, which is unstable when traversing,
-            # and the same call may get different results.
-            # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
-            serialize_op['deps'] = sorted(deps)
+            serialize_op['deps'] = cls._serialize_deps(op.deps)
 
         # Store all template_fields as they are if there are JSON Serializable
         # If not, store them as strings
@@ -670,6 +652,32 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
 
         return serialize_op
 
+    @classmethod
+    def _serialize_deps(cls, op_deps: Iterable["BaseTIDep"]) -> List[str]:
+        from airflow import plugins_manager
+
+        plugins_manager.initialize_ti_deps_plugins()
+        if plugins_manager.registered_ti_dep_classes is None:
+            raise AirflowException("Can not load plugins")
+
+        deps = []
+        for dep in op_deps:
+            klass = type(dep)
+            module_name = klass.__module__
+            qualname = f'{module_name}.{klass.__name__}'
+            if (
+                not qualname.startswith("airflow.ti_deps.deps.")
+                and qualname not in plugins_manager.registered_ti_dep_classes
+            ):
+                raise SerializationError(
+                    f"Custom dep class {qualname} not serialized, please register it through plugins."
+                )
+            deps.append(qualname)
+        # deps needs to be sorted here, because op_deps is a set, which is unstable when traversing,
+        # and the same call may get different results.
+        # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
+        return sorted(deps)
+
     @classmethod
     def populate_operator(cls, op: Operator, encoded_op: Dict[str, Any]) -> None:
         if "label" not in encoded_op:
@@ -820,11 +828,21 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
 
     @classmethod
     def _deserialize_deps(cls, deps: List[str]) -> Set["BaseTIDep"]:
+        from airflow import plugins_manager
+
+        plugins_manager.initialize_ti_deps_plugins()
+        if plugins_manager.registered_ti_dep_classes is None:
+            raise AirflowException("Can not load plugins")
+
         instances = set()
         for qualname in set(deps):
-            if not qualname.startswith("airflow.ti_deps.deps."):
-                log.error("Dep class %r not registered", qualname)
-                continue
+            if (
+                not qualname.startswith("airflow.ti_deps.deps.")
+                and qualname not in plugins_manager.registered_ti_dep_classes
+            ):
+                raise SerializationError(
+                    f"Custom dep class {qualname} not deserialized, please register it through plugins."
+                )
 
             try:
                 instances.add(import_string(qualname)())
@@ -835,7 +853,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
     @classmethod
     def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> Dict[str, BaseOperatorLink]:
         """
-        Deserialize Operator Links if the Classes  are registered in Airflow Plugins.
+        Deserialize Operator Links if the Classes are registered in Airflow Plugins.
         Error is raised if the OperatorLink is not found in Plugins too.
 
         :param encoded_op_links: Serialized Operator Link
diff --git a/tests/cli/commands/test_plugins_command.py b/tests/cli/commands/test_plugins_command.py
index 1bffc4ff5a..7903a053ea 100644
--- a/tests/cli/commands/test_plugins_command.py
+++ b/tests/cli/commands/test_plugins_command.py
@@ -95,6 +95,7 @@ class TestPluginsCommand(unittest.TestCase):
                         'label': 'The Apache Software Foundation',
                     },
                 ],
+                'ti_deps': ['<TIDep(CustomTestTriggerRule)>'],
             }
         ]
         get_listener_manager().clear()
diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py
index 7abdd22c99..dc6fdc2692 100644
--- a/tests/plugins/test_plugin.py
+++ b/tests/plugins/test_plugin.py
@@ -28,6 +28,7 @@ from airflow.models.baseoperator import BaseOperator
 # This is the class you derive to create a plugin
 from airflow.plugins_manager import AirflowPlugin
 from airflow.sensors.base import BaseSensorOperator
+from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.timetables.interval import CronDataIntervalTimetable
 from tests.listeners import empty_listener
 from tests.test_utils.mock_operators import (
@@ -106,6 +107,10 @@ class CustomCronDataIntervalTimetable(CronDataIntervalTimetable):
     pass
 
 
+class CustomTestTriggerRule(BaseTIDep):
+    pass
+
+
 # Defining the plugin class
 class AirflowTestPlugin(AirflowPlugin):
     name = "test_plugin"
@@ -124,6 +129,7 @@ class AirflowTestPlugin(AirflowPlugin):
     operator_extra_links = [GoogleLink(), AirflowLink2(), CustomOpLink(), CustomBaseIndexOpLink(1)]
     timetables = [CustomCronDataIntervalTimetable]
     listeners = [empty_listener]
+    ti_deps = [CustomTestTriggerRule()]
 
 
 class MockPluginA(AirflowPlugin):
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 5fb378f041..e2661eb68f 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -45,6 +45,7 @@ from airflow.operators.bash import BashOperator
 from airflow.security import permissions
 from airflow.serialization.json_schema import load_dag_schema_dict
 from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
+from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.timetables.simple import NullTimetable, OnceTimetable
 from airflow.utils import timezone
 from airflow.utils.context import Context
@@ -1236,6 +1237,64 @@ class TestStringifiedDAGs:
             'airflow.ti_deps.deps.trigger_rule_dep.TriggerRuleDep',
         ]
 
+    def test_error_on_unregistered_ti_dep_serialization(self):
+        # trigger rule not registered through the plugin system will not be serialized
+        class DummyTriggerRule(BaseTIDep):
+            pass
+
+        class DummyTask(BaseOperator):
+            deps = frozenset(list(BaseOperator.deps) + [DummyTriggerRule()])
+
+        execution_date = datetime(2020, 1, 1)
+        with DAG(dag_id="test_error_on_unregistered_ti_dep_serialization", start_date=execution_date) as dag:
+            DummyTask(task_id="task1")
+
+        with pytest.raises(SerializationError):
+            SerializedBaseOperator.serialize_operator(dag.task_dict["task1"])
+
+    def test_error_on_unregistered_ti_dep_deserialization(self):
+        from airflow.operators.dummy import DummyOperator
+
+        with DAG("test_error_on_unregistered_ti_dep_deserialization", start_date=datetime(2019, 8, 1)) as dag:
+            DummyOperator(task_id="task1")
+        serialize_op = SerializedBaseOperator.serialize_operator(dag.task_dict["task1"])
+        serialize_op['deps'] = [
+            'airflow.ti_deps.deps.not_in_retry_period_dep.NotInRetryPeriodDep',
+            # manually injected noncore ti dep should be ignored
+            'test_plugin.NotATriggerRule',
+        ]
+        with pytest.raises(SerializationError):
+            SerializedBaseOperator.deserialize_operator(serialize_op)
+
+    def test_serialize_and_deserialize_custom_ti_deps(self):
+        from test_plugin import CustomTestTriggerRule
+
+        class DummyTask(BaseOperator):
+            deps = frozenset(list(BaseOperator.deps) + [CustomTestTriggerRule()])
+
+        execution_date = datetime(2020, 1, 1)
+        with DAG(dag_id="test_serialize_custom_ti_deps", start_date=execution_date) as dag:
+            DummyTask(task_id="task1")
+
+        serialize_op = SerializedBaseOperator.serialize_operator(dag.task_dict["task1"])
+
+        assert serialize_op["deps"] == [
+            'airflow.ti_deps.deps.not_in_retry_period_dep.NotInRetryPeriodDep',
+            'airflow.ti_deps.deps.not_previously_skipped_dep.NotPreviouslySkippedDep',
+            'airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep',
+            'airflow.ti_deps.deps.trigger_rule_dep.TriggerRuleDep',
+            'test_plugin.CustomTestTriggerRule',
+        ]
+
+        op = SerializedBaseOperator.deserialize_operator(serialize_op)
+        assert sorted(str(dep) for dep in op.deps) == [
+            '<TIDep(CustomTestTriggerRule)>',
+            '<TIDep(Not In Retry Period)>',
+            '<TIDep(Not Previously Skipped)>',
+            '<TIDep(Previous Dagrun State)>',
+            '<TIDep(Trigger Rule)>',
+        ]
+
     def test_task_group_sorted(self):
         """
         Tests serialize_task_group, make sure the list is in order