You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2021/11/05 23:24:57 UTC

[airflow] 05/06: Fix moving of dangling TaskInstance rows for SQL Server (#19425)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d478f939b48bb5df5207fe182ad5f77a87051234
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Fri Nov 5 15:30:29 2021 +0000

    Fix moving of dangling TaskInstance rows for SQL Server (#19425)
    
    SQL server uses a different syntax for creating a table from a select to
    the other DBs we support.
    
    And to make the "where_query" reusable across all DBs (SQL Server
    doesn't support `WHERE (col1,col2) IN ...`) the delete has been
    re-written too.
    
    (cherry picked from commit 3c45c12ed12a014978c0f1c9ca2fc3f32416b144)
---
 airflow/utils/db.py | 78 ++++++++++++++++++++++++++++++++++++++++-------------
 1 file changed, 60 insertions(+), 18 deletions(-)

diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 2da0722..2fcab15 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -709,10 +709,55 @@ def _format_dangling_error(source_table, target_table, invalid_count, reason):
     )
 
 
-def _move_dangling_run_data_to_new_table(session, source_table, target_table):
+def _move_dangling_run_data_to_new_table(session, source_table: "Table", target_table_name: str):
     where_clause = "where dag_id is null or run_id is null or execution_date is null"
-    session.execute(text(f"create table {target_table} as select * from {source_table} {where_clause}"))
-    session.execute(text(f"delete from {source_table} {where_clause}"))
+    _move_dangling_table(session, source_table, target_table_name, where_clause)
+
+
+def _move_dangling_table(session, source_table: "Table", target_table_name: str, where_clause: str):
+    dialect_name = session.get_bind().dialect.name
+
+    delete_where = " AND ".join(
+        f"{source_table.name}.{c.name} = d.{c.name}" for c in source_table.primary_key.columns
+    )
+    if dialect_name == "mssql":
+        session.execute(
+            text(f"select source.* into {target_table_name} from {source_table} as source {where_clause}")
+        )
+        session.execute(
+            text(
+                f"delete from {source_table} from {source_table} join {target_table_name} AS d ON "
+                + delete_where
+            )
+        )
+    else:
+        # Postgres, MySQL and SQLite all have the same CREATE TABLE a AS SELECT ... syntax
+        session.execute(
+            text(
+                f"create table {target_table_name} as select source.* from {source_table} as source "
+                + where_clause
+            )
+        )
+
+        # But different join-delete syntax.
+        if dialect_name == "mysql":
+            session.execute(
+                text(
+                    f"delete {source_table} from {source_table} join {target_table_name} as d on "
+                    + delete_where
+                )
+            )
+        elif dialect_name == "sqlite":
+            session.execute(
+                text(
+                    f"delete from {source_table} where ROWID in (select {source_table}.ROWID from "
+                    f"{source_table} as source join {target_table_name} as d on {delete_where})"
+                )
+            )
+        else:
+            session.execute(
+                text(f"delete from {source_table} using {target_table_name} as d where {delete_where}")
+            )
 
 
 def check_run_id_null(session) -> Iterable[str]:
@@ -720,7 +765,7 @@ def check_run_id_null(session) -> Iterable[str]:
 
     metadata = sqlalchemy.schema.MetaData(session.bind)
     try:
-        metadata.reflect(only=[DagRun.__tablename__])
+        metadata.reflect(only=[DagRun.__tablename__], extend_existing=True, resolve_fks=False)
     except exc.InvalidRequestError:
         # Table doesn't exist -- empty db
         return
@@ -745,21 +790,16 @@ def check_run_id_null(session) -> Iterable[str]:
                 reason="with a NULL dag_id, run_id, or execution_date",
             )
             return
-        _move_dangling_run_data_to_new_table(session, dagrun_table.name, dagrun_dangling_table_name)
+        _move_dangling_run_data_to_new_table(session, dagrun_table, dagrun_dangling_table_name)
 
 
-def _move_dangling_task_data_to_new_table(session, source_table, target_table):
-    where_clause = f"""
-        where (task_id, dag_id, execution_date) IN (
-            select source.task_id, source.dag_id, source.execution_date
-            from {source_table} as source
-            left join dag_run as dr
-            on (source.dag_id = dr.dag_id and source.execution_date = dr.execution_date)
-            where dr.id is null
-        )
+def _move_dangling_task_data_to_new_table(session, source_table: "Table", target_table_name: str):
+    where_clause = """
+        left join dag_run as dr
+        on (source.dag_id = dr.dag_id and source.execution_date = dr.execution_date)
+        where dr.id is null
     """
-    session.execute(text(f"create table {target_table} as select * from {source_table} {where_clause}"))
-    session.execute(text(f"delete from {source_table} {where_clause}"))
+    _move_dangling_table(session, source_table, target_table_name, where_clause)
 
 
 def check_task_tables_without_matching_dagruns(session) -> Iterable[str]:
@@ -770,7 +810,7 @@ def check_task_tables_without_matching_dagruns(session) -> Iterable[str]:
     models_to_dagrun = [TaskInstance, TaskReschedule]
     for model in models_to_dagrun + [DagRun]:
         try:
-            metadata.reflect(only=[model.__tablename__])
+            metadata.reflect(only=[model.__tablename__], extend_existing=True, resolve_fks=False)
         except exc.InvalidRequestError:
             # Table doesn't exist, but try the other ones incase the user is upgrading from an _old_ DB
             # version
@@ -821,7 +861,7 @@ def check_task_tables_without_matching_dagruns(session) -> Iterable[str]:
             )
             errored = True
             continue
-        _move_dangling_task_data_to_new_table(session, source_table.name, dangling_table_name)
+        _move_dangling_task_data_to_new_table(session, source_table, dangling_table_name)
 
     if errored:
         session.rollback()
@@ -842,6 +882,8 @@ def _check_migration_errors(session=None) -> Iterable[str]:
         check_task_tables_without_matching_dagruns,
     ):
         yield from check_fn(session)
+        # Ensure there is no "active" transaction. Seems odd, but without this MSSQL can hang
+        session.commit()
 
 
 @provide_session