You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2023/08/04 04:40:51 UTC

[airflow] branch main updated: Ensure DAG-level references are filled on unmap (#33083)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bcfadcf6e4 Ensure DAG-level references are filled on unmap (#33083)
bcfadcf6e4 is described below

commit bcfadcf6e4b2de587959594f54a9e8fef96c4a2b
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Fri Aug 4 12:40:43 2023 +0800

    Ensure DAG-level references are filled on unmap (#33083)
    
    Co-authored-by: Jed Cunningham <66...@users.noreply.github.com>
---
 airflow/models/mappedoperator.py               |  2 +
 airflow/serialization/serialized_objects.py    | 59 ++++++++++++++++++--------
 tests/serialization/test_serialized_objects.py | 21 +++++++++
 3 files changed, 64 insertions(+), 18 deletions(-)

diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 0cf8852ea2..82dcc82aa0 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -659,6 +659,8 @@ class MappedOperator(AbstractOperator):
 
         op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True)
         SerializedBaseOperator.populate_operator(op, self.operator_class)
+        if self.dag is not None:  # For Mypy; we only serialize tasks in a DAG so the check always satisfies.
+            SerializedBaseOperator.set_task_dag_references(op, self.dag)
         return op
 
     def _get_specified_expand_input(self) -> ExpandInput:
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 5efa3b3da5..d89f2e22d4 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -735,6 +735,13 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
 
     All operators are casted to SerializedBaseOperator after deserialization.
     Class specific attributes used by UI are move to object attributes.
+
+    Creating a SerializedBaseOperator is a three-step process:
+
+    1. Instantiate a :class:`SerializedBaseOperator` object.
+    2. Populate attributes with :func:`SerializedBaseOperator.populated_operator`.
+    3. When the task's containing DAG is available, fix references to the DAG
+       with :func:`SerializedBaseOperator.set_task_dag_references`.
     """
 
     _decorated_fields = {"executor_config"}
@@ -875,6 +882,13 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
 
     @classmethod
     def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
+        """Populate operator attributes with serialized values.
+
+        This covers simple attributes that don't reference other things in the
+        DAG. Setting references (such as ``op.dag`` and task dependencies) is
+        done in ``set_task_dag_references`` instead, which is called after the
+        DAG is hydrated.
+        """
         if "label" not in encoded_op:
             # Handle deserialization of old data before the introduction of TaskGroup
             encoded_op["label"] = encoded_op["task_id"]
@@ -982,6 +996,32 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
         # Used to determine if an Operator is inherited from EmptyOperator
         setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))
 
+    @staticmethod
+    def set_task_dag_references(task: Operator, dag: DAG) -> None:
+        """Handle DAG references on an operator.
+
+        The operator should have been mostly populated earlier by calling
+        ``populate_operator``. This function further fixes object references
+        that were not possible before the task's containing DAG is hydrated.
+        """
+        task.dag = dag
+
+        for date_attr in ("start_date", "end_date"):
+            if getattr(task, date_attr, None) is None:
+                setattr(task, date_attr, getattr(dag, date_attr, None))
+
+        if task.subdag is not None:
+            task.subdag.parent_dag = dag
+
+        # Dereference expand_input and op_kwargs_expand_input.
+        for k in ("expand_input", "op_kwargs_expand_input"):
+            if isinstance(kwargs_ref := getattr(task, k, None), _ExpandInputRef):
+                setattr(task, k, kwargs_ref.deref(dag))
+
+        for task_id in task.downstream_task_ids:
+            # Bypass set_upstream etc here - it does more than we want
+            dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
+
     @classmethod
     def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
         """Deserializes an operator from a JSON object."""
@@ -1328,24 +1368,7 @@ class SerializedDAG(DAG, BaseSerialization):
             setattr(dag, k, None)
 
         for task in dag.task_dict.values():
-            task.dag = dag
-
-            for date_attr in ["start_date", "end_date"]:
-                if getattr(task, date_attr) is None:
-                    setattr(task, date_attr, getattr(dag, date_attr))
-
-            if task.subdag is not None:
-                setattr(task.subdag, "parent_dag", dag)
-
-            # Dereference expand_input and op_kwargs_expand_input.
-            for k in ("expand_input", "op_kwargs_expand_input"):
-                kwargs_ref = getattr(task, k, None)
-                if isinstance(kwargs_ref, _ExpandInputRef):
-                    setattr(task, k, kwargs_ref.deref(dag))
-
-            for task_id in task.downstream_task_ids:
-                # Bypass set_upstream etc here - it does more than we want
-                dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
+            SerializedBaseOperator.set_task_dag_references(task, dag)
 
         return dag
 
diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py
index 1eb4214783..17f5187579 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -96,3 +96,24 @@ def test_use_pydantic_models():
     deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True)  # does not raise
 
     assert isinstance(deserialized[0][0], TaskInstancePydantic)
+
+
+def test_serialized_mapped_operator_unmap(dag_maker):
+    from airflow.serialization.serialized_objects import SerializedDAG
+    from tests.test_utils.mock_operators import MockOperator
+
+    with dag_maker(dag_id="dag") as dag:
+        MockOperator(task_id="task1", arg1="x")
+        MockOperator.partial(task_id="task2").expand(arg1=["a", "b"])
+
+    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+    assert serialized_dag.dag_id == "dag"
+
+    serialized_task1 = serialized_dag.get_task("task1")
+    assert serialized_task1.dag is serialized_dag
+
+    serialized_task2 = serialized_dag.get_task("task2")
+    assert serialized_task2.dag is serialized_dag
+
+    serialized_unmapped_task = serialized_task2.unmap(None)
+    assert serialized_unmapped_task.dag is serialized_dag