You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ds...@apache.org on 2022/08/05 01:14:40 UTC

[airflow] branch main updated: Fix dag dependencies detection (#25521)

This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 22fb1dff08 Fix dag dependencies detection (#25521)
22fb1dff08 is described below

commit 22fb1dff08828f80f9a60b3786121b897f601f83
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Thu Aug 4 18:14:26 2022 -0700

    Fix dag dependencies detection (#25521)
    
    Fix issue where we were creating duplicate deps elementss in the dag_dependencies node in serialized dag.
    
    Also added some test coverage for the "dependency detector".
---
 airflow/serialization/serialized_objects.py   |  54 +++++-----
 tests/serialization/test_dag_serialization.py | 148 +++++++++++++++++++++++++-
 2 files changed, 175 insertions(+), 27 deletions(-)

diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 3ace47e2b0..ddbfcce868 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -863,20 +863,30 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
     @classmethod
     def detect_dependencies(cls, op: Operator) -> Set['DagDependency']:
         """Detects between DAG dependencies for the operator."""
+
+        def get_custom_dep() -> List[DagDependency]:
+            """
+            If custom dependency detector is configured, use it.
+
+            TODO: Remove this logic in 3.0.
+            """
+            custom_dependency_detector_cls = conf.getimport('scheduler', 'dependency_detector', fallback=None)
+            if not (
+                custom_dependency_detector_cls is None or custom_dependency_detector_cls is DependencyDetector
+            ):
+                warnings.warn(
+                    "Use of a custom dependency detector is deprecated. "
+                    "Support will be removed in a future release.",
+                    DeprecationWarning,
+                )
+                dep = custom_dependency_detector_cls().detect_task_dependencies(op)
+                if type(dep) is DagDependency:
+                    return [dep]
+            return []
+
         dependency_detector = DependencyDetector()
-        custom_dependency_detector = conf.getimport('scheduler', 'dependency_detector', fallback=None)
-        deps = set()
-        if not (custom_dependency_detector is None or type(dependency_detector) is DependencyDetector):
-            warnings.warn(
-                "Use of a custom dependency detector is deprecated. "
-                "Support will be removed in a future release.",
-                DeprecationWarning,
-            )
-            dep = custom_dependency_detector.detect_task_dependencies(op)
-            if type(dep) is DagDependency:
-                deps.add(dep)
-        deps.update(dependency_detector.detect_task_dependencies(op))
-        deps.update(dependency_detector.detect_dag_dependencies(op.dag))
+        deps = set(dependency_detector.detect_task_dependencies(op))
+        deps.update(get_custom_dep())  # todo: remove in 3.0
         return deps
 
     @classmethod
@@ -1048,14 +1058,13 @@ class SerializedDAG(DAG, BaseSerialization):
                 del serialized_dag["timetable"]
 
             serialized_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()]
-            dag_deps = [
-                t.__dict__
+            dag_deps = {
+                dep
                 for task in dag.task_dict.values()
-                for t in SerializedBaseOperator.detect_dependencies(task)
-                if t is not None
-            ]
-
-            serialized_dag["dag_dependencies"] = dag_deps
+                for dep in SerializedBaseOperator.detect_dependencies(task)
+            }
+            dag_deps.update(DependencyDetector().detect_dag_dependencies(dag))
+            serialized_dag["dag_dependencies"] = [x.__dict__ for x in dag_deps]
             serialized_dag['_task_group'] = SerializedTaskGroup.serialize_task_group(dag.task_group)
 
             # Edge info in the JSON exactly matches our internal structure
@@ -1240,7 +1249,7 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization):
         return group
 
 
-@dataclass
+@dataclass(frozen=True)
 class DagDependency:
     """Dataclass for representing dependencies between DAGs.
     These are calculated during serialization and attached to serialized DAGs.
@@ -1261,9 +1270,6 @@ class DagDependency:
             val += f":{self.dependency_id}"
         return val
 
-    def __hash__(self):
-        return hash((self.source, self.target, self.dependency_type, self.dependency_id))
-
 
 def _has_kubernetes() -> bool:
     global HAS_KUBERNETES
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 7158851291..42b9ca6d4c 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -27,6 +27,7 @@ import os
 import pickle
 from datetime import datetime, timedelta
 from glob import glob
+from typing import Optional
 from unittest import mock
 
 import pendulum
@@ -37,7 +38,7 @@ from kubernetes.client import models as k8s
 from airflow.exceptions import SerializationError
 from airflow.hooks.base import BaseHook
 from airflow.kubernetes.pod_generator import PodGenerator
-from airflow.models import DAG, Connection, DagBag
+from airflow.models import DAG, Connection, DagBag, Dataset, Operator
 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
 from airflow.models.mappedoperator import MappedOperator
 from airflow.models.param import Param, ParamsDict
@@ -45,16 +46,53 @@ from airflow.models.xcom import XCOM_RETURN_KEY, XCom
 from airflow.operators.bash import BashOperator
 from airflow.security import permissions
 from airflow.serialization.json_schema import load_dag_schema_dict
-from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
+from airflow.serialization.serialized_objects import (
+    DagDependency,
+    DependencyDetector,
+    SerializedBaseOperator,
+    SerializedDAG,
+)
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.timetables.simple import NullTimetable, OnceTimetable
 from airflow.utils import timezone
 from airflow.utils.context import Context
 from airflow.utils.operator_resources import Resources
 from airflow.utils.task_group import TaskGroup
+from tests.test_utils.config import conf_vars
 from tests.test_utils.mock_operators import CustomOperator, GoogleLink, MockOperator
 from tests.test_utils.timetables import CustomSerializationTimetable, cron_timetable, delta_timetable
 
+
+class CustomDepOperator(BashOperator):
+    """
+    Used for testing custom dependency detector.
+
+    TODO: remove in Airflow 3.0
+    """
+
+
+class CustomDependencyDetector(DependencyDetector):
+    """
+    Prior to deprecation of custom dependency detector, the return type as Optional[DagDependency].
+    This class verifies that custom dependency detector classes which assume that return type will still
+    work until support for them is removed in 3.0.
+
+    TODO: remove in Airflow 3.0
+    """
+
+    @staticmethod
+    def detect_task_dependencies(task: Operator) -> Optional[DagDependency]:  # type: ignore
+        if isinstance(task, CustomDepOperator):
+            return DagDependency(
+                source=task.dag_id,
+                target='nothing',
+                dependency_type='abc',
+                dependency_id=task.task_id,
+            )
+        else:
+            return DependencyDetector().detect_task_dependencies(task)  # type: ignore
+
+
 executor_config_pod = k8s.V1Pod(
     metadata=k8s.V1ObjectMeta(name="my-name"),
     spec=k8s.V1PodSpec(
@@ -1032,7 +1070,7 @@ class TestStringifiedDAGs:
         """
         dag_schema: dict = load_dag_schema_dict()["definitions"]["dag"]["properties"]
 
-        # The parameters we add manually in Serialization needs to be ignored
+        # The parameters we add manually in Serialization need to be ignored
         ignored_keys: set = {
             "is_subdag",
             "tasks",
@@ -1329,6 +1367,110 @@ class TestStringifiedDAGs:
                 }
             ]
 
+    @conf_vars(
+        {
+            (
+                'scheduler',
+                'dependency_detector',
+            ): 'tests.serialization.test_dag_serialization.CustomDependencyDetector'
+        }
+    )
+    def test_custom_dep_detector(self):
+        """
+        Prior to deprecation of custom dependency detector, the return type was Optional[DagDependency].
+        This class verifies that custom dependency detector classes which assume that return type will still
+        work until support for them is removed in 3.0.
+
+        TODO: remove in Airflow 3.0
+        """
+        from airflow.sensors.external_task import ExternalTaskSensor
+
+        execution_date = datetime(2020, 1, 1)
+        with DAG(dag_id="test", start_date=execution_date) as dag:
+            ExternalTaskSensor(
+                task_id="task1",
+                external_dag_id="external_dag_id",
+                mode="reschedule",
+            )
+            CustomDepOperator(task_id='hello', bash_command='hi')
+            dag = SerializedDAG.to_dict(dag)
+            assert sorted(dag['dag']['dag_dependencies'], key=lambda x: tuple(x.values())) == sorted(
+                [
+                    {
+                        'source': 'external_dag_id',
+                        'target': 'test',
+                        'dependency_type': 'sensor',
+                        'dependency_id': 'task1',
+                    },
+                    {
+                        'source': 'test',
+                        'target': 'nothing',
+                        'dependency_type': 'abc',
+                        'dependency_id': 'hello',
+                    },
+                ],
+                key=lambda x: tuple(x.values()),
+            )
+
+    def test_dag_deps_datasets(self):
+        """
+        Check that dag_dependencies node is populated correctly for a DAG with datasets.
+        """
+        from airflow.sensors.external_task import ExternalTaskSensor
+
+        d1 = Dataset('d1')
+        d2 = Dataset('d2')
+        d3 = Dataset('d3')
+        d4 = Dataset('d4')
+        execution_date = datetime(2020, 1, 1)
+        with DAG(dag_id="test", start_date=execution_date, schedule_on=[d1]) as dag:
+            ExternalTaskSensor(
+                task_id="task1",
+                external_dag_id="external_dag_id",
+                mode="reschedule",
+            )
+            BashOperator(task_id='dataset_writer', bash_command="echo hello", outlets=[d2, d3])
+            BashOperator(task_id='other_dataset_writer', bash_command="echo hello", outlets=[d4])
+
+        dag = SerializedDAG.to_dict(dag)
+        actual = sorted(dag['dag']['dag_dependencies'], key=lambda x: tuple(x.values()))
+        expected = sorted(
+            [
+                {
+                    'source': 'test',
+                    'target': 'dataset',
+                    'dependency_type': 'dataset',
+                    'dependency_id': 'd4',
+                },
+                {
+                    'source': 'external_dag_id',
+                    'target': 'test',
+                    'dependency_type': 'sensor',
+                    'dependency_id': 'task1',
+                },
+                {
+                    'source': 'test',
+                    'target': 'dataset',
+                    'dependency_type': 'dataset',
+                    'dependency_id': 'd3',
+                },
+                {
+                    'source': 'test',
+                    'target': 'dataset',
+                    'dependency_type': 'dataset',
+                    'dependency_id': 'd2',
+                },
+                {
+                    'source': 'dataset',
+                    'target': 'test',
+                    'dependency_type': 'dataset',
+                    'dependency_id': 'd1',
+                },
+            ],
+            key=lambda x: tuple(x.values()),
+        )
+        assert actual == expected
+
     def test_derived_dag_deps_operator(self):
         """
         Tests DAG dependency detection for operators, including derived classes