You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ds...@apache.org on 2022/04/14 01:59:47 UTC

[airflow] branch main updated: Add dangling rows check for TaskInstance references (#22924)

This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 26e09dffe5 Add dangling rows check for TaskInstance references (#22924)
26e09dffe5 is described below

commit 26e09dffe5e7571e183bed30e9f2bd178406837d
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Wed Apr 13 18:59:31 2022 -0700

    Add dangling rows check for TaskInstance references (#22924)
    
    We are adding some foreign keys in 2.3.0 so we want make it more likely that migration succeeds by detecting FK violations and moving the records out of the table before creating the FK.  We already had a check for "missing" dag runs, but this adds a check for TaskInstance.  In most cases we replace the "missing dag run" check with a "missing TI" check since from 2.2.0 a TI implies the existence of a DR anyway.
---
 airflow/utils/db.py | 141 ++++++++++++++++++++++++++++++++++++++--------------
 1 file changed, 104 insertions(+), 37 deletions(-)

diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 8b3b5ebf21..aefac14c90 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -22,12 +22,14 @@ import os
 import sys
 import time
 import warnings
+from dataclasses import dataclass
 from tempfile import gettempdir
 from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple
 
-from sqlalchemy import Table, column, exc, func, inspect, literal, or_, table, text
+from sqlalchemy import Table, and_, column, exc, func, inspect, literal, or_, table, text
 from sqlalchemy.orm.session import Session
 
+import airflow
 from airflow import settings
 from airflow.compat.sqlalchemy import has_table
 from airflow.configuration import conf
@@ -654,8 +656,7 @@ def initdb(session: Session = NEW_SESSION):
 def _get_alembic_config():
     from alembic.config import Config
 
-    current_dir = os.path.dirname(os.path.abspath(__file__))
-    package_dir = os.path.normpath(os.path.join(current_dir, '..'))
+    package_dir = os.path.dirname(airflow.__file__)
     directory = os.path.join(package_dir, 'migrations')
     config = Config(os.path.join(package_dir, 'alembic.ini'))
     config.set_main_option('script_location', directory.replace('%', '%%'))
@@ -965,7 +966,9 @@ def check_run_id_null(session: Session) -> Iterable[str]:
     )
     invalid_dagrun_count = session.query(dagrun_table.c.id).filter(invalid_dagrun_filter).count()
     if invalid_dagrun_count > 0:
-        dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2")
+        dagrun_dangling_table_name = _format_airflow_moved_table_name(
+            source_table=dagrun_table.name, version="2.2"
+        )
         if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names():
             yield _format_dangling_error(
                 source_table=dagrun_table.name,
@@ -1045,7 +1048,61 @@ def _move_dangling_data_to_new_table(
     session.execute(delete)
 
 
-def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str]:
+def _dag_run_exists(session, source_table, dag_run):
+    """
+    Given a source table, we generate a subquery that will return 1 for every row that
+    has a dagrun.
+    """
+    source_to_dag_run_join_cond = and_(
+        source_table.c.dag_id == dag_run.c.dag_id,
+        source_table.c.execution_date == dag_run.c.execution_date,
+    )
+    exists_subquery = session.query(text('1')).select_from(dag_run).filter(source_to_dag_run_join_cond)
+    return exists_subquery
+
+
+def _task_instance_exists(session, source_table, dag_run, task_instance):
+    """
+    Given a source table, we generate a subquery that will return 1 for every row that
+    has a valid task instance (and associated dagrun).
+
+    This is used to identify rows that need to be removed from tables prior to adding a TI fk.
+
+    Since this check is applied prior to running the migrations, we have to use different
+    query logic depending on which revision the database is at.
+
+    """
+    if 'run_id' not in task_instance.c:
+        # db is < 2.2.0
+        source_to_ti_join_cond = and_(
+            source_table.c.dag_id == task_instance.c.dag_id,
+            source_table.c.task_id == task_instance.c.task_id,
+            source_table.c.execution_date == task_instance.c.execution_date,
+        )
+        ti_to_dr_join_cond = and_(
+            source_table.c.dag_id == task_instance.c.dag_id,
+            source_table.c.execution_date == task_instance.c.execution_date,
+        )
+    else:
+        # db is 2.2.0 <= version < 2.3.0
+        source_to_ti_join_cond = and_(
+            source_table.c.dag_id == task_instance.c.dag_id,
+            source_table.c.task_id == task_instance.c.task_id,
+        )
+        ti_to_dr_join_cond = and_(
+            source_table.c.dag_id == task_instance.c.dag_id,
+            dag_run.c.run_id == task_instance.c.run_id,
+            source_table.c.execution_date == dag_run.c.execution_date,
+        )
+    exists_subquery = (
+        session.query(text('1'))
+        .select_from(task_instance.join(dag_run, onclause=ti_to_dr_join_cond))
+        .filter(source_to_ti_join_cond)
+    )
+    return exists_subquery
+
+
+def check_bad_references(session: Session) -> Iterable[str]:
     """
     Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id`
     in many tables.
@@ -1053,20 +1110,40 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str
     When we find such "dangling" rows we back them up in a special table and delete them
     from the main table.
     """
-    from sqlalchemy import and_
-
     from airflow.models.renderedtifields import RenderedTaskInstanceFields
 
-    models_to_dagrun: List[Tuple[Base, str]] = [
-        (mod, ver)
-        for ver, models in {
-            '2.2': [TaskInstance, TaskReschedule],
-            '2.3': [RenderedTaskInstanceFields, TaskFail, XCom],
-        }.items()
-        for mod in models
-    ]
+    @dataclass
+    class BadReferenceConfig:
+        """
+        :param exists_func: function that returns subquery which determines whether bad rows exist
+        :param join_tables: table objects referenced in subquery
+        :param ref_table: information-only identifier for categorizing the missing ref
+        """
+
+        exists_func: Callable
+        join_tables: List[str]
+        ref_table: str
+
+    missing_dag_run_config = BadReferenceConfig(
+        exists_func=_dag_run_exists,
+        join_tables=['dag_run'],
+        ref_table='dag_run',
+    )
+
+    missing_ti_config = BadReferenceConfig(
+        exists_func=_task_instance_exists,
+        join_tables=['dag_run', 'task_instance'],
+        ref_table='task_instance',
+    )
 
-    metadata = reflect_tables([*[x[0] for x in models_to_dagrun], DagRun], session)
+    models_list: List[Tuple[Base, str, BadReferenceConfig]] = [
+        (TaskInstance, '2.2', missing_dag_run_config),
+        (TaskReschedule, '2.2', missing_ti_config),
+        (RenderedTaskInstanceFields, '2.3', missing_ti_config),
+        (TaskFail, '2.3', missing_ti_config),
+        (XCom, '2.3', missing_ti_config),
+    ]
+    metadata = reflect_tables([*[x[0] for x in models_list], DagRun, TaskInstance], session)
 
     if (
         metadata.tables.get(DagRun.__tablename__) is None
@@ -1075,16 +1152,13 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str
         # Key table doesn't exist -- likely empty DB.
         return
 
-    # We can't use the model here since it may differ from the db state due to
-    # this function is run prior to migration. Use the reflected table instead.
-    dagrun_table = metadata.tables[DagRun.__tablename__]
-
     existing_table_names = set(inspect(session.get_bind()).get_table_names())
     errored = False
 
-    for model, change_version in models_to_dagrun:
+    for model, change_version, bad_ref_cfg in models_list:
         # We can't use the model here since it may differ from the db state due to
         # this function is run prior to migration. Use the reflected table instead.
+        exists_func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
         source_table = metadata.tables.get(model.__tablename__)  # type: ignore
         if source_table is None:
             continue
@@ -1093,29 +1167,22 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str
         if "run_id" in source_table.columns:
             continue
 
-        # find rows in source table which don't have a matching dag run
-        source_to_dag_run_join_cond = and_(
-            source_table.c.dag_id == dagrun_table.c.dag_id,
-            source_table.c.execution_date == dagrun_table.c.execution_date,
-        )
-        exists_subquery = (
-            session.query(text('1')).select_from(dagrun_table).filter(source_to_dag_run_join_cond)
-        )
-        invalid_rows_query = session.query(*[x.label(x.name) for x in source_table.c]).filter(
-            ~exists_subquery.exists()
-        )
-
+        bad_rows_subquery = bad_ref_cfg.exists_func(session, source_table, **exists_func_kwargs)
+        select_list = [x.label(x.name) for x in source_table.c]
+        invalid_rows_query = session.query(*select_list).filter(~bad_rows_subquery.exists())
         invalid_row_count = invalid_rows_query.count()
         if invalid_row_count <= 0:
             continue
 
-        dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version)
+        dangling_table_name = _format_airflow_moved_table_name(
+            source_table=source_table.name, version=change_version
+        )
         if dangling_table_name in existing_table_names:
             yield _format_dangling_error(
                 source_table=source_table.name,
                 target_table=dangling_table_name,
                 invalid_count=invalid_row_count,
-                reason=f"without a corresponding {dagrun_table.name} row",
+                reason=f"without a corresponding {bad_ref_cfg.ref_table} row",
             )
             errored = True
             continue
@@ -1123,7 +1190,7 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str
             session,
             source_table,
             invalid_rows_query,
-            exists_subquery,
+            bad_rows_subquery,
             dangling_table_name,
         )
 
@@ -1144,7 +1211,7 @@ def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
         check_conn_id_duplicates,
         check_conn_type_null,
         check_run_id_null,
-        check_task_tables_without_matching_dagruns,
+        check_bad_references,
     )
     for check_fn in check_functions:
         yield from check_fn(session=session)