You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/09/06 15:19:31 UTC
[airflow] branch main updated: Rewrite recursion into iteration (#26175)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8538a72d0f Rewrite recursion into iteration (#26175)
8538a72d0f is described below
commit 8538a72d0f90dd1917ef68554c65d8b96403beca
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Sep 6 16:19:22 2022 +0100
Rewrite recursion into iteration (#26175)
This helps to avoid RecursionError when viewing the graph
view of a dag with many tasks
---
airflow/www/views.py | 13 +++++++++----
tests/www/views/test_views_tasks.py | 21 +++++++++++++++++++++
2 files changed, 30 insertions(+), 4 deletions(-)
diff --git a/airflow/www/views.py b/airflow/www/views.py
index d55d85e997..a86e8bd340 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -597,11 +597,16 @@ def dag_edges(dag):
edges = set()
def get_downstream(task):
- for child in task.downstream_list:
- edge = (task.task_id, child.task_id)
- if edge not in edges:
+ tasks_to_trace = task.downstream_list
+ while tasks_to_trace:
+ tasks_to_trace_next: Set[str] = set()
+ for child in tasks_to_trace:
+ edge = (task.task_id, child.task_id)
+ if edge in edges:
+ continue
+ tasks_to_trace_next.update(child.downstream_list)
edges.add(edge)
- get_downstream(child)
+ tasks_to_trace = tasks_to_trace_next
for root in dag.roots:
get_downstream(root)
diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py
index 5f6254f1fb..d15845e161 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -970,3 +970,24 @@ def test_task_fail_duration(app, admin_client, dag_maker, session):
assert resp.status_code == 200
assert sorted(item["key"] for item in cumulative_chart) == ["fail", "success"]
assert sorted(item["key"] for item in line_chart) == ["fail", "success"]
+
+
+def test_graph_view_doesnt_fail_on_recursion_error(app, dag_maker, admin_client):
+ """Test that the graph view doesn't fail on a recursion error."""
+ from airflow.utils.helpers import chain
+
+ with dag_maker('test_fails_with_recursion') as dag:
+
+ tasks = [
+ BashOperator(
+ task_id=f"task_{i}",
+ bash_command="echo test",
+ )
+ for i in range(1, 1000 + 1)
+ ]
+ chain(*tasks)
+ with unittest.mock.patch.object(app, 'dag_bag') as mocked_dag_bag:
+ mocked_dag_bag.get_dag.return_value = dag
+ url = f'/dags/{dag.dag_id}/graph'
+ resp = admin_client.get(url, follow_redirects=True)
+ assert resp.status_code == 200