You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/04/14 14:53:18 UTC
[airflow] branch main updated: Fix TaskFail queries in views after run_id migration (#23008)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 70049f19e4 Fix TaskFail queries in views after run_id migration (#23008)
70049f19e4 is described below
commit 70049f19e4ac82ea922d7e59871a3b4ebae068f1
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Thu Apr 14 15:53:09 2022 +0100
Fix TaskFail queries in views after run_id migration (#23008)
Two problems here:
1. TaskFail no longer has a executin_date property -- switch to run_id
2. We weren't joining to DagRun correctly, meaning we'd end up with a
cross-product effect(? Something weird anyway)
Co-authored-by: Karthikeyan Singaravelan <ti...@gmail.com>
---
airflow/models/taskfail.py | 12 ++++++++++
airflow/www/views.py | 16 +++++++------
tests/www/views/test_views_tasks.py | 47 ++++++++++++++++++++++++++++++++++++-
3 files changed, 67 insertions(+), 8 deletions(-)
diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py
index 4266179a34..a4bd102b56 100644
--- a/airflow/models/taskfail.py
+++ b/airflow/models/taskfail.py
@@ -18,6 +18,7 @@
"""Taskfail tracks the failed run durations of each task instance"""
from sqlalchemy import Column, ForeignKeyConstraint, Integer
+from sqlalchemy.orm import relationship
from airflow.models.base import Base, StringID
from airflow.utils.sqlalchemy import UtcDateTime
@@ -51,6 +52,17 @@ class TaskFail(Base):
),
)
+ # We don't need a DB level FK here, as we already have that to TI (which has one to DR) but by defining
+ # the relationship we can more easily find the execution date for these rows
+ dag_run = relationship(
+ "DagRun",
+ primaryjoin="""and_(
+ TaskFail.dag_id == foreign(DagRun.dag_id),
+ TaskFail.run_id == foreign(DagRun.run_id),
+ )""",
+ viewonly=True,
+ )
+
def __init__(self, task, run_id, start_date, end_date, map_index):
self.dag_id = task.dag_id
self.task_id = task.task_id
diff --git a/airflow/www/views.py b/airflow/www/views.py
index c722853dfe..f0ca663611 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -2885,21 +2885,22 @@ class Airflow(AirflowBaseView):
min_date = timezone.utc_epoch()
ti_fails = (
session.query(TaskFail)
+ .join(TaskFail.dag_run)
.filter(
TaskFail.dag_id == dag.dag_id,
DagRun.execution_date >= min_date,
DagRun.execution_date <= base_date,
- TaskFail.task_id.in_([t.task_id for t in dag.tasks]),
)
- .all()
)
+ if dag.partial:
+ ti_fails = ti_fails.filter(TaskFail.task_id.in_([t.task_id for t in dag.tasks]))
fails_totals = defaultdict(int)
for failed_task_instance in ti_fails:
dict_key = (
failed_task_instance.dag_id,
failed_task_instance.task_id,
- failed_task_instance.execution_date,
+ failed_task_instance.run_id,
)
if failed_task_instance.duration:
fails_totals[dict_key] += failed_task_instance.duration
@@ -2909,7 +2910,7 @@ class Airflow(AirflowBaseView):
date_time = wwwutils.epoch(task_instance.execution_date)
x_points[task_instance.task_id].append(date_time)
y_points[task_instance.task_id].append(float(task_instance.duration))
- fails_dict_key = (task_instance.dag_id, task_instance.task_id, task_instance.execution_date)
+ fails_dict_key = (task_instance.dag_id, task_instance.task_id, task_instance.run_id)
fails_total = fails_totals[fails_dict_key]
cumulative_y[task_instance.task_id].append(float(task_instance.duration + fails_total))
@@ -3215,17 +3216,18 @@ class Airflow(AirflowBaseView):
tis = (
session.query(TaskInstance)
- .join(TaskInstance.dag_run)
.filter(
- DagRun.execution_date == dttm,
TaskInstance.dag_id == dag_id,
+ TaskInstance.run_id == dag_run_id,
TaskInstance.start_date.isnot(None),
TaskInstance.state.isnot(None),
)
.order_by(TaskInstance.start_date)
)
- ti_fails = session.query(TaskFail).filter(DagRun.execution_date == dttm, TaskFail.dag_id == dag_id)
+ ti_fails = session.query(TaskFail).filter_by(run_id=dag_run_id, dag_id=dag_id)
+ if dag.partial:
+ ti_fails = ti_fails.filter(TaskFail.task_id.in_([t.task_id for t in dag.tasks]))
tasks = []
for ti in tis:
diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py
index 397faca3cd..fce94fd5e4 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -25,9 +25,11 @@ import pytest
from freezegun import freeze_time
from airflow import settings
+from airflow.exceptions import AirflowException
from airflow.executors.celery_executor import CeleryExecutor
-from airflow.models import DAG, DagBag, DagModel, TaskInstance, TaskReschedule
+from airflow.models import DAG, DagBag, DagModel, TaskFail, TaskInstance, TaskReschedule
from airflow.models.dagcode import DagCode
+from airflow.operators.bash import BashOperator
from airflow.security import permissions
from airflow.ti_deps.dependencies_states import QUEUEABLE_STATES, RUNNABLE_STATES
from airflow.utils import dates, timezone
@@ -934,3 +936,46 @@ def test_action_muldelete_task_instance(session, admin_client, task_search_tuple
== 0
)
assert session.query(TaskReschedule).count() == 0
+
+
+def test_task_fail_duration(app, admin_client, dag_maker, session):
+ """Task duration page with a TaskFail entry should render without error."""
+ with dag_maker() as dag:
+ op1 = BashOperator(task_id='fail', bash_command='exit 1')
+ op2 = BashOperator(task_id='success', bash_command='exit 0')
+
+ with pytest.raises(AirflowException):
+ op1.run()
+ op2.run()
+
+ op1_fails = (
+ session.query(TaskFail)
+ .filter(
+ TaskFail.task_id == 'fail',
+ TaskFail.dag_id == dag.dag_id,
+ )
+ .all()
+ )
+
+ op2_fails = (
+ session.query(TaskFail)
+ .filter(
+ TaskFail.task_id == 'success',
+ TaskFail.dag_id == dag.dag_id,
+ )
+ .all()
+ )
+
+ assert len(op1_fails) == 1
+ assert len(op2_fails) == 0
+
+ with unittest.mock.patch.object(app, 'dag_bag') as mocked_dag_bag:
+ mocked_dag_bag.get_dag.return_value = dag
+ resp = admin_client.get(f"dags/{dag.dag_id}/duration", follow_redirects=True)
+ html = resp.get_data().decode()
+ cumulative_chart = json.loads(re.search("data_cumlinechart=(.*);", html).group(1))
+ line_chart = json.loads(re.search("data_linechart=(.*);", html).group(1))
+
+ assert resp.status_code == 200
+ assert sorted(item["key"] for item in cumulative_chart) == ["fail", "success"]
+ assert sorted(item["key"] for item in line_chart) == ["fail", "success"]