You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/04/19 16:25:51 UTC

[airflow] branch main updated: Improve speed of `dag.partial_subset` by not deep-copying TaskGroup (#23088)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5fca11ef85 Improve speed of `dag.partial_subset` by not deep-copying TaskGroup (#23088)
5fca11ef85 is described below

commit 5fca11ef856e5f1a451cce726a06b537ea979649
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Tue Apr 19 17:25:36 2022 +0100

    Improve speed of `dag.partial_subset` by not deep-copying TaskGroup (#23088)
    
    This resulted in the _entire_ dag being copied over and over many times.
    For a task with 500 dags this takes the time of this function down from
    60s(!) to just over 1s.
---
 airflow/models/dag.py | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 418bd6e2ce..4013f40bdc 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -26,6 +26,7 @@ import re
 import sys
 import traceback
 import warnings
+import weakref
 from datetime import datetime, timedelta
 from inspect import signature
 from typing import (
@@ -1989,11 +1990,13 @@ class DAG(LoggingMixin):
                 also_include.extend(upstream)
 
         # Compiling the unique list of tasks that made the cut
-        # Make sure to not recursively deepcopy the dag while copying the task
-        dag.task_dict = {
-            t.task_id: copy.deepcopy(t, {id(t.dag): dag})  # type: ignore
-            for t in matched_tasks + also_include
-        }
+        # Make sure to not recursively deepcopy the dag or task_group while copying the task.
+        # task_group is reset later
+        def _deepcopy_task(t) -> "Operator":
+            memo.setdefault(id(t.task_group), None)
+            return copy.deepcopy(t, memo)
+
+        dag.task_dict = {t.task_id: _deepcopy_task(t) for t in matched_tasks + also_include}
 
         def filter_task_group(group, parent_group):
             """Exclude tasks not included in the subdag from the given TaskGroup."""
@@ -2006,7 +2009,8 @@ class DAG(LoggingMixin):
             for child in group.children.values():
                 if isinstance(child, AbstractOperator):
                     if child.task_id in dag.task_dict:
-                        copied.children[child.task_id] = dag.task_dict[child.task_id]
+                        task = copied.children[child.task_id] = dag.task_dict[child.task_id]
+                        task.task_group = weakref.proxy(copied)
                     else:
                         copied.used_group_ids.discard(child.task_id)
                 else: