You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2023/08/04 04:40:51 UTC
[airflow] branch main updated: Ensure DAG-level references are filled on unmap (#33083)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bcfadcf6e4 Ensure DAG-level references are filled on unmap (#33083)
bcfadcf6e4 is described below
commit bcfadcf6e4b2de587959594f54a9e8fef96c4a2b
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Fri Aug 4 12:40:43 2023 +0800
Ensure DAG-level references are filled on unmap (#33083)
Co-authored-by: Jed Cunningham <66...@users.noreply.github.com>
---
airflow/models/mappedoperator.py | 2 +
airflow/serialization/serialized_objects.py | 59 ++++++++++++++++++--------
tests/serialization/test_serialized_objects.py | 21 +++++++++
3 files changed, 64 insertions(+), 18 deletions(-)
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 0cf8852ea2..82dcc82aa0 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -659,6 +659,8 @@ class MappedOperator(AbstractOperator):
op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True)
SerializedBaseOperator.populate_operator(op, self.operator_class)
+ if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies.
+ SerializedBaseOperator.set_task_dag_references(op, self.dag)
return op
def _get_specified_expand_input(self) -> ExpandInput:
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 5efa3b3da5..d89f2e22d4 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -735,6 +735,13 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
All operators are casted to SerializedBaseOperator after deserialization.
Class specific attributes used by UI are move to object attributes.
+
+ Creating a SerializedBaseOperator is a three-step process:
+
+ 1. Instantiate a :class:`SerializedBaseOperator` object.
+ 2. Populate attributes with :func:`SerializedBaseOperator.populated_operator`.
+ 3. When the task's containing DAG is available, fix references to the DAG
+ with :func:`SerializedBaseOperator.set_task_dag_references`.
"""
_decorated_fields = {"executor_config"}
@@ -875,6 +882,13 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
@classmethod
def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
+ """Populate operator attributes with serialized values.
+
+ This covers simple attributes that don't reference other things in the
+ DAG. Setting references (such as ``op.dag`` and task dependencies) is
+ done in ``set_task_dag_references`` instead, which is called after the
+ DAG is hydrated.
+ """
if "label" not in encoded_op:
# Handle deserialization of old data before the introduction of TaskGroup
encoded_op["label"] = encoded_op["task_id"]
@@ -982,6 +996,32 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
# Used to determine if an Operator is inherited from EmptyOperator
setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))
+ @staticmethod
+ def set_task_dag_references(task: Operator, dag: DAG) -> None:
+ """Handle DAG references on an operator.
+
+ The operator should have been mostly populated earlier by calling
+ ``populate_operator``. This function further fixes object references
+ that were not possible before the task's containing DAG is hydrated.
+ """
+ task.dag = dag
+
+ for date_attr in ("start_date", "end_date"):
+ if getattr(task, date_attr, None) is None:
+ setattr(task, date_attr, getattr(dag, date_attr, None))
+
+ if task.subdag is not None:
+ task.subdag.parent_dag = dag
+
+ # Dereference expand_input and op_kwargs_expand_input.
+ for k in ("expand_input", "op_kwargs_expand_input"):
+ if isinstance(kwargs_ref := getattr(task, k, None), _ExpandInputRef):
+ setattr(task, k, kwargs_ref.deref(dag))
+
+ for task_id in task.downstream_task_ids:
+ # Bypass set_upstream etc here - it does more than we want
+ dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
+
@classmethod
def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
"""Deserializes an operator from a JSON object."""
@@ -1328,24 +1368,7 @@ class SerializedDAG(DAG, BaseSerialization):
setattr(dag, k, None)
for task in dag.task_dict.values():
- task.dag = dag
-
- for date_attr in ["start_date", "end_date"]:
- if getattr(task, date_attr) is None:
- setattr(task, date_attr, getattr(dag, date_attr))
-
- if task.subdag is not None:
- setattr(task.subdag, "parent_dag", dag)
-
- # Dereference expand_input and op_kwargs_expand_input.
- for k in ("expand_input", "op_kwargs_expand_input"):
- kwargs_ref = getattr(task, k, None)
- if isinstance(kwargs_ref, _ExpandInputRef):
- setattr(task, k, kwargs_ref.deref(dag))
-
- for task_id in task.downstream_task_ids:
- # Bypass set_upstream etc here - it does more than we want
- dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
+ SerializedBaseOperator.set_task_dag_references(task, dag)
return dag
diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py
index 1eb4214783..17f5187579 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -96,3 +96,24 @@ def test_use_pydantic_models():
deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True) # does not raise
assert isinstance(deserialized[0][0], TaskInstancePydantic)
+
+
+def test_serialized_mapped_operator_unmap(dag_maker):
+ from airflow.serialization.serialized_objects import SerializedDAG
+ from tests.test_utils.mock_operators import MockOperator
+
+ with dag_maker(dag_id="dag") as dag:
+ MockOperator(task_id="task1", arg1="x")
+ MockOperator.partial(task_id="task2").expand(arg1=["a", "b"])
+
+ serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+ assert serialized_dag.dag_id == "dag"
+
+ serialized_task1 = serialized_dag.get_task("task1")
+ assert serialized_task1.dag is serialized_dag
+
+ serialized_task2 = serialized_dag.get_task("task2")
+ assert serialized_task2.dag is serialized_dag
+
+ serialized_unmapped_task = serialized_task2.unmap(None)
+ assert serialized_unmapped_task.dag is serialized_dag