You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/09/06 15:19:31 UTC

[airflow] branch main updated: Rewrite recursion into iteration (#26175)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8538a72d0f Rewrite recursion into iteration (#26175)
8538a72d0f is described below

commit 8538a72d0f90dd1917ef68554c65d8b96403beca
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Sep 6 16:19:22 2022 +0100

    Rewrite recursion into iteration (#26175)
    
    This helps to avoid RecursionError when viewing the graph
    view of a dag with many tasks
---
 airflow/www/views.py                | 13 +++++++++----
 tests/www/views/test_views_tasks.py | 21 +++++++++++++++++++++
 2 files changed, 30 insertions(+), 4 deletions(-)

diff --git a/airflow/www/views.py b/airflow/www/views.py
index d55d85e997..a86e8bd340 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -597,11 +597,16 @@ def dag_edges(dag):
     edges = set()
 
     def get_downstream(task):
-        for child in task.downstream_list:
-            edge = (task.task_id, child.task_id)
-            if edge not in edges:
+        tasks_to_trace = task.downstream_list
+        while tasks_to_trace:
+            tasks_to_trace_next: Set[str] = set()
+            for child in tasks_to_trace:
+                edge = (task.task_id, child.task_id)
+                if edge in edges:
+                    continue
+                tasks_to_trace_next.update(child.downstream_list)
                 edges.add(edge)
-                get_downstream(child)
+            tasks_to_trace = tasks_to_trace_next
 
     for root in dag.roots:
         get_downstream(root)
diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py
index 5f6254f1fb..d15845e161 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -970,3 +970,24 @@ def test_task_fail_duration(app, admin_client, dag_maker, session):
         assert resp.status_code == 200
         assert sorted(item["key"] for item in cumulative_chart) == ["fail", "success"]
         assert sorted(item["key"] for item in line_chart) == ["fail", "success"]
+
+
+def test_graph_view_doesnt_fail_on_recursion_error(app, dag_maker, admin_client):
+    """Test that the graph view doesn't fail on a recursion error."""
+    from airflow.utils.helpers import chain
+
+    with dag_maker('test_fails_with_recursion') as dag:
+
+        tasks = [
+            BashOperator(
+                task_id=f"task_{i}",
+                bash_command="echo test",
+            )
+            for i in range(1, 1000 + 1)
+        ]
+        chain(*tasks)
+    with unittest.mock.patch.object(app, 'dag_bag') as mocked_dag_bag:
+        mocked_dag_bag.get_dag.return_value = dag
+        url = f'/dags/{dag.dag_id}/graph'
+        resp = admin_client.get(url, follow_redirects=True)
+        assert resp.status_code == 200