You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/01/24 12:43:46 UTC

[airflow] branch main updated: Refactor dangling row check to use SQLA queries (#19808)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new cecd4c8  Refactor dangling row check to use SQLA queries (#19808)
cecd4c8 is described below

commit cecd4c8059e04c5be0cfec67ebb576d08d83f7b9
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Mon Jan 24 12:43:09 2022 +0000

    Refactor dangling row check to use SQLA queries (#19808)
    
    This is a prepaoratory refactor to have the move dangling rows
    pre-upgrade check make better use of the SQLA Queries -- this is needed
    because in a future PR we will add a check for dangling XCom rows, and
    that will need to conditionally join against DagRun to get
    execution_date (depending on if it is run pre- or post-2.2).
    
    This has been tested with Postgres 9.6, SQLite, MSSQL 2017 and MySQL 5.7
    
    codespell didn't like `froms` as it thinks it is a typo of forms, and
    most other cases it would be, except here. Codespell doesn't currently
    have a method of ignoring a _single_ line without ignoring the word
    everywhere (which we don't want to do) so I have to ignore the exact
    _line_. Sad panda
---
 .codespellignorelines   |   2 +
 .pre-commit-config.yaml |   1 +
 .rat-excludes           |   1 +
 airflow/utils/db.py     | 148 ++++++++++++++++++++++++------------------------
 4 files changed, 77 insertions(+), 75 deletions(-)

diff --git a/.codespellignorelines b/.codespellignorelines
new file mode 100644
index 0000000..7b8a9bf
--- /dev/null
+++ b/.codespellignorelines
@@ -0,0 +1,2 @@
+            f"DELETE {source_table} FROM { ', '.join(_from_name(tbl) for tbl in stmt.froms) }"
+        for frm in source_query.selectable.froms:
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7aebe58..11bb0b8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -280,6 +280,7 @@ repos:
         args:
           - --ignore-words=docs/spelling_wordlist.txt
           - --skip=docs/*/commits.rst,airflow/providers/*/*.rst,*.lock,INTHEWILD.md,*.min.js,docs/apache-airflow/pipeline_example.csv
+          - --exclude-file=.codespellignorelines
   - repo: local
     hooks:
       - id: autoflake
diff --git a/.rat-excludes b/.rat-excludes
index 3372441..94a9731 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -10,6 +10,7 @@
 .coverage
 .coveragerc
 .codecov.yml
+.codespellignorelines
 .eslintrc
 .eslintignore
 .flake8
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 600925a..9e8e326 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -22,7 +22,7 @@ import os
 import sys
 import time
 from tempfile import gettempdir
-from typing import Any, Callable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Tuple
 
 from sqlalchemy import Table, exc, func, inspect, or_, text
 from sqlalchemy.orm.session import Session
@@ -62,6 +62,10 @@ from airflow.utils import helpers
 from airflow.utils.session import NEW_SESSION, create_session, provide_session  # noqa: F401
 from airflow.version import version
 
+if TYPE_CHECKING:
+    from sqlalchemy.orm import Query
+
+
 log = logging.getLogger(__name__)
 
 
@@ -799,69 +803,7 @@ def _format_dangling_error(source_table, target_table, invalid_count, reason):
     )
 
 
-def _move_dangling_run_data_to_new_table(session: Session, source_table: "Table", target_table_name: str):
-    where_clause = "where dag_id is null or run_id is null or execution_date is null"
-    _move_dangling_table(session, source_table, target_table_name, where_clause)
-
-
-def _move_dangling_table(session, source_table: "Table", target_table_name: str, where_clause: str):
-    dialect_name = session.get_bind().dialect.name
-
-    delete_where = " AND ".join(
-        f"{source_table.name}.{c.name} = d.{c.name}" for c in source_table.primary_key.columns
-    )
-    if dialect_name == "mssql":
-        session.execute(
-            text(f"select source.* into {target_table_name} from {source_table} as source {where_clause}")
-        )
-        session.execute(
-            text(
-                f"delete from {source_table} from {source_table} join {target_table_name} AS d ON "
-                + delete_where
-            )
-        )
-    else:
-        if dialect_name == "mysql":
-            # CREATE TABLE AS SELECT must be broken into two queries for  MySQL as the single query
-            # approach fails when replication is enabled ("Statement violates GTID consistency")```
-            session.execute(text(f"create table {target_table_name} like {source_table}"))
-            session.execute(
-                text(
-                    f"INSERT INTO {target_table_name} select source.* from {source_table} as source "
-                    + where_clause
-                )
-            )
-        # Postgres and SQLite have the same CREATE TABLE a AS SELECT ... syntax
-        else:
-            session.execute(
-                text(
-                    f"create table {target_table_name} as select source.* from {source_table} as source "
-                    + where_clause
-                )
-            )
-
-        # But different join-delete syntax.
-        if dialect_name == "mysql":
-            session.execute(
-                text(
-                    f"delete {source_table} from {source_table} join {target_table_name} as d on "
-                    + delete_where
-                )
-            )
-        elif dialect_name == "sqlite":
-            session.execute(
-                text(
-                    f"delete from {source_table} where ROWID in (select {source_table}.ROWID from "
-                    f"{source_table} as source join {target_table_name} as d on {delete_where})"
-                )
-            )
-        else:
-            session.execute(
-                text(f"delete from {source_table} using {target_table_name} as d where {delete_where}")
-            )
-
-
-def check_run_id_null(session: Session) -> Iterable[str]:
+def check_run_id_null(session) -> Iterable[str]:
     import sqlalchemy.schema
 
     metadata = sqlalchemy.schema.MetaData(session.bind)
@@ -891,16 +833,67 @@ def check_run_id_null(session: Session) -> Iterable[str]:
                 reason="with a NULL dag_id, run_id, or execution_date",
             )
             return
-        _move_dangling_run_data_to_new_table(session, dagrun_table, dagrun_dangling_table_name)
+        _move_dangling_data_to_new_table(
+            session,
+            dagrun_table,
+            dagrun_table.select(invalid_dagrun_filter),
+            dagrun_dangling_table_name,
+        )
 
 
-def _move_dangling_task_data_to_new_table(session, source_table: "Table", target_table_name: str):
-    where_clause = """
-        left join dag_run as dr
-        on (source.dag_id = dr.dag_id and source.execution_date = dr.execution_date)
-        where dr.id is null
-    """
-    _move_dangling_table(session, source_table, target_table_name, where_clause)
+def _move_dangling_data_to_new_table(
+    session, source_table: "Table", source_query: "Query", target_table_name: str
+):
+    from sqlalchemy import column, select, table
+    from sqlalchemy.sql.selectable import Join
+
+    bind = session.get_bind()
+    dialect_name = bind.dialect.name
+
+    # First: Create moved rows from new table
+    if dialect_name == "mssql":
+        cte = source_query.cte("source")
+        moved_data_tbl = table(target_table_name, *(column(c.name) for c in cte.columns))
+        ins = moved_data_tbl.insert().from_select(list(cte.columns), select([cte]))
+
+        stmt = ins.compile(bind=session.get_bind())
+        cte_sql = stmt.ctes[cte]
+
+        session.execute(f"WITH {cte_sql} SELECT source.* INTO {target_table_name} FROM source")
+    else:
+        # Postgres, MySQL and SQLite all support the same "create as select"
+        session.execute(
+            f"CREATE TABLE {target_table_name} AS {source_query.selectable.compile(bind=session.get_bind())}"
+        )
+
+    # Second: Now delete rows we've moved
+    try:
+        clause = source_query.whereclause
+    except AttributeError:
+        clause = source_query._whereclause
+
+    if dialect_name == "sqlite":
+        subq = source_query.selectable.with_only_columns([text(f'{source_table}.ROWID')])
+        delete = source_table.delete().where(column('ROWID').in_(subq))
+    elif dialect_name in ("mysql", "mssql"):
+        # This is not foolproof! But it works for the limited queries (with no params) that we use here
+        stmt = source_query.selectable
+
+        def _from_name(from_) -> str:
+            if isinstance(from_, Join):
+                return str(from_.compile(bind=bind))
+            return str(from_)
+
+        delete = (
+            f"DELETE {source_table} FROM { ', '.join(_from_name(tbl) for tbl in stmt.froms) }"
+            f" WHERE {clause.compile(bind=bind)}"
+        )
+    else:
+        for frm in source_query.selectable.froms:
+            if hasattr(frm, 'onclause'):  # Table, or JOIN?
+                clause &= frm.onclause
+        delete = source_table.delete(clause)
+    session.execute(delete)
 
 
 def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str]:
@@ -945,12 +938,12 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str
             source_table.c.dag_id == dagrun_table.c.dag_id,
             source_table.c.execution_date == dagrun_table.c.execution_date,
         )
-        invalid_row_count = (
+        invalid_rows_query = (
             session.query(source_table.c.dag_id, source_table.c.task_id, source_table.c.execution_date)
             .select_from(outerjoin(source_table, dagrun_table, source_to_dag_run_join_cond))
             .filter(dagrun_table.c.dag_id.is_(None))
-            .count()
         )
+        invalid_row_count = invalid_rows_query.count()
         if invalid_row_count <= 0:
             continue
 
@@ -964,7 +957,12 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str
             )
             errored = True
             continue
-        _move_dangling_task_data_to_new_table(session, source_table, dangling_table_name)
+        _move_dangling_data_to_new_table(
+            session,
+            source_table,
+            invalid_rows_query.with_entities(*source_table.columns),
+            dangling_table_name,
+        )
 
     if errored:
         session.rollback()