You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/04/12 08:38:36 UTC

[airflow] branch main updated: Call mapped_dependants only on the original task (#22904)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 30ac99773c Call mapped_dependants only on the original task (#22904)
30ac99773c is described below

commit 30ac99773c8577718c87703a310ffc454316cfce
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Apr 12 16:38:17 2022 +0800

    Call mapped_dependants only on the original task (#22904)
    
    * Add literal expands in test DAGs
    
    * Call mapped_dependants only on the original task
    
    We've made change on this in the scheduler, but need to match it in
    the BackfillJob.
---
 airflow/jobs/backfill_job.py       | 11 +++++------
 airflow/models/taskmixin.py        |  3 ++-
 tests/dags/test_mapped_classic.py  |  3 +++
 tests/dags/test_mapped_taskflow.py |  1 +
 4 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index 0334a4f20c..840db35692 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -102,7 +102,7 @@ class BackfillJob(BaseJob):
 
     def __init__(
         self,
-        dag,
+        dag: DAG,
         start_date=None,
         end_date=None,
         mark_success=False,
@@ -238,8 +238,6 @@ class BackfillJob(BaseJob):
         :param running: dict of key, task to verify
         :return: An iterable of expanded TaskInstance per MappedTask
         """
-        from airflow.models.mappedoperator import MappedOperator
-
         executor = self.executor
 
         # TODO: query all instead of refresh from db
@@ -266,8 +264,9 @@ class BackfillJob(BaseJob):
                 ti.handle_failure_with_callback(error=msg)
                 continue
             if ti.state not in self.STATES_COUNT_AS_RUNNING:
-                for node in ti.task.mapped_dependants():
-                    assert isinstance(node, MappedOperator)
+                # Don't use ti.task; if this task is mapped, that attribute
+                # would hold the unmapped task. We need to original task here.
+                for node in self.dag.get_task(ti.task_id, include_subdags=True).mapped_dependants():
                     yield node, ti.run_id, node.expand_mapped_task(ti.run_id, session=session)
 
     @provide_session
@@ -702,7 +701,7 @@ class BackfillJob(BaseJob):
 
         return err
 
-    def _get_dag_with_subdags(self):
+    def _get_dag_with_subdags(self) -> List[DAG]:
         return [self.dag] + self.dag.subdags
 
     @provide_session
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index 73221623d8..c5d6165e8d 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -28,6 +28,7 @@ if TYPE_CHECKING:
     from logging import Logger
 
     from airflow.models.dag import DAG
+    from airflow.models.mappedoperator import MappedOperator
     from airflow.utils.edgemodifier import EdgeModifier
     from airflow.utils.task_group import TaskGroup
 
@@ -290,7 +291,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
         """This is used by SerializedTaskGroup to serialize a task group's content."""
         raise NotImplementedError()
 
-    def mapped_dependants(self) -> Iterator["DAGNode"]:
+    def mapped_dependants(self) -> Iterator["MappedOperator"]:
         """Return any mapped nodes that are direct dependencies of the current task
 
         For now, this walks the entire DAG to find mapped nodes that has this
diff --git a/tests/dags/test_mapped_classic.py b/tests/dags/test_mapped_classic.py
index 3880cc74fc..cbf3a8a5b8 100644
--- a/tests/dags/test_mapped_classic.py
+++ b/tests/dags/test_mapped_classic.py
@@ -32,3 +32,6 @@ def consumer(value):
 
 with DAG(dag_id='test_mapped_classic', start_date=days_ago(2)) as dag:
     PythonOperator.partial(task_id='consumer', python_callable=consumer).expand(op_args=make_arg_lists())
+    PythonOperator.partial(task_id='consumer_literal', python_callable=consumer).expand(
+        op_args=[[1], [2], [3]],
+    )
diff --git a/tests/dags/test_mapped_taskflow.py b/tests/dags/test_mapped_taskflow.py
index 34f6ae3d72..e4e796c3e4 100644
--- a/tests/dags/test_mapped_taskflow.py
+++ b/tests/dags/test_mapped_taskflow.py
@@ -29,3 +29,4 @@ with DAG(dag_id='test_mapped_taskflow', start_date=days_ago(2)) as dag:
         print(repr(value))
 
     consumer.expand(value=make_list())
+    consumer.expand(value=[1, 2, 3])