You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/04/14 14:53:18 UTC

[airflow] branch main updated: Fix TaskFail queries in views after run_id migration (#23008)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 70049f19e4 Fix TaskFail queries in views after run_id migration (#23008)
70049f19e4 is described below

commit 70049f19e4ac82ea922d7e59871a3b4ebae068f1
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Thu Apr 14 15:53:09 2022 +0100

    Fix TaskFail queries in views after run_id migration (#23008)
    
    Two problems here:
    
    1. TaskFail no longer has a executin_date property -- switch to run_id
    2. We weren't joining to DagRun correctly, meaning we'd end up with a
       cross-product effect(? Something weird anyway)
    
    Co-authored-by: Karthikeyan Singaravelan <ti...@gmail.com>
---
 airflow/models/taskfail.py          | 12 ++++++++++
 airflow/www/views.py                | 16 +++++++------
 tests/www/views/test_views_tasks.py | 47 ++++++++++++++++++++++++++++++++++++-
 3 files changed, 67 insertions(+), 8 deletions(-)

diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py
index 4266179a34..a4bd102b56 100644
--- a/airflow/models/taskfail.py
+++ b/airflow/models/taskfail.py
@@ -18,6 +18,7 @@
 """Taskfail tracks the failed run durations of each task instance"""
 
 from sqlalchemy import Column, ForeignKeyConstraint, Integer
+from sqlalchemy.orm import relationship
 
 from airflow.models.base import Base, StringID
 from airflow.utils.sqlalchemy import UtcDateTime
@@ -51,6 +52,17 @@ class TaskFail(Base):
         ),
     )
 
+    # We don't need a DB level FK here, as we already have that to TI (which has one to DR) but by defining
+    # the relationship we can more easily find the execution date for these rows
+    dag_run = relationship(
+        "DagRun",
+        primaryjoin="""and_(
+            TaskFail.dag_id == foreign(DagRun.dag_id),
+            TaskFail.run_id == foreign(DagRun.run_id),
+        )""",
+        viewonly=True,
+    )
+
     def __init__(self, task, run_id, start_date, end_date, map_index):
         self.dag_id = task.dag_id
         self.task_id = task.task_id
diff --git a/airflow/www/views.py b/airflow/www/views.py
index c722853dfe..f0ca663611 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -2885,21 +2885,22 @@ class Airflow(AirflowBaseView):
             min_date = timezone.utc_epoch()
         ti_fails = (
             session.query(TaskFail)
+            .join(TaskFail.dag_run)
             .filter(
                 TaskFail.dag_id == dag.dag_id,
                 DagRun.execution_date >= min_date,
                 DagRun.execution_date <= base_date,
-                TaskFail.task_id.in_([t.task_id for t in dag.tasks]),
             )
-            .all()
         )
+        if dag.partial:
+            ti_fails = ti_fails.filter(TaskFail.task_id.in_([t.task_id for t in dag.tasks]))
 
         fails_totals = defaultdict(int)
         for failed_task_instance in ti_fails:
             dict_key = (
                 failed_task_instance.dag_id,
                 failed_task_instance.task_id,
-                failed_task_instance.execution_date,
+                failed_task_instance.run_id,
             )
             if failed_task_instance.duration:
                 fails_totals[dict_key] += failed_task_instance.duration
@@ -2909,7 +2910,7 @@ class Airflow(AirflowBaseView):
                 date_time = wwwutils.epoch(task_instance.execution_date)
                 x_points[task_instance.task_id].append(date_time)
                 y_points[task_instance.task_id].append(float(task_instance.duration))
-                fails_dict_key = (task_instance.dag_id, task_instance.task_id, task_instance.execution_date)
+                fails_dict_key = (task_instance.dag_id, task_instance.task_id, task_instance.run_id)
                 fails_total = fails_totals[fails_dict_key]
                 cumulative_y[task_instance.task_id].append(float(task_instance.duration + fails_total))
 
@@ -3215,17 +3216,18 @@ class Airflow(AirflowBaseView):
 
         tis = (
             session.query(TaskInstance)
-            .join(TaskInstance.dag_run)
             .filter(
-                DagRun.execution_date == dttm,
                 TaskInstance.dag_id == dag_id,
+                TaskInstance.run_id == dag_run_id,
                 TaskInstance.start_date.isnot(None),
                 TaskInstance.state.isnot(None),
             )
             .order_by(TaskInstance.start_date)
         )
 
-        ti_fails = session.query(TaskFail).filter(DagRun.execution_date == dttm, TaskFail.dag_id == dag_id)
+        ti_fails = session.query(TaskFail).filter_by(run_id=dag_run_id, dag_id=dag_id)
+        if dag.partial:
+            ti_fails = ti_fails.filter(TaskFail.task_id.in_([t.task_id for t in dag.tasks]))
 
         tasks = []
         for ti in tis:
diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py
index 397faca3cd..fce94fd5e4 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -25,9 +25,11 @@ import pytest
 from freezegun import freeze_time
 
 from airflow import settings
+from airflow.exceptions import AirflowException
 from airflow.executors.celery_executor import CeleryExecutor
-from airflow.models import DAG, DagBag, DagModel, TaskInstance, TaskReschedule
+from airflow.models import DAG, DagBag, DagModel, TaskFail, TaskInstance, TaskReschedule
 from airflow.models.dagcode import DagCode
+from airflow.operators.bash import BashOperator
 from airflow.security import permissions
 from airflow.ti_deps.dependencies_states import QUEUEABLE_STATES, RUNNABLE_STATES
 from airflow.utils import dates, timezone
@@ -934,3 +936,46 @@ def test_action_muldelete_task_instance(session, admin_client, task_search_tuple
             == 0
         )
     assert session.query(TaskReschedule).count() == 0
+
+
+def test_task_fail_duration(app, admin_client, dag_maker, session):
+    """Task duration page with a TaskFail entry should render without error."""
+    with dag_maker() as dag:
+        op1 = BashOperator(task_id='fail', bash_command='exit 1')
+        op2 = BashOperator(task_id='success', bash_command='exit 0')
+
+    with pytest.raises(AirflowException):
+        op1.run()
+    op2.run()
+
+    op1_fails = (
+        session.query(TaskFail)
+        .filter(
+            TaskFail.task_id == 'fail',
+            TaskFail.dag_id == dag.dag_id,
+        )
+        .all()
+    )
+
+    op2_fails = (
+        session.query(TaskFail)
+        .filter(
+            TaskFail.task_id == 'success',
+            TaskFail.dag_id == dag.dag_id,
+        )
+        .all()
+    )
+
+    assert len(op1_fails) == 1
+    assert len(op2_fails) == 0
+
+    with unittest.mock.patch.object(app, 'dag_bag') as mocked_dag_bag:
+        mocked_dag_bag.get_dag.return_value = dag
+        resp = admin_client.get(f"dags/{dag.dag_id}/duration", follow_redirects=True)
+        html = resp.get_data().decode()
+        cumulative_chart = json.loads(re.search("data_cumlinechart=(.*);", html).group(1))
+        line_chart = json.loads(re.search("data_linechart=(.*);", html).group(1))
+
+        assert resp.status_code == 200
+        assert sorted(item["key"] for item in cumulative_chart) == ["fail", "success"]
+        assert sorted(item["key"] for item in line_chart) == ["fail", "success"]