You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ds...@apache.org on 2022/08/05 01:14:40 UTC
[airflow] branch main updated: Fix dag dependencies detection (#25521)
This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 22fb1dff08 Fix dag dependencies detection (#25521)
22fb1dff08 is described below
commit 22fb1dff08828f80f9a60b3786121b897f601f83
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Thu Aug 4 18:14:26 2022 -0700
Fix dag dependencies detection (#25521)
Fix issue where we were creating duplicate deps elementss in the dag_dependencies node in serialized dag.
Also added some test coverage for the "dependency detector".
---
airflow/serialization/serialized_objects.py | 54 +++++-----
tests/serialization/test_dag_serialization.py | 148 +++++++++++++++++++++++++-
2 files changed, 175 insertions(+), 27 deletions(-)
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 3ace47e2b0..ddbfcce868 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -863,20 +863,30 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
@classmethod
def detect_dependencies(cls, op: Operator) -> Set['DagDependency']:
"""Detects between DAG dependencies for the operator."""
+
+ def get_custom_dep() -> List[DagDependency]:
+ """
+ If custom dependency detector is configured, use it.
+
+ TODO: Remove this logic in 3.0.
+ """
+ custom_dependency_detector_cls = conf.getimport('scheduler', 'dependency_detector', fallback=None)
+ if not (
+ custom_dependency_detector_cls is None or custom_dependency_detector_cls is DependencyDetector
+ ):
+ warnings.warn(
+ "Use of a custom dependency detector is deprecated. "
+ "Support will be removed in a future release.",
+ DeprecationWarning,
+ )
+ dep = custom_dependency_detector_cls().detect_task_dependencies(op)
+ if type(dep) is DagDependency:
+ return [dep]
+ return []
+
dependency_detector = DependencyDetector()
- custom_dependency_detector = conf.getimport('scheduler', 'dependency_detector', fallback=None)
- deps = set()
- if not (custom_dependency_detector is None or type(dependency_detector) is DependencyDetector):
- warnings.warn(
- "Use of a custom dependency detector is deprecated. "
- "Support will be removed in a future release.",
- DeprecationWarning,
- )
- dep = custom_dependency_detector.detect_task_dependencies(op)
- if type(dep) is DagDependency:
- deps.add(dep)
- deps.update(dependency_detector.detect_task_dependencies(op))
- deps.update(dependency_detector.detect_dag_dependencies(op.dag))
+ deps = set(dependency_detector.detect_task_dependencies(op))
+ deps.update(get_custom_dep()) # todo: remove in 3.0
return deps
@classmethod
@@ -1048,14 +1058,13 @@ class SerializedDAG(DAG, BaseSerialization):
del serialized_dag["timetable"]
serialized_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()]
- dag_deps = [
- t.__dict__
+ dag_deps = {
+ dep
for task in dag.task_dict.values()
- for t in SerializedBaseOperator.detect_dependencies(task)
- if t is not None
- ]
-
- serialized_dag["dag_dependencies"] = dag_deps
+ for dep in SerializedBaseOperator.detect_dependencies(task)
+ }
+ dag_deps.update(DependencyDetector().detect_dag_dependencies(dag))
+ serialized_dag["dag_dependencies"] = [x.__dict__ for x in dag_deps]
serialized_dag['_task_group'] = SerializedTaskGroup.serialize_task_group(dag.task_group)
# Edge info in the JSON exactly matches our internal structure
@@ -1240,7 +1249,7 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization):
return group
-@dataclass
+@dataclass(frozen=True)
class DagDependency:
"""Dataclass for representing dependencies between DAGs.
These are calculated during serialization and attached to serialized DAGs.
@@ -1261,9 +1270,6 @@ class DagDependency:
val += f":{self.dependency_id}"
return val
- def __hash__(self):
- return hash((self.source, self.target, self.dependency_type, self.dependency_id))
-
def _has_kubernetes() -> bool:
global HAS_KUBERNETES
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 7158851291..42b9ca6d4c 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -27,6 +27,7 @@ import os
import pickle
from datetime import datetime, timedelta
from glob import glob
+from typing import Optional
from unittest import mock
import pendulum
@@ -37,7 +38,7 @@ from kubernetes.client import models as k8s
from airflow.exceptions import SerializationError
from airflow.hooks.base import BaseHook
from airflow.kubernetes.pod_generator import PodGenerator
-from airflow.models import DAG, Connection, DagBag
+from airflow.models import DAG, Connection, DagBag, Dataset, Operator
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
@@ -45,16 +46,53 @@ from airflow.models.xcom import XCOM_RETURN_KEY, XCom
from airflow.operators.bash import BashOperator
from airflow.security import permissions
from airflow.serialization.json_schema import load_dag_schema_dict
-from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
+from airflow.serialization.serialized_objects import (
+ DagDependency,
+ DependencyDetector,
+ SerializedBaseOperator,
+ SerializedDAG,
+)
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.timetables.simple import NullTimetable, OnceTimetable
from airflow.utils import timezone
from airflow.utils.context import Context
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import TaskGroup
+from tests.test_utils.config import conf_vars
from tests.test_utils.mock_operators import CustomOperator, GoogleLink, MockOperator
from tests.test_utils.timetables import CustomSerializationTimetable, cron_timetable, delta_timetable
+
+class CustomDepOperator(BashOperator):
+ """
+ Used for testing custom dependency detector.
+
+ TODO: remove in Airflow 3.0
+ """
+
+
+class CustomDependencyDetector(DependencyDetector):
+ """
+ Prior to deprecation of custom dependency detector, the return type as Optional[DagDependency].
+ This class verifies that custom dependency detector classes which assume that return type will still
+ work until support for them is removed in 3.0.
+
+ TODO: remove in Airflow 3.0
+ """
+
+ @staticmethod
+ def detect_task_dependencies(task: Operator) -> Optional[DagDependency]: # type: ignore
+ if isinstance(task, CustomDepOperator):
+ return DagDependency(
+ source=task.dag_id,
+ target='nothing',
+ dependency_type='abc',
+ dependency_id=task.task_id,
+ )
+ else:
+ return DependencyDetector().detect_task_dependencies(task) # type: ignore
+
+
executor_config_pod = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(name="my-name"),
spec=k8s.V1PodSpec(
@@ -1032,7 +1070,7 @@ class TestStringifiedDAGs:
"""
dag_schema: dict = load_dag_schema_dict()["definitions"]["dag"]["properties"]
- # The parameters we add manually in Serialization needs to be ignored
+ # The parameters we add manually in Serialization need to be ignored
ignored_keys: set = {
"is_subdag",
"tasks",
@@ -1329,6 +1367,110 @@ class TestStringifiedDAGs:
}
]
+ @conf_vars(
+ {
+ (
+ 'scheduler',
+ 'dependency_detector',
+ ): 'tests.serialization.test_dag_serialization.CustomDependencyDetector'
+ }
+ )
+ def test_custom_dep_detector(self):
+ """
+ Prior to deprecation of custom dependency detector, the return type was Optional[DagDependency].
+ This class verifies that custom dependency detector classes which assume that return type will still
+ work until support for them is removed in 3.0.
+
+ TODO: remove in Airflow 3.0
+ """
+ from airflow.sensors.external_task import ExternalTaskSensor
+
+ execution_date = datetime(2020, 1, 1)
+ with DAG(dag_id="test", start_date=execution_date) as dag:
+ ExternalTaskSensor(
+ task_id="task1",
+ external_dag_id="external_dag_id",
+ mode="reschedule",
+ )
+ CustomDepOperator(task_id='hello', bash_command='hi')
+ dag = SerializedDAG.to_dict(dag)
+ assert sorted(dag['dag']['dag_dependencies'], key=lambda x: tuple(x.values())) == sorted(
+ [
+ {
+ 'source': 'external_dag_id',
+ 'target': 'test',
+ 'dependency_type': 'sensor',
+ 'dependency_id': 'task1',
+ },
+ {
+ 'source': 'test',
+ 'target': 'nothing',
+ 'dependency_type': 'abc',
+ 'dependency_id': 'hello',
+ },
+ ],
+ key=lambda x: tuple(x.values()),
+ )
+
+ def test_dag_deps_datasets(self):
+ """
+ Check that dag_dependencies node is populated correctly for a DAG with datasets.
+ """
+ from airflow.sensors.external_task import ExternalTaskSensor
+
+ d1 = Dataset('d1')
+ d2 = Dataset('d2')
+ d3 = Dataset('d3')
+ d4 = Dataset('d4')
+ execution_date = datetime(2020, 1, 1)
+ with DAG(dag_id="test", start_date=execution_date, schedule_on=[d1]) as dag:
+ ExternalTaskSensor(
+ task_id="task1",
+ external_dag_id="external_dag_id",
+ mode="reschedule",
+ )
+ BashOperator(task_id='dataset_writer', bash_command="echo hello", outlets=[d2, d3])
+ BashOperator(task_id='other_dataset_writer', bash_command="echo hello", outlets=[d4])
+
+ dag = SerializedDAG.to_dict(dag)
+ actual = sorted(dag['dag']['dag_dependencies'], key=lambda x: tuple(x.values()))
+ expected = sorted(
+ [
+ {
+ 'source': 'test',
+ 'target': 'dataset',
+ 'dependency_type': 'dataset',
+ 'dependency_id': 'd4',
+ },
+ {
+ 'source': 'external_dag_id',
+ 'target': 'test',
+ 'dependency_type': 'sensor',
+ 'dependency_id': 'task1',
+ },
+ {
+ 'source': 'test',
+ 'target': 'dataset',
+ 'dependency_type': 'dataset',
+ 'dependency_id': 'd3',
+ },
+ {
+ 'source': 'test',
+ 'target': 'dataset',
+ 'dependency_type': 'dataset',
+ 'dependency_id': 'd2',
+ },
+ {
+ 'source': 'dataset',
+ 'target': 'test',
+ 'dependency_type': 'dataset',
+ 'dependency_id': 'd1',
+ },
+ ],
+ key=lambda x: tuple(x.values()),
+ )
+ assert actual == expected
+
def test_derived_dag_deps_operator(self):
"""
Tests DAG dependency detection for operators, including derived classes