You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/04/19 16:25:51 UTC
[airflow] branch main updated: Improve speed of `dag.partial_subset` by not deep-copying TaskGroup (#23088)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5fca11ef85 Improve speed of `dag.partial_subset` by not deep-copying TaskGroup (#23088)
5fca11ef85 is described below
commit 5fca11ef856e5f1a451cce726a06b537ea979649
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Tue Apr 19 17:25:36 2022 +0100
Improve speed of `dag.partial_subset` by not deep-copying TaskGroup (#23088)
This resulted in the _entire_ dag being copied over and over many times.
For a task with 500 dags this takes the time of this function down from
60s(!) to just over 1s.
---
airflow/models/dag.py | 16 ++++++++++------
1 file changed, 10 insertions(+), 6 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 418bd6e2ce..4013f40bdc 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -26,6 +26,7 @@ import re
import sys
import traceback
import warnings
+import weakref
from datetime import datetime, timedelta
from inspect import signature
from typing import (
@@ -1989,11 +1990,13 @@ class DAG(LoggingMixin):
also_include.extend(upstream)
# Compiling the unique list of tasks that made the cut
- # Make sure to not recursively deepcopy the dag while copying the task
- dag.task_dict = {
- t.task_id: copy.deepcopy(t, {id(t.dag): dag}) # type: ignore
- for t in matched_tasks + also_include
- }
+ # Make sure to not recursively deepcopy the dag or task_group while copying the task.
+ # task_group is reset later
+ def _deepcopy_task(t) -> "Operator":
+ memo.setdefault(id(t.task_group), None)
+ return copy.deepcopy(t, memo)
+
+ dag.task_dict = {t.task_id: _deepcopy_task(t) for t in matched_tasks + also_include}
def filter_task_group(group, parent_group):
"""Exclude tasks not included in the subdag from the given TaskGroup."""
@@ -2006,7 +2009,8 @@ class DAG(LoggingMixin):
for child in group.children.values():
if isinstance(child, AbstractOperator):
if child.task_id in dag.task_dict:
- copied.children[child.task_id] = dag.task_dict[child.task_id]
+ task = copied.children[child.task_id] = dag.task_dict[child.task_id]
+ task.task_group = weakref.proxy(copied)
else:
copied.used_group_ids.discard(child.task_id)
else: