You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/04/12 08:38:36 UTC
[airflow] branch main updated: Call mapped_dependants only on the original task (#22904)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 30ac99773c Call mapped_dependants only on the original task (#22904)
30ac99773c is described below
commit 30ac99773c8577718c87703a310ffc454316cfce
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Apr 12 16:38:17 2022 +0800
Call mapped_dependants only on the original task (#22904)
* Add literal expands in test DAGs
* Call mapped_dependants only on the original task
We've made change on this in the scheduler, but need to match it in
the BackfillJob.
---
airflow/jobs/backfill_job.py | 11 +++++------
airflow/models/taskmixin.py | 3 ++-
tests/dags/test_mapped_classic.py | 3 +++
tests/dags/test_mapped_taskflow.py | 1 +
4 files changed, 11 insertions(+), 7 deletions(-)
diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index 0334a4f20c..840db35692 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -102,7 +102,7 @@ class BackfillJob(BaseJob):
def __init__(
self,
- dag,
+ dag: DAG,
start_date=None,
end_date=None,
mark_success=False,
@@ -238,8 +238,6 @@ class BackfillJob(BaseJob):
:param running: dict of key, task to verify
:return: An iterable of expanded TaskInstance per MappedTask
"""
- from airflow.models.mappedoperator import MappedOperator
-
executor = self.executor
# TODO: query all instead of refresh from db
@@ -266,8 +264,9 @@ class BackfillJob(BaseJob):
ti.handle_failure_with_callback(error=msg)
continue
if ti.state not in self.STATES_COUNT_AS_RUNNING:
- for node in ti.task.mapped_dependants():
- assert isinstance(node, MappedOperator)
+ # Don't use ti.task; if this task is mapped, that attribute
+ # would hold the unmapped task. We need to original task here.
+ for node in self.dag.get_task(ti.task_id, include_subdags=True).mapped_dependants():
yield node, ti.run_id, node.expand_mapped_task(ti.run_id, session=session)
@provide_session
@@ -702,7 +701,7 @@ class BackfillJob(BaseJob):
return err
- def _get_dag_with_subdags(self):
+ def _get_dag_with_subdags(self) -> List[DAG]:
return [self.dag] + self.dag.subdags
@provide_session
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index 73221623d8..c5d6165e8d 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -28,6 +28,7 @@ if TYPE_CHECKING:
from logging import Logger
from airflow.models.dag import DAG
+ from airflow.models.mappedoperator import MappedOperator
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.task_group import TaskGroup
@@ -290,7 +291,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
"""This is used by SerializedTaskGroup to serialize a task group's content."""
raise NotImplementedError()
- def mapped_dependants(self) -> Iterator["DAGNode"]:
+ def mapped_dependants(self) -> Iterator["MappedOperator"]:
"""Return any mapped nodes that are direct dependencies of the current task
For now, this walks the entire DAG to find mapped nodes that has this
diff --git a/tests/dags/test_mapped_classic.py b/tests/dags/test_mapped_classic.py
index 3880cc74fc..cbf3a8a5b8 100644
--- a/tests/dags/test_mapped_classic.py
+++ b/tests/dags/test_mapped_classic.py
@@ -32,3 +32,6 @@ def consumer(value):
with DAG(dag_id='test_mapped_classic', start_date=days_ago(2)) as dag:
PythonOperator.partial(task_id='consumer', python_callable=consumer).expand(op_args=make_arg_lists())
+ PythonOperator.partial(task_id='consumer_literal', python_callable=consumer).expand(
+ op_args=[[1], [2], [3]],
+ )
diff --git a/tests/dags/test_mapped_taskflow.py b/tests/dags/test_mapped_taskflow.py
index 34f6ae3d72..e4e796c3e4 100644
--- a/tests/dags/test_mapped_taskflow.py
+++ b/tests/dags/test_mapped_taskflow.py
@@ -29,3 +29,4 @@ with DAG(dag_id='test_mapped_taskflow', start_date=days_ago(2)) as dag:
print(repr(value))
consumer.expand(value=make_list())
+ consumer.expand(value=[1, 2, 3])