You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/08/15 18:44:38 UTC

[airflow] 02/45: Don't rely on current ORM structure for db clean command (#23574)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 658f5abe60a98049be8ea904b9855615c8bb5d2f
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Fri Jun 17 12:46:16 2022 -0700

    Don't rely on current ORM structure for db clean command (#23574)
    
    For command DB clean, by not relying on the ORM models, we will be able to use the command even when the metadatabase is not yet upgraded to the version of Airflow you have installed.
    
    Additionally we archive all rows before deletion.
    
    (cherry picked from commit 95bd6b71cc9f5da377e272707f7b68000d980939)
---
 airflow/cli/cli_parser.py             |   6 +
 airflow/cli/commands/db_command.py    |   1 +
 airflow/utils/db.py                   |  17 ++-
 airflow/utils/db_cleanup.py           | 270 ++++++++++++++++++++--------------
 docs/apache-airflow/usage-cli.rst     |   2 +
 newsfragments/23574.feature.rst       |   1 +
 tests/cli/commands/test_db_command.py |  34 ++++-
 tests/test_utils/db.py                |  10 +-
 tests/utils/test_db_cleanup.py        |  44 +++++-
 9 files changed, 257 insertions(+), 128 deletions(-)

diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index 0789b4ee88..d494aa0f06 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -434,6 +434,11 @@ ARG_DB_DRY_RUN = Arg(
     help="Perform a dry run",
     action="store_true",
 )
+ARG_DB_SKIP_ARCHIVE = Arg(
+    ("--skip-archive",),
+    help="Don't preserve purged records in an archive table.",
+    action="store_true",
+)
 
 
 # pool
@@ -1454,6 +1459,7 @@ DB_COMMANDS = (
             ARG_DB_CLEANUP_TIMESTAMP,
             ARG_VERBOSE,
             ARG_YES,
+            ARG_DB_SKIP_ARCHIVE,
         ),
     ),
 )
diff --git a/airflow/cli/commands/db_command.py b/airflow/cli/commands/db_command.py
index c9201ad59b..5f6a84c8a4 100644
--- a/airflow/cli/commands/db_command.py
+++ b/airflow/cli/commands/db_command.py
@@ -198,4 +198,5 @@ def cleanup_tables(args):
         clean_before_timestamp=args.clean_before_timestamp,
         verbose=args.verbose,
         confirm=not args.yes,
+        skip_archive=args.skip_archive,
     )
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 7bdd33fb93..a86222e3b8 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -870,7 +870,7 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]:
         )
 
 
-def reflect_tables(tables: List[Union[Base, str]], session):
+def reflect_tables(tables: Optional[List[Union[Base, str]]], session):
     """
     When running checks prior to upgrades, we use reflection to determine current state of the
     database.
@@ -881,12 +881,15 @@ def reflect_tables(tables: List[Union[Base, str]], session):
 
     metadata = sqlalchemy.schema.MetaData(session.bind)
 
-    for tbl in tables:
-        try:
-            table_name = tbl if isinstance(tbl, str) else tbl.__tablename__
-            metadata.reflect(only=[table_name], extend_existing=True, resolve_fks=False)
-        except exc.InvalidRequestError:
-            continue
+    if tables is None:
+        metadata.reflect(resolve_fks=False)
+    else:
+        for tbl in tables:
+            try:
+                table_name = tbl if isinstance(tbl, str) else tbl.__tablename__
+                metadata.reflect(only=[table_name], extend_existing=True, resolve_fks=False)
+            except exc.InvalidRequestError:
+                continue
     return metadata
 
 
diff --git a/airflow/utils/db_cleanup.py b/airflow/utils/db_cleanup.py
index b02d08503f..f77ae52a60 100644
--- a/airflow/utils/db_cleanup.py
+++ b/airflow/utils/db_cleanup.py
@@ -21,38 +21,24 @@ This module took inspiration from the community maintenance dag
 """
 
 import logging
-from contextlib import AbstractContextManager
+from contextlib import contextmanager
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional
 
 from pendulum import DateTime
-from sqlalchemy import and_, false, func
+from sqlalchemy import and_, column, false, func, table, text
 from sqlalchemy.exc import OperationalError, ProgrammingError
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.orm import Query, Session, aliased
+from sqlalchemy.sql.expression import ClauseElement, Executable, tuple_
 
 from airflow.cli.simple_table import AirflowConsole
-from airflow.jobs.base_job import BaseJob
-from airflow.models import (
-    Base,
-    DagModel,
-    DagRun,
-    DbCallbackRequest,
-    ImportError as models_ImportError,
-    Log,
-    RenderedTaskInstanceFields,
-    SensorInstance,
-    SlaMiss,
-    TaskFail,
-    TaskInstance,
-    TaskReschedule,
-    XCom,
-)
+from airflow.models import Base
 from airflow.utils import timezone
+from airflow.utils.db import reflect_tables
 from airflow.utils.session import NEW_SESSION, provide_session
 
-if TYPE_CHECKING:
-    from sqlalchemy.orm import Query, Session
-    from sqlalchemy.orm.attributes import InstrumentedAttribute
-    from sqlalchemy.sql.schema import Column
+logger = logging.getLogger(__file__)
 
 
 @dataclass
@@ -60,115 +46,155 @@ class _TableConfig:
     """
     Config class for performing cleanup on a table
 
-    :param orm_model: the table
-    :param recency_column: date column to filter by
+    :param table_name: the table
+    :param extra_columns: any columns besides recency_column_name that we'll need in queries
+    :param recency_column_name: date column to filter by
     :param keep_last: whether the last record should be kept even if it's older than clean_before_timestamp
     :param keep_last_filters: the "keep last" functionality will preserve the most recent record
         in the table.  to ignore certain records even if they are the latest in the table, you can
         supply additional filters here (e.g. externally triggered dag runs)
     :param keep_last_group_by: if keeping the last record, can keep the last record for each group
-    :param warn_if_missing: If True, then we'll suppress "table missing" exception and log a warning.
-        If False then the exception will go uncaught.
     """
 
-    orm_model: Base
-    recency_column: Union["Column", "InstrumentedAttribute"]
+    table_name: str
+    recency_column_name: str
+    extra_columns: Optional[List[str]] = None
     keep_last: bool = False
     keep_last_filters: Optional[Any] = None
     keep_last_group_by: Optional[Any] = None
-    warn_if_missing: bool = False
+
+    def __post_init__(self):
+        self.recency_column = column(self.recency_column_name)
+        self.orm_model: Base = table(
+            self.table_name, *[column(x) for x in self.extra_columns or []], self.recency_column
+        )
 
     def __lt__(self, other):
-        return self.orm_model.__tablename__ < other.orm_model.__tablename__
+        return self.table_name < other.table_name
 
     @property
     def readable_config(self):
         return dict(
-            table=self.orm_model.__tablename__,
+            table=self.orm_model.name,
             recency_column=str(self.recency_column),
             keep_last=self.keep_last,
             keep_last_filters=[str(x) for x in self.keep_last_filters] if self.keep_last_filters else None,
             keep_last_group_by=str(self.keep_last_group_by),
-            warn_if_missing=str(self.warn_if_missing),
         )
 
 
 config_list: List[_TableConfig] = [
-    _TableConfig(orm_model=BaseJob, recency_column=BaseJob.latest_heartbeat),
-    _TableConfig(orm_model=DagModel, recency_column=DagModel.last_parsed_time),
+    _TableConfig(table_name='job', recency_column_name='latest_heartbeat'),
+    _TableConfig(table_name='dag', recency_column_name='last_parsed_time'),
     _TableConfig(
-        orm_model=DagRun,
-        recency_column=DagRun.start_date,
+        table_name='dag_run',
+        recency_column_name='start_date',
+        extra_columns=['dag_id', 'external_trigger'],
         keep_last=True,
-        keep_last_filters=[DagRun.external_trigger == false()],
-        keep_last_group_by=DagRun.dag_id,
-    ),
-    _TableConfig(orm_model=models_ImportError, recency_column=models_ImportError.timestamp),
-    _TableConfig(orm_model=Log, recency_column=Log.dttm),
-    _TableConfig(
-        orm_model=RenderedTaskInstanceFields, recency_column=RenderedTaskInstanceFields.execution_date
+        keep_last_filters=[column('external_trigger') == false()],
+        keep_last_group_by=['dag_id'],
     ),
-    _TableConfig(
-        orm_model=SensorInstance, recency_column=SensorInstance.updated_at
-    ),  # TODO: add FK to task instance / dag so we can remove here
-    _TableConfig(orm_model=SlaMiss, recency_column=SlaMiss.timestamp),
-    _TableConfig(orm_model=TaskFail, recency_column=TaskFail.start_date),
-    _TableConfig(orm_model=TaskInstance, recency_column=TaskInstance.start_date),
-    _TableConfig(orm_model=TaskReschedule, recency_column=TaskReschedule.start_date),
-    _TableConfig(orm_model=XCom, recency_column=XCom.timestamp),
-    _TableConfig(orm_model=DbCallbackRequest, recency_column=XCom.timestamp),
+    _TableConfig(table_name='import_error', recency_column_name='timestamp'),
+    _TableConfig(table_name='log', recency_column_name='dttm'),
+    _TableConfig(table_name='rendered_task_instance_fields', recency_column_name='execution_date'),
+    _TableConfig(table_name='sensor_instance', recency_column_name='updated_at'),
+    _TableConfig(table_name='sla_miss', recency_column_name='timestamp'),
+    _TableConfig(table_name='task_fail', recency_column_name='start_date'),
+    _TableConfig(table_name='task_instance', recency_column_name='start_date'),
+    _TableConfig(table_name='task_reschedule', recency_column_name='start_date'),
+    _TableConfig(table_name='xcom', recency_column_name='timestamp'),
+    _TableConfig(table_name='callback_request', recency_column_name='created_at'),
+    _TableConfig(table_name='celery_taskmeta', recency_column_name='date_done'),
+    _TableConfig(table_name='celery_tasksetmeta', recency_column_name='date_done'),
 ]
-try:
-    from celery.backends.database.models import Task, TaskSet
-
-    config_list.extend(
-        [
-            _TableConfig(orm_model=Task, recency_column=Task.date_done, warn_if_missing=True),
-            _TableConfig(orm_model=TaskSet, recency_column=TaskSet.date_done, warn_if_missing=True),
-        ]
-    )
-except ImportError:
-    pass
 
-config_dict: Dict[str, _TableConfig] = {x.orm_model.__tablename__: x for x in sorted(config_list)}
+config_dict: Dict[str, _TableConfig] = {x.orm_model.name: x for x in sorted(config_list)}
 
 
-def _print_entities(*, query: "Query", print_rows=False):
+def _check_for_rows(*, query: "Query", print_rows=False):
     num_entities = query.count()
     print(f"Found {num_entities} rows meeting deletion criteria.")
-    if not print_rows:
-        return
-    max_rows_to_print = 100
-    if num_entities > 0:
-        print(f"Printing first {max_rows_to_print} rows.")
-    logger.debug("print entities query: %s", query)
-    for entry in query.limit(max_rows_to_print):
-        print(entry.__dict__)
+    if print_rows:
+        max_rows_to_print = 100
+        if num_entities > 0:
+            print(f"Printing first {max_rows_to_print} rows.")
+        logger.debug("print entities query: %s", query)
+        for entry in query.limit(max_rows_to_print):
+            print(entry.__dict__)
+    return num_entities
+
 
+def _do_delete(*, query, orm_model, skip_archive, session):
+    import re
+    from datetime import datetime
 
-def _do_delete(*, query, session):
     print("Performing Delete...")
     # using bulk delete
-    query.delete(synchronize_session=False)
+    # create a new table and copy the rows there
+    timestamp_str = re.sub(r'[^\d]', '', datetime.utcnow().isoformat())[:14]
+    target_table_name = f'_airflow_deleted__{orm_model.name}__{timestamp_str}'
+    print(f"Moving data to table {target_table_name}")
+    stmt = CreateTableAs(target_table_name, query.selectable)
+    logger.debug("ctas query:\n%s", stmt.compile())
+    session.execute(stmt)
+    session.commit()
+
+    # delete the rows from the old table
+    metadata = reflect_tables([orm_model.name, target_table_name], session)
+    source_table = metadata.tables[orm_model.name]
+    target_table = metadata.tables[target_table_name]
+    logger.debug("rows moved; purging from %s", source_table.name)
+    bind = session.get_bind()
+    dialect_name = bind.dialect.name
+    if dialect_name == 'sqlite':
+        pk_cols = source_table.primary_key.columns
+        delete = source_table.delete().where(
+            tuple_(*pk_cols).in_(
+                session.query(*[target_table.c[x.name] for x in source_table.primary_key.columns]).subquery()
+            )
+        )
+    else:
+        delete = source_table.delete().where(
+            and_(col == target_table.c[col.name] for col in source_table.primary_key.columns)
+        )
+    logger.debug("delete statement:\n%s", delete.compile())
+    session.execute(delete)
+    session.commit()
+    if skip_archive:
+        target_table.drop()
     session.commit()
     print("Finished Performing Delete")
 
 
-def _subquery_keep_last(*, recency_column, keep_last_filters, keep_last_group_by, session):
-    subquery = session.query(func.max(recency_column))
+def _subquery_keep_last(*, recency_column, keep_last_filters, group_by_columns, max_date_colname, session):
+    subquery = session.query(*group_by_columns, func.max(recency_column).label(max_date_colname))
 
     if keep_last_filters is not None:
         for entry in keep_last_filters:
             subquery = subquery.filter(entry)
 
-    if keep_last_group_by is not None:
-        subquery = subquery.group_by(keep_last_group_by)
+    if group_by_columns is not None:
+        subquery = subquery.group_by(*group_by_columns)
+
+    return subquery.subquery(name='latest')
 
-    # We nest this subquery to work around a MySQL "table specified twice" issue
-    # See https://github.com/teamclairvoyant/airflow-maintenance-dags/issues/41
-    # and https://github.com/teamclairvoyant/airflow-maintenance-dags/pull/57/files.
-    subquery = subquery.from_self()
-    return subquery
+
+class CreateTableAs(Executable, ClauseElement):
+    """Custom sqlalchemy clause element for CTAS operations."""
+
+    def __init__(self, name, query):
+        self.name = name
+        self.query = query
+
+
+@compiles(CreateTableAs)
+def _compile_create_table_as__other(element, compiler, **kw):
+    return f"CREATE TABLE {element.name} AS {compiler.process(element.query)}"
+
+
+@compiles(CreateTableAs, 'mssql')
+def _compile_create_table_as__mssql(element, compiler, **kw):
+    return f"WITH cte AS ( {compiler.process(element.query)} ) SELECT * INTO {element.name} FROM cte"
 
 
 def _build_query(
@@ -182,23 +208,33 @@ def _build_query(
     session,
     **kwargs,
 ):
-    query = session.query(orm_model)
-    conditions = [recency_column < clean_before_timestamp]
+    base_table_alias = 'base'
+    base_table = aliased(orm_model, name=base_table_alias)
+    query = session.query(base_table).with_entities(text(f"{base_table_alias}.*"))
+    base_table_recency_col = base_table.c[recency_column.name]
+    conditions = [base_table_recency_col < clean_before_timestamp]
     if keep_last:
+        max_date_col_name = 'max_date_per_group'
+        group_by_columns = [column(x) for x in keep_last_group_by]
         subquery = _subquery_keep_last(
             recency_column=recency_column,
             keep_last_filters=keep_last_filters,
-            keep_last_group_by=keep_last_group_by,
+            group_by_columns=group_by_columns,
+            max_date_colname=max_date_col_name,
             session=session,
         )
-        conditions.append(recency_column.notin_(subquery))
+        query = query.select_from(base_table).outerjoin(
+            subquery,
+            and_(
+                *[base_table.c[x] == subquery.c[x] for x in keep_last_group_by],
+                base_table_recency_col == column(max_date_col_name),
+            ),
+        )
+        conditions.append(column(max_date_col_name).is_(None))
     query = query.filter(and_(*conditions))
     return query
 
 
-logger = logging.getLogger(__file__)
-
-
 def _cleanup_table(
     *,
     orm_model,
@@ -209,12 +245,13 @@ def _cleanup_table(
     clean_before_timestamp,
     dry_run=True,
     verbose=False,
+    skip_archive=False,
     session=None,
     **kwargs,
 ):
     print()
     if dry_run:
-        print(f"Performing dry run for table {orm_model.__tablename__!r}")
+        print(f"Performing dry run for table {orm_model.name}")
     query = _build_query(
         orm_model=orm_model,
         recency_column=recency_column,
@@ -224,12 +261,14 @@ def _cleanup_table(
         clean_before_timestamp=clean_before_timestamp,
         session=session,
     )
+    logger.debug("old rows query:\n%s", query.selectable.compile())
+    print(f"Checking table {orm_model.name}")
+    num_rows = _check_for_rows(query=query, print_rows=False)
 
-    _print_entities(query=query, print_rows=False)
+    if num_rows and not dry_run:
+        _do_delete(query=query, orm_model=orm_model, skip_archive=skip_archive, session=session)
 
-    if not dry_run:
-        _do_delete(query=query, session=session)
-        session.commit()
+    session.commit()
 
 
 def _confirm_delete(*, date: DateTime, tables: List[str]):
@@ -251,19 +290,20 @@ def _print_config(*, configs: Dict[str, _TableConfig]):
     AirflowConsole().print_as_table(data=data)
 
 
-class _warn_if_missing(AbstractContextManager):
-    def __init__(self, table, suppress):
-        self.table = table
-        self.suppress = suppress
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, exctype, excinst, exctb):
-        caught_error = exctype is not None and issubclass(exctype, (OperationalError, ProgrammingError))
-        if caught_error:
-            logger.warning("Table %r not found.  Skipping.", self.table)
-        return caught_error
+@contextmanager
+def _suppress_with_logging(table, session):
+    """
+    Suppresses errors but logs them.
+    Also stores the exception instance so it can be referred to after exiting context.
+    """
+    try:
+        yield
+    except (OperationalError, ProgrammingError):
+        logger.warning("Encountered error when attempting to clean table '%s'. ", table)
+        logger.debug("Traceback for table '%s'", table, exc_info=True)
+        if session.is_active:
+            logger.debug('Rolling back transaction')
+            session.rollback()
 
 
 @provide_session
@@ -274,6 +314,7 @@ def run_cleanup(
     dry_run: bool = False,
     verbose: bool = False,
     confirm: bool = True,
+    skip_archive: bool = False,
     session: 'Session' = NEW_SESSION,
 ):
     """
@@ -292,6 +333,7 @@ def run_cleanup(
     :param dry_run: If true, print rows meeting deletion criteria
     :param verbose: If true, may provide more detailed output.
     :param confirm: Require user input to confirm before processing deletions.
+    :param skip_archive: Set to True if you don't want the purged rows preservied in an archive table.
     :param session: Session representing connection to the metadata database.
     """
     clean_before_timestamp = timezone.coerce_datetime(clean_before_timestamp)
@@ -306,12 +348,18 @@ def run_cleanup(
         _print_config(configs=effective_config_dict)
     if not dry_run and confirm:
         _confirm_delete(date=clean_before_timestamp, tables=list(effective_config_dict.keys()))
+    existing_tables = reflect_tables(tables=None, session=session).tables
     for table_name, table_config in effective_config_dict.items():
-        with _warn_if_missing(table_name, table_config.warn_if_missing):
+        if table_name not in existing_tables:
+            logger.warning("Table %s not found.  Skipping.", table_name)
+            continue
+        with _suppress_with_logging(table_name, session):
             _cleanup_table(
                 clean_before_timestamp=clean_before_timestamp,
                 dry_run=dry_run,
                 verbose=verbose,
                 **table_config.__dict__,
+                skip_archive=skip_archive,
                 session=session,
             )
+            session.commit()
diff --git a/docs/apache-airflow/usage-cli.rst b/docs/apache-airflow/usage-cli.rst
index 0e7b1b5455..c14efacb1d 100644
--- a/docs/apache-airflow/usage-cli.rst
+++ b/docs/apache-airflow/usage-cli.rst
@@ -215,6 +215,8 @@ You can optionally provide a list of tables to perform deletes on. If no list of
 
 You can use the ``--dry-run`` option to print the row counts in the primary tables to be cleaned.
 
+By default, ``db clean`` will archive purged rows in tables of the form ``_airflow_deleted__<table>__<timestamp>``.  If you don't want the data preserved in this way, you may supply argument ``--skip-archive``.
+
 Beware cascading deletes
 ^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/newsfragments/23574.feature.rst b/newsfragments/23574.feature.rst
new file mode 100644
index 0000000000..805b7b18bd
--- /dev/null
+++ b/newsfragments/23574.feature.rst
@@ -0,0 +1 @@
+Command ``airflow db clean`` now archives data before purging.
diff --git a/tests/cli/commands/test_db_command.py b/tests/cli/commands/test_db_command.py
index 125e5d7c3e..e6e93f6a1c 100644
--- a/tests/cli/commands/test_db_command.py
+++ b/tests/cli/commands/test_db_command.py
@@ -293,6 +293,7 @@ class TestCLIDBClean:
             clean_before_timestamp=pendulum.parse(timestamp, tz=timezone),
             verbose=False,
             confirm=False,
+            skip_archive=False,
         )
 
     @pytest.mark.parametrize('timezone', ['UTC', 'Europe/Berlin', 'America/Los_Angeles'])
@@ -312,13 +313,14 @@ class TestCLIDBClean:
             clean_before_timestamp=pendulum.parse(timestamp),
             verbose=False,
             confirm=False,
+            skip_archive=False,
         )
 
     @pytest.mark.parametrize('confirm_arg, expected', [(['-y'], False), ([], True)])
     @patch('airflow.cli.commands.db_command.run_cleanup')
     def test_confirm(self, run_cleanup_mock, confirm_arg, expected):
         """
-        When tz included in the string then default timezone should not be used.
+        When ``-y`` provided, ``confirm`` should be false.
         """
         args = self.parser.parse_args(
             [
@@ -337,6 +339,33 @@ class TestCLIDBClean:
             clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
             verbose=False,
             confirm=expected,
+            skip_archive=False,
+        )
+
+    @pytest.mark.parametrize('extra_arg, expected', [(['--skip-archive'], True), ([], False)])
+    @patch('airflow.cli.commands.db_command.run_cleanup')
+    def test_skip_archive(self, run_cleanup_mock, extra_arg, expected):
+        """
+        When ``--skip-archive`` provided, ``skip_archive`` should be True (False otherwise).
+        """
+        args = self.parser.parse_args(
+            [
+                'db',
+                'clean',
+                '--clean-before-timestamp',
+                '2021-01-01',
+                *extra_arg,
+            ]
+        )
+        db_command.cleanup_tables(args)
+
+        run_cleanup_mock.assert_called_once_with(
+            table_names=None,
+            dry_run=False,
+            clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
+            verbose=False,
+            confirm=True,
+            skip_archive=expected,
         )
 
     @pytest.mark.parametrize('dry_run_arg, expected', [(['--dry-run'], True), ([], False)])
@@ -362,6 +391,7 @@ class TestCLIDBClean:
             clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
             verbose=False,
             confirm=True,
+            skip_archive=False,
         )
 
     @pytest.mark.parametrize(
@@ -389,6 +419,7 @@ class TestCLIDBClean:
             clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
             verbose=False,
             confirm=True,
+            skip_archive=False,
         )
 
     @pytest.mark.parametrize('extra_args, expected', [(['--verbose'], True), ([], False)])
@@ -414,4 +445,5 @@ class TestCLIDBClean:
             clean_before_timestamp=pendulum.parse('2021-01-01 00:00:00Z'),
             verbose=expected,
             confirm=True,
+            skip_archive=False,
         )
diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py
index b7502fc52b..ae4a1d6598 100644
--- a/tests/test_utils/db.py
+++ b/tests/test_utils/db.py
@@ -38,7 +38,7 @@ from airflow.models import (
 from airflow.models.dagcode import DagCode
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.security.permissions import RESOURCE_DAG_PREFIX
-from airflow.utils.db import add_default_pool_if_not_exists, create_default_connections
+from airflow.utils.db import add_default_pool_if_not_exists, create_default_connections, reflect_tables
 from airflow.utils.session import create_session
 from airflow.www.fab_security.sqla.models import Permission, Resource, assoc_permission_role
 
@@ -57,6 +57,14 @@ def clear_db_dags():
         session.query(DagModel).delete()
 
 
+def drop_tables_with_prefix(prefix):
+    with create_session() as session:
+        metadata = reflect_tables(None, session)
+        for table_name, table in metadata.tables.items():
+            if table_name.startswith(prefix):
+                table.drop()
+
+
 def clear_db_serialized_dags():
     with create_session() as session:
         session.query(SerializedDagModel).delete()
diff --git a/tests/utils/test_db_cleanup.py b/tests/utils/test_db_cleanup.py
index 8d227df6e5..e335cdb251 100644
--- a/tests/utils/test_db_cleanup.py
+++ b/tests/utils/test_db_cleanup.py
@@ -30,7 +30,7 @@ from airflow.models import DagModel, DagRun, TaskInstance
 from airflow.operators.python import PythonOperator
 from airflow.utils.db_cleanup import _build_query, _cleanup_table, config_dict, run_cleanup
 from airflow.utils.session import create_session
-from tests.test_utils.db import clear_db_dags, clear_db_runs
+from tests.test_utils.db import clear_db_dags, clear_db_runs, drop_tables_with_prefix
 
 
 @pytest.fixture(autouse=True)
@@ -44,6 +44,10 @@ def clean_database():
 
 
 class TestDBCleanup:
+    @pytest.fixture(autouse=True)
+    def clear_airflow_tables(self):
+        drop_tables_with_prefix('_airflow_')
+
     @pytest.mark.parametrize(
         'kwargs, called',
         [
@@ -68,6 +72,27 @@ class TestDBCleanup:
         else:
             confirm_delete_mock.assert_not_called()
 
+    @pytest.mark.parametrize(
+        'kwargs, should_skip',
+        [
+            param(dict(skip_archive=True), True, id='true'),
+            param(dict(), False, id='not supplied'),
+            param(dict(skip_archive=False), False, id='false'),
+        ],
+    )
+    @patch('airflow.utils.db_cleanup._cleanup_table')
+    def test_run_cleanup_skip_archive(self, cleanup_table_mock, kwargs, should_skip):
+        """test that delete confirmation input is called when appropriate"""
+        run_cleanup(
+            clean_before_timestamp=None,
+            table_names=['log'],
+            dry_run=None,
+            verbose=None,
+            confirm=False,
+            **kwargs,
+        )
+        assert cleanup_table_mock.call_args[1]['skip_archive'] is should_skip
+
     @pytest.mark.parametrize(
         'table_names',
         [
@@ -95,12 +120,14 @@ class TestDBCleanup:
         [None, True, False],
     )
     @patch('airflow.utils.db_cleanup._build_query', MagicMock())
-    @patch('airflow.utils.db_cleanup._print_entities', MagicMock())
-    @patch('airflow.utils.db_cleanup._do_delete')
     @patch('airflow.utils.db_cleanup._confirm_delete', MagicMock())
-    def test_run_cleanup_dry_run(self, do_delete, dry_run):
+    @patch('airflow.utils.db_cleanup._check_for_rows')
+    @patch('airflow.utils.db_cleanup._do_delete')
+    def test_run_cleanup_dry_run(self, do_delete, check_rows_mock, dry_run):
         """Delete should only be called when not dry_run"""
+        check_rows_mock.return_value = 10
         base_kwargs = dict(
+            table_names=['log'],
             clean_before_timestamp=None,
             dry_run=dry_run,
             verbose=None,
@@ -135,7 +162,7 @@ class TestDBCleanup:
         dag run is kept.
 
         """
-        base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone('America/Los_Angeles'))
+        base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone('UTC'))
         create_tis(
             base_date=base_date,
             num_tis=10,
@@ -175,7 +202,7 @@ class TestDBCleanup:
         associated dag runs should remain.
 
         """
-        base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone('America/Los_Angeles'))
+        base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone('UTC'))
         num_tis = 10
         create_tis(
             base_date=base_date,
@@ -189,13 +216,14 @@ class TestDBCleanup:
                 clean_before_timestamp=clean_before_date,
                 dry_run=False,
                 session=session,
+                table_names=['dag_run', 'task_instance'],
             )
             model = config_dict[table_name].orm_model
             expected_remaining = num_tis - expected_to_delete
             assert len(session.query(model).all()) == expected_remaining
-            if model == TaskInstance:
+            if model.name == 'task_instance':
                 assert len(session.query(DagRun).all()) == num_tis
-            elif model == DagRun:
+            elif model.name == 'dag_run':
                 assert len(session.query(TaskInstance).all()) == expected_remaining
             else:
                 raise Exception("unexpected")