You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/09/06 18:48:29 UTC
[airflow] branch main updated: Support multiple DagProcessors parsing files from different locations. (#25935)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f878854fff Support multiple DagProcessors parsing files from different locations. (#25935)
f878854fff is described below
commit f878854fff0f5bc577eab70c61e69e353aed43c3
Author: mhenc <mh...@google.com>
AuthorDate: Tue Sep 6 20:48:00 2022 +0200
Support multiple DagProcessors parsing files from different locations. (#25935)
* Separate dag processors
* Introduce [scheduler]stalled_dags_update_timeout configuration option
* Remove DagProcessorDirectory class and pass dag_directory as parameter
* Rename dag_directory column to processor_subdir in CallbackRequests
Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
---
airflow/callbacks/callback_requests.py | 28 +++-
airflow/config_templates/config.yml | 8 ++
airflow/config_templates/default_airflow.cfg | 4 +
airflow/dag_processing/manager.py | 71 ++++++----
airflow/dag_processing/processor.py | 15 +-
airflow/jobs/scheduler_job.py | 48 ++++++-
..._2_4_0_add_processor_subdir_to_dagmodel_and_.py | 76 ++++++++++
airflow/models/dag.py | 24 +++-
airflow/models/dagbag.py | 6 +-
airflow/models/dagrun.py | 12 ++
airflow/models/db_callback_request.py | 3 +-
airflow/models/serialized_dag.py | 46 ++++--
docs/apache-airflow/migrations-ref.rst | 5 +-
tests/callbacks/test_callback_requests.py | 12 +-
tests/conftest.py | 15 +-
tests/dag_processing/test_manager.py | 154 +++++++++++++++++++--
tests/dag_processing/test_processor.py | 90 +++++++-----
tests/jobs/test_scheduler_job.py | 77 +++++++++--
tests/models/test_dag.py | 6 +-
tests/models/test_dagbag.py | 6 +-
tests/models/test_dagrun.py | 6 +
tests/models/test_serialized_dag.py | 29 +++-
tests/test_utils/perf/perf_kit/sqlalchemy.py | 2 +-
tests/www/views/test_views_home.py | 2 +-
24 files changed, 613 insertions(+), 132 deletions(-)
diff --git a/airflow/callbacks/callback_requests.py b/airflow/callbacks/callback_requests.py
index b04a201c08..3e274368fb 100644
--- a/airflow/callbacks/callback_requests.py
+++ b/airflow/callbacks/callback_requests.py
@@ -28,10 +28,17 @@ class CallbackRequest:
:param full_filepath: File Path to use to run the callback
:param msg: Additional Message that can be used for logging
+ :param processor_subdir: Directory used by Dag Processor when parsed the dag.
"""
- def __init__(self, full_filepath: str, msg: Optional[str] = None):
+ def __init__(
+ self,
+ full_filepath: str,
+ processor_subdir: Optional[str] = None,
+ msg: Optional[str] = None,
+ ):
self.full_filepath = full_filepath
+ self.processor_subdir = processor_subdir
self.msg = msg
def __eq__(self, other):
@@ -60,6 +67,7 @@ class TaskCallbackRequest(CallbackRequest):
:param simple_task_instance: Simplified Task Instance representation
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging to determine failure/zombie
+ :param processor_subdir: Directory used by Dag Processor when parsed the dag.
"""
def __init__(
@@ -67,9 +75,10 @@ class TaskCallbackRequest(CallbackRequest):
full_filepath: str,
simple_task_instance: "SimpleTaskInstance",
is_failure_callback: Optional[bool] = True,
+ processor_subdir: Optional[str] = None,
msg: Optional[str] = None,
):
- super().__init__(full_filepath=full_filepath, msg=msg)
+ super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.simple_task_instance = simple_task_instance
self.is_failure_callback = is_failure_callback
@@ -94,6 +103,7 @@ class DagCallbackRequest(CallbackRequest):
:param full_filepath: File Path to use to run the callback
:param dag_id: DAG ID
:param run_id: Run ID for the DagRun
+ :param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging
"""
@@ -103,10 +113,11 @@ class DagCallbackRequest(CallbackRequest):
full_filepath: str,
dag_id: str,
run_id: str,
+ processor_subdir: Optional[str],
is_failure_callback: Optional[bool] = True,
msg: Optional[str] = None,
):
- super().__init__(full_filepath=full_filepath, msg=msg)
+ super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.dag_id = dag_id
self.run_id = run_id
self.is_failure_callback = is_failure_callback
@@ -118,8 +129,15 @@ class SlaCallbackRequest(CallbackRequest):
:param full_filepath: File Path to use to run the callback
:param dag_id: DAG ID
+ :param processor_subdir: Directory used by Dag Processor when parsed the dag.
"""
- def __init__(self, full_filepath: str, dag_id: str, msg: Optional[str] = None):
- super().__init__(full_filepath, msg)
+ def __init__(
+ self,
+ full_filepath: str,
+ dag_id: str,
+ processor_subdir: Optional[str],
+ msg: Optional[str] = None,
+ ):
+ super().__init__(full_filepath, processor_subdir=processor_subdir, msg=msg)
self.dag_id = dag_id
diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index 13b299d1fe..4537b62d6b 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -2121,6 +2121,14 @@
type: integer
example: ~
default: "20"
+ - name: dag_stale_not_seen_duration
+ description: |
+ Only applicable if `[scheduler]standalone_dag_processor` is true.
+ Time in seconds after which dags, which were not updated by Dag Processor are deactivated.
+ version_added: 2.4.0
+ type: integer
+ example: ~
+ default: "600"
- name: use_job_schedule
description: |
Turn off scheduler use of cron intervals by setting this to False.
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index 119873d9f2..7cd116369e 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -1077,6 +1077,10 @@ standalone_dag_processor = False
# in database. Contains maximum number of callbacks that are fetched during a single loop.
max_callbacks_per_loop = 20
+# Only applicable if `[scheduler]standalone_dag_processor` is true.
+# Time in seconds after which dags, which were not updated by Dag Processor are deactivated.
+dag_stale_not_seen_duration = 600
+
# Turn off scheduler use of cron intervals by setting this to False.
# DAGs submitted manually in the web UI or with trigger_dag will still run.
use_job_schedule = True
diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py
index ecd1f24434..fd1098899e 100644
--- a/airflow/dag_processing/manager.py
+++ b/airflow/dag_processing/manager.py
@@ -31,7 +31,8 @@ from collections import defaultdict
from datetime import datetime, timedelta
from importlib import import_module
from multiprocessing.connection import Connection as MultiprocessingConnection
-from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union, cast
+from pathlib import Path
+from typing import Any, Dict, List, NamedTuple, Optional, Union, cast
from setproctitle import setproctitle
from sqlalchemy.orm import Session
@@ -57,9 +58,6 @@ from airflow.utils.process_utils import (
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import prohibit_commit, skip_locked, with_row_locks
-if TYPE_CHECKING:
- import pathlib
-
class DagParsingStat(NamedTuple):
"""Information on processing progress"""
@@ -107,7 +105,7 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin):
def __init__(
self,
- dag_directory: str,
+ dag_directory: os.PathLike,
max_runs: int,
processor_timeout: timedelta,
dag_ids: Optional[List[str]],
@@ -116,7 +114,7 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin):
):
super().__init__()
self._file_path_queue: List[str] = []
- self._dag_directory: str = dag_directory
+ self._dag_directory: os.PathLike = dag_directory
self._max_runs = max_runs
self._processor_timeout = processor_timeout
self._dag_ids = dag_ids
@@ -205,7 +203,7 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin):
@staticmethod
def _run_processor_manager(
- dag_directory: str,
+ dag_directory: os.PathLike,
max_runs: int,
processor_timeout: timedelta,
signal_conn: MultiprocessingConnection,
@@ -368,7 +366,7 @@ class DagFileProcessorManager(LoggingMixin):
def __init__(
self,
- dag_directory: Union[str, "pathlib.Path"],
+ dag_directory: os.PathLike,
max_runs: int,
processor_timeout: timedelta,
dag_ids: Optional[List[str]],
@@ -379,7 +377,6 @@ class DagFileProcessorManager(LoggingMixin):
super().__init__()
self._file_paths: List[str] = []
self._file_path_queue: List[str] = []
- self._dag_directory = dag_directory
self._max_runs = max_runs
# signal_conn is None for dag_processor_standalone mode.
self._direct_scheduler_conn = signal_conn
@@ -387,6 +384,7 @@ class DagFileProcessorManager(LoggingMixin):
self._dag_ids = dag_ids
self._async_mode = async_mode
self._parsing_start_time: Optional[int] = None
+ self._dag_directory = dag_directory
# Set the signal conn in to non-blocking mode, so that attempting to
# send when the buffer is full errors, rather than hangs for-ever
@@ -397,6 +395,7 @@ class DagFileProcessorManager(LoggingMixin):
if self._async_mode and self._direct_scheduler_conn is not None:
os.set_blocking(self._direct_scheduler_conn.fileno(), False)
+ self.standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor")
self._parallelism = conf.getint('scheduler', 'parsing_processes')
if (
conf.get_mandatory_value('database', 'sql_alchemy_conn').startswith('sqlite')
@@ -498,11 +497,13 @@ class DagFileProcessorManager(LoggingMixin):
fp: self.get_last_finish_time(fp) for fp in self.file_paths if self.get_last_finish_time(fp)
}
to_deactivate = set()
- dags_parsed = (
- session.query(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time)
- .filter(DagModel.is_active)
- .all()
+ query = session.query(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).filter(
+ DagModel.is_active
)
+ if self.standalone_dag_processor:
+ query = query.filter(DagModel.processor_subdir == self.get_dag_directory())
+ dags_parsed = query.all()
+
for dag in dags_parsed:
# The largest valid difference between a DagFileStat's last_finished_time and a DAG's
# last_parsed_time is _processor_timeout. Longer than that indicates that the DAG is
@@ -540,7 +541,7 @@ class DagFileProcessorManager(LoggingMixin):
self._refresh_dag_dir()
self.prepare_file_path_queue()
max_callbacks_per_loop = conf.getint("scheduler", "max_callbacks_per_loop")
- standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor")
+
if self._async_mode:
# If we're in async mode, we can start up straight away. If we're
# in sync mode we need to be told to start a "loop"
@@ -591,7 +592,7 @@ class DagFileProcessorManager(LoggingMixin):
self.waitables.pop(sentinel)
self._processors.pop(processor.file_path)
- if standalone_dag_processor:
+ if self.standalone_dag_processor:
self._fetch_callbacks(max_callbacks_per_loop)
self._deactivate_stale_dags()
DagWarning.purge_inactive_dag_warnings()
@@ -661,11 +662,12 @@ class DagFileProcessorManager(LoggingMixin):
"""Fetches callbacks from database and add them to the internal queue for execution."""
self.log.debug("Fetching callbacks from the database.")
with prohibit_commit(session) as guard:
- query = (
- session.query(DbCallbackRequest)
- .order_by(DbCallbackRequest.priority_weight.asc())
- .limit(max_callbacks)
- )
+ query = session.query(DbCallbackRequest)
+ if self.standalone_dag_processor:
+ query = query.filter(
+ DbCallbackRequest.processor_subdir == self.get_dag_directory(),
+ )
+ query = query.order_by(DbCallbackRequest.priority_weight.asc()).limit(max_callbacks)
callbacks = with_row_locks(
query, of=DbCallbackRequest, session=session, **skip_locked(session=session)
).all()
@@ -743,7 +745,10 @@ class DagFileProcessorManager(LoggingMixin):
else:
dag_filelocs.append(fileloc)
- SerializedDagModel.remove_deleted_dags(dag_filelocs)
+ SerializedDagModel.remove_deleted_dags(
+ alive_dag_filelocs=dag_filelocs,
+ processor_subdir=self.get_dag_directory(),
+ )
DagModel.deactivate_deleted_dags(self._file_paths)
from airflow.models.dagcode import DagCode
@@ -913,6 +918,16 @@ class DagFileProcessorManager(LoggingMixin):
stat = self._file_stats.get(file_path)
return stat.run_count if stat else 0
+ def get_dag_directory(self) -> str:
+ """
+ Returns the dag_director as a string.
+ :rtype: str
+ """
+ if isinstance(self._dag_directory, Path):
+ return str(self._dag_directory.resolve())
+ else:
+ return str(self._dag_directory)
+
def set_file_paths(self, new_file_paths):
"""
Update this with a new set of paths to DAG definition files.
@@ -986,10 +1001,14 @@ class DagFileProcessorManager(LoggingMixin):
self.log.debug("%s file paths queued for processing", len(self._file_path_queue))
@staticmethod
- def _create_process(file_path, pickle_dags, dag_ids, callback_requests):
+ def _create_process(file_path, pickle_dags, dag_ids, dag_directory, callback_requests):
"""Creates DagFileProcessorProcess instance."""
return DagFileProcessorProcess(
- file_path=file_path, pickle_dags=pickle_dags, dag_ids=dag_ids, callback_requests=callback_requests
+ file_path=file_path,
+ pickle_dags=pickle_dags,
+ dag_ids=dag_ids,
+ dag_directory=dag_directory,
+ callback_requests=callback_requests,
)
def start_new_processes(self):
@@ -1002,7 +1021,11 @@ class DagFileProcessorManager(LoggingMixin):
callback_to_execute_for_file = self._callback_to_execute[file_path]
processor = self._create_process(
- file_path, self._pickle_dags, self._dag_ids, callback_to_execute_for_file
+ file_path,
+ self._pickle_dags,
+ self._dag_ids,
+ self.get_dag_directory(),
+ callback_to_execute_for_file,
)
del self._callback_to_execute[file_path]
diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py
index b7dac828e2..fa1eb46c29 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -76,12 +76,14 @@ class DagFileProcessorProcess(LoggingMixin, MultiprocessingStartMethodMixin):
file_path: str,
pickle_dags: bool,
dag_ids: Optional[List[str]],
+ dag_directory: str,
callback_requests: List[CallbackRequest],
):
super().__init__()
self._file_path = file_path
self._pickle_dags = pickle_dags
self._dag_ids = dag_ids
+ self._dag_directory = dag_directory
self._callback_requests = callback_requests
# The process that was launched to process the given .
@@ -111,6 +113,7 @@ class DagFileProcessorProcess(LoggingMixin, MultiprocessingStartMethodMixin):
pickle_dags: bool,
dag_ids: Optional[List[str]],
thread_name: str,
+ dag_directory: str,
callback_requests: List[CallbackRequest],
) -> None:
"""
@@ -154,7 +157,11 @@ class DagFileProcessorProcess(LoggingMixin, MultiprocessingStartMethodMixin):
threading.current_thread().name = thread_name
log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path)
- dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log)
+ dag_file_processor = DagFileProcessor(
+ dag_ids=dag_ids,
+ dag_directory=dag_directory,
+ log=log,
+ )
result: Tuple[int, int] = dag_file_processor.process_file(
file_path=file_path,
pickle_dags=pickle_dags,
@@ -188,6 +195,7 @@ class DagFileProcessorProcess(LoggingMixin, MultiprocessingStartMethodMixin):
self._pickle_dags,
self._dag_ids,
f"DagFileProcessor{self._instance_id}",
+ self._dag_directory,
self._callback_requests,
),
name=f"DagFileProcessor{self._instance_id}-Process",
@@ -356,10 +364,11 @@ class DagFileProcessor(LoggingMixin):
UNIT_TEST_MODE: bool = conf.getboolean('core', 'UNIT_TEST_MODE')
- def __init__(self, dag_ids: Optional[List[str]], log: logging.Logger):
+ def __init__(self, dag_ids: Optional[List[str]], dag_directory: str, log: logging.Logger):
super().__init__()
self.dag_ids = dag_ids
self._log = log
+ self._dag_directory = dag_directory
self.dag_warnings: Set[Tuple[str, str]] = set()
@provide_session
@@ -766,7 +775,7 @@ class DagFileProcessor(LoggingMixin):
session.commit()
# Save individual DAGs in the ORM
- dagbag.sync_to_db(session)
+ dagbag.sync_to_db(processor_subdir=self._dag_directory, session=session)
session.commit()
if pickle_dags:
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index d8b8073674..50dd18d2c2 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -26,6 +26,7 @@ import time
import warnings
from collections import defaultdict
from datetime import datetime, timedelta
+from pathlib import Path
from typing import TYPE_CHECKING, Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple
from sqlalchemy import func, not_, or_, text
@@ -141,6 +142,7 @@ class SchedulerJob(BaseJob):
# How many seconds do we wait for tasks to heartbeat before mark them as zombies.
self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold')
self._standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor")
+ self._dag_stale_not_seen_duration = conf.getint("scheduler", "dag_stale_not_seen_duration")
self.do_pickle = do_pickle
super().__init__(*args, **kwargs)
@@ -685,6 +687,7 @@ class SchedulerJob(BaseJob):
full_filepath=ti.dag_model.fileloc,
simple_task_instance=SimpleTaskInstance.from_ti(ti),
msg=msg % (ti, state, ti.state, info),
+ processor_subdir=ti.dag_model.processor_subdir,
)
self.executor.send_callback(request)
else:
@@ -708,7 +711,7 @@ class SchedulerJob(BaseJob):
processor_timeout = timedelta(seconds=processor_timeout_seconds)
if not self._standalone_dag_processor:
self.processor_agent = DagFileProcessorAgent(
- dag_directory=self.subdir,
+ dag_directory=Path(self.subdir),
max_runs=self.num_times_parse_dags,
processor_timeout=processor_timeout,
dag_ids=[],
@@ -834,6 +837,12 @@ class SchedulerJob(BaseJob):
)
timers.call_regular_interval(60.0, self._update_dag_run_state_for_paused_dags)
+ if self._standalone_dag_processor:
+ timers.call_regular_interval(
+ conf.getfloat('scheduler', 'deactivate_stale_dags_interval', fallback=60.0),
+ self._cleanup_stale_dags,
+ )
+
for loop_count in itertools.count(start=1):
with Stats.timer() as timer:
@@ -1260,6 +1269,7 @@ class SchedulerJob(BaseJob):
dag_id=dag.dag_id,
run_id=dag_run.run_id,
is_failure_callback=True,
+ processor_subdir=dag_model.processor_subdir,
msg='timed_out',
)
@@ -1322,7 +1332,12 @@ class SchedulerJob(BaseJob):
self.log.debug("Skipping SLA check for %s because DAG is not scheduled", dag)
return
- request = SlaCallbackRequest(full_filepath=dag.fileloc, dag_id=dag.dag_id)
+ dag_model = DagModel.get_dagmodel(dag.dag_id)
+ request = SlaCallbackRequest(
+ full_filepath=dag.fileloc,
+ dag_id=dag.dag_id,
+ processor_subdir=dag_model.processor_subdir,
+ )
self.executor.send_callback(request)
@provide_session
@@ -1485,11 +1500,11 @@ class SchedulerJob(BaseJob):
zombie_message_details = self._generate_zombie_message_details(ti)
request = TaskCallbackRequest(
full_filepath=file_loc,
+ processor_subdir=ti.dag_model.processor_subdir,
simple_task_instance=SimpleTaskInstance.from_ti(ti),
msg=str(zombie_message_details),
)
-
- self.log.error("Detected zombie job: %s", request.msg)
+ self.log.error("Detected zombie job: %s", request)
self.executor.send_callback(request)
Stats.incr('zombies_killed')
@@ -1509,3 +1524,28 @@ class SchedulerJob(BaseJob):
zombie_message_details["External Executor Id"] = ti.external_executor_id
return zombie_message_details
+
+ @provide_session
+ def _cleanup_stale_dags(self, session: Session = NEW_SESSION) -> None:
+ """
+ Find all dags that were not updated by Dag Processor recently and mark them as inactive.
+
+ In case one of DagProcessors is stopped (in case there are multiple of them
+ for different dag folders), it's dags are never marked as inactive.
+ Also remove dags from SerializedDag table.
+ Executed on schedule only if [scheduler]standalone_dag_processor is True.
+ """
+ self.log.debug("Checking dags not parsed within last %s seconds.", self._dag_stale_not_seen_duration)
+ limit_lpt = timezone.utcnow() - timedelta(seconds=self._dag_stale_not_seen_duration)
+ stale_dags = (
+ session.query(DagModel).filter(DagModel.is_active, DagModel.last_parsed_time < limit_lpt).all()
+ )
+ if not stale_dags:
+ self.log.debug("Not stale dags found.")
+ return
+
+ self.log.info("Found (%d) stales dags not parsed after %s.", len(stale_dags), limit_lpt)
+ for dag in stale_dags:
+ dag.is_active = False
+ SerializedDagModel.remove_dag(dag_id=dag.dag_id, session=session)
+ session.flush()
diff --git a/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py b/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py
new file mode 100644
index 0000000000..ae16a3fffb
--- /dev/null
+++ b/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add processor_subdir column to DagModel, SerializedDagModel and CallbackRequest tables.
+
+Revision ID: ecb43d2a1842
+Revises: 1486deb605b4
+Create Date: 2022-08-26 11:30:11.249580
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = 'ecb43d2a1842'
+down_revision = '1486deb605b4'
+branch_labels = None
+depends_on = None
+airflow_version = '2.4.0'
+
+
+def upgrade():
+ """Apply add processor_subdir to DagModel and SerializedDagModel"""
+ conn = op.get_bind()
+
+ with op.batch_alter_table('dag') as batch_op:
+ if conn.dialect.name == "mysql":
+ batch_op.add_column(sa.Column('processor_subdir', sa.Text(length=2000), nullable=True))
+ else:
+ batch_op.add_column(sa.Column('processor_subdir', sa.String(length=2000), nullable=True))
+
+ with op.batch_alter_table('serialized_dag') as batch_op:
+ if conn.dialect.name == "mysql":
+ batch_op.add_column(sa.Column('processor_subdir', sa.Text(length=2000), nullable=True))
+ else:
+ batch_op.add_column(sa.Column('processor_subdir', sa.String(length=2000), nullable=True))
+
+ with op.batch_alter_table('callback_request') as batch_op:
+ batch_op.drop_column('dag_directory')
+ if conn.dialect.name == "mysql":
+ batch_op.add_column(sa.Column('processor_subdir', sa.Text(length=2000), nullable=True))
+ else:
+ batch_op.add_column(sa.Column('processor_subdir', sa.String(length=2000), nullable=True))
+
+
+def downgrade():
+ """Unapply Add processor_subdir to DagModel and SerializedDagModel"""
+ conn = op.get_bind()
+ with op.batch_alter_table('dag', schema=None) as batch_op:
+ batch_op.drop_column('processor_subdir')
+
+ with op.batch_alter_table('serialized_dag', schema=None) as batch_op:
+ batch_op.drop_column('processor_subdir')
+
+ with op.batch_alter_table('callback_request') as batch_op:
+ batch_op.drop_column('processor_subdir')
+ if conn.dialect.name == "mysql":
+ batch_op.add_column(sa.Column('dag_directory', sa.Text(length=1000), nullable=True))
+ else:
+ batch_op.add_column(sa.Column('dag_directory', sa.String(length=1000), nullable=True))
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index b3b10b6d66..c7336b8aaa 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2532,18 +2532,27 @@ class DAG(LoggingMixin):
@classmethod
@provide_session
- def bulk_sync_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
+ def bulk_sync_to_db(
+ cls,
+ dags: Collection["DAG"],
+ session=NEW_SESSION,
+ ):
"""This method is deprecated in favor of bulk_write_to_db"""
warnings.warn(
"This method is deprecated and will be removed in a future version. Please use bulk_write_to_db",
RemovedInAirflow3Warning,
stacklevel=2,
)
- return cls.bulk_write_to_db(dags, session)
+ return cls.bulk_write_to_db(dags=dags, session=session)
@classmethod
@provide_session
- def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
+ def bulk_write_to_db(
+ cls,
+ dags: Collection["DAG"],
+ processor_subdir: Optional[str] = None,
+ session=NEW_SESSION,
+ ):
"""
Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB, including
calculated fields.
@@ -2624,6 +2633,7 @@ class DAG(LoggingMixin):
orm_dag.has_task_concurrency_limits = any(t.max_active_tis_per_dag is not None for t in dag.tasks)
orm_dag.schedule_interval = dag.schedule_interval
orm_dag.timetable_description = dag.timetable.description
+ orm_dag.processor_subdir = processor_subdir
run: Optional[DagRun] = most_recent_runs.get(dag.dag_id)
if run is None:
@@ -2729,10 +2739,10 @@ class DAG(LoggingMixin):
session.flush()
for dag in dags:
- cls.bulk_write_to_db(dag.subdags, session=session)
+ cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session)
@provide_session
- def sync_to_db(self, session=NEW_SESSION):
+ def sync_to_db(self, processor_subdir: Optional[str] = None, session=NEW_SESSION):
"""
Save attributes about this DAG to the DB. Note that this method
can be called for both DAGs and SubDAGs. A SubDag is actually a
@@ -2740,7 +2750,7 @@ class DAG(LoggingMixin):
:return: None
"""
- self.bulk_write_to_db([self], session)
+ self.bulk_write_to_db([self], processor_subdir=processor_subdir, session=session)
def get_default_view(self):
"""This is only there for backward compatible jinja2 templates"""
@@ -2977,6 +2987,8 @@ class DagModel(Base):
# packaged DAG, it will point to the subpath of the DAG within the
# associated zip.
fileloc = Column(String(2000))
+ # The base directory used by Dag Processor that parsed this dag.
+ processor_subdir = Column(String(2000), nullable=True)
# String representing the owners
owners = Column(String(2000))
# Description of the dag
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 3183f8a8f1..3f8d6e57ca 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -575,7 +575,7 @@ class DagBag(LoggingMixin):
return report
@provide_session
- def sync_to_db(self, session: Session = None):
+ def sync_to_db(self, processor_subdir: Optional[str] = None, session: Session = None):
"""Save attributes about list of DAG to the DB."""
# To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag
from airflow.models.dag import DAG
@@ -622,7 +622,9 @@ class DagBag(LoggingMixin):
for dag in self.dags.values():
serialize_errors.extend(_serialize_dag_capturing_errors(dag, session))
- DAG.bulk_write_to_db(self.dags.values(), session=session)
+ DAG.bulk_write_to_db(
+ self.dags.values(), processor_subdir=processor_subdir, session=session
+ )
except OperationalError:
session.rollback()
raise
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 41c9729282..65f92fd8c8 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -572,11 +572,15 @@ class DagRun(Base, LoggingMixin):
if execute_callbacks:
dag.handle_callback(self, success=False, reason='task_failure', session=session)
elif dag.has_on_failure_callback:
+ from airflow.models.dag import DagModel
+
+ dag_model = DagModel.get_dagmodel(dag.dag_id, session)
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
is_failure_callback=True,
+ processor_subdir=dag_model.processor_subdir,
msg='task_failure',
)
@@ -587,11 +591,15 @@ class DagRun(Base, LoggingMixin):
if execute_callbacks:
dag.handle_callback(self, success=True, reason='success', session=session)
elif dag.has_on_success_callback:
+ from airflow.models.dag import DagModel
+
+ dag_model = DagModel.get_dagmodel(dag.dag_id, session)
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
is_failure_callback=False,
+ processor_subdir=dag_model.processor_subdir,
msg='success',
)
@@ -602,11 +610,15 @@ class DagRun(Base, LoggingMixin):
if execute_callbacks:
dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session)
elif dag.has_on_failure_callback:
+ from airflow.models.dag import DagModel
+
+ dag_model = DagModel.get_dagmodel(dag.dag_id, session)
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
is_failure_callback=True,
+ processor_subdir=dag_model.processor_subdir,
msg='all_tasks_deadlocked',
)
diff --git a/airflow/models/db_callback_request.py b/airflow/models/db_callback_request.py
index 4fdd36a71b..1f6ee4cd85 100644
--- a/airflow/models/db_callback_request.py
+++ b/airflow/models/db_callback_request.py
@@ -36,11 +36,12 @@ class DbCallbackRequest(Base):
priority_weight = Column(Integer(), nullable=False)
callback_data = Column(ExtendedJSON, nullable=False)
callback_type = Column(String(20), nullable=False)
- dag_directory = Column(String(1000), nullable=True)
+ processor_subdir = Column(String(2000), nullable=True)
def __init__(self, priority_weight: int, callback: CallbackRequest):
self.created_at = timezone.utcnow()
self.priority_weight = priority_weight
+ self.processor_subdir = callback.processor_subdir
self.callback_data = callback.to_json()
self.callback_type = callback.__class__.__name__
diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py
index 9114af987c..baf3691f50 100644
--- a/airflow/models/serialized_dag.py
+++ b/airflow/models/serialized_dag.py
@@ -25,7 +25,7 @@ from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
import sqlalchemy_jsonfield
-from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_
+from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_, or_
from sqlalchemy.orm import Session, backref, foreign, relationship
from sqlalchemy.sql.expression import func, literal
@@ -72,6 +72,7 @@ class SerializedDagModel(Base):
_data_compressed = Column('data_compressed', LargeBinary, nullable=True)
last_updated = Column(UtcDateTime, nullable=False)
dag_hash = Column(String(32), nullable=False)
+ processor_subdir = Column(String(2000), nullable=True)
__table_args__ = (Index('idx_fileloc_hash', fileloc_hash, unique=False),)
@@ -92,11 +93,12 @@ class SerializedDagModel(Base):
load_op_links = True
- def __init__(self, dag: DAG):
+ def __init__(self, dag: DAG, processor_subdir: Optional[str] = None):
self.dag_id = dag.dag_id
self.fileloc = dag.fileloc
self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
self.last_updated = timezone.utcnow()
+ self.processor_subdir = processor_subdir
dag_data = SerializedDAG.to_dict(dag)
dag_data_json = json.dumps(dag_data, sort_keys=True).encode("utf-8")
@@ -119,7 +121,13 @@ class SerializedDagModel(Base):
@classmethod
@provide_session
- def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session: Session = None) -> bool:
+ def write_dag(
+ cls,
+ dag: DAG,
+ min_update_interval: Optional[int] = None,
+ processor_subdir: Optional[str] = None,
+ session: Session = None,
+ ) -> bool:
"""Serializes a DAG and writes it into database.
If the record already exists, it checks if the Serialized DAG changed or not. If it is
changed, it updates the record, ignores otherwise.
@@ -151,10 +159,16 @@ class SerializedDagModel(Base):
return False
log.debug("Checking if DAG (%s) changed", dag.dag_id)
- new_serialized_dag = cls(dag)
- serialized_dag_hash_from_db = session.query(cls.dag_hash).filter(cls.dag_id == dag.dag_id).scalar()
+ new_serialized_dag = cls(dag, processor_subdir)
+ serialized_dag_db = (
+ session.query(cls.dag_hash, cls.processor_subdir).filter(cls.dag_id == dag.dag_id).first()
+ )
- if serialized_dag_hash_from_db == new_serialized_dag.dag_hash:
+ if (
+ serialized_dag_db is not None
+ and serialized_dag_db.dag_hash == new_serialized_dag.dag_hash
+ and serialized_dag_db.processor_subdir == new_serialized_dag.processor_subdir
+ ):
log.debug("Serialized DAG (%s) is unchanged. Skipping writing to DB", dag.dag_id)
return False
@@ -222,7 +236,9 @@ class SerializedDagModel(Base):
@classmethod
@provide_session
- def remove_deleted_dags(cls, alive_dag_filelocs: List[str], session=None):
+ def remove_deleted_dags(
+ cls, alive_dag_filelocs: List[str], processor_subdir: Optional[str] = None, session=None
+ ):
"""Deletes DAGs not included in alive_dag_filelocs.
:param alive_dag_filelocs: file paths of alive DAGs
@@ -236,7 +252,14 @@ class SerializedDagModel(Base):
session.execute(
cls.__table__.delete().where(
- and_(cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs))
+ and_(
+ cls.fileloc_hash.notin_(alive_fileloc_hashes),
+ cls.fileloc.notin_(alive_dag_filelocs),
+ or_(
+ cls.processor_subdir is None,
+ cls.processor_subdir == processor_subdir,
+ ),
+ )
)
)
@@ -281,7 +304,7 @@ class SerializedDagModel(Base):
@staticmethod
@provide_session
- def bulk_sync_to_db(dags: List[DAG], session: Session = None):
+ def bulk_sync_to_db(dags: List[DAG], processor_subdir: Optional[str] = None, session: Session = None):
"""
Saves DAGs as Serialized DAG objects in the database. Each
DAG is saved in a separate database query.
@@ -293,7 +316,10 @@ class SerializedDagModel(Base):
for dag in dags:
if not dag.is_subdag:
SerializedDagModel.write_dag(
- dag, min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL, session=session
+ dag=dag,
+ min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL,
+ processor_subdir=processor_subdir,
+ session=session,
)
@classmethod
diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst
index 28b5928738..cd4a608fe3 100644
--- a/docs/apache-airflow/migrations-ref.rst
+++ b/docs/apache-airflow/migrations-ref.rst
@@ -27,7 +27,10 @@ Here's the list of all the Database Migrations that are executed via when you ru
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+=================================+===================+===================+==============================================================+
-| ``1486deb605b4`` (head) | ``f4ff391becb5`` | ``2.4.0`` | add dag_owner_attributes table |
+| ``ecb43d2a1842`` (head) | ``1486deb605b4`` | ``2.4.0`` | Add processor_subdir column to DagModel, SerializedDagModel |
+| | | | and CallbackRequest tables. |
++---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
+| ``1486deb605b4`` | ``f4ff391becb5`` | ``2.4.0`` | add dag_owner_attributes table |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| ``f4ff391becb5`` | ``0038cd0c28b4`` | ``2.4.0`` | Remove smart sensors |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
diff --git a/tests/callbacks/test_callback_requests.py b/tests/callbacks/test_callback_requests.py
index 3764f19c4c..8819571f57 100644
--- a/tests/callbacks/test_callback_requests.py
+++ b/tests/callbacks/test_callback_requests.py
@@ -46,6 +46,7 @@ class TestCallbackRequest:
TaskCallbackRequest(
full_filepath="filepath",
simple_task_instance=SimpleTaskInstance.from_ti(ti=TI),
+ processor_subdir='/test_dir',
is_failure_callback=True,
),
TaskCallbackRequest,
@@ -55,11 +56,19 @@ class TestCallbackRequest:
full_filepath="filepath",
dag_id="fake_dag",
run_id="fake_run",
+ processor_subdir='/test_dir',
is_failure_callback=False,
),
DagCallbackRequest,
),
- (SlaCallbackRequest(full_filepath="filepath", dag_id="fake_dag"), SlaCallbackRequest),
+ (
+ SlaCallbackRequest(
+ full_filepath="filepath",
+ dag_id="fake_dag",
+ processor_subdir='/test_dir',
+ ),
+ SlaCallbackRequest,
+ ),
]
)
def test_from_json(self, input, request_class):
@@ -76,6 +85,7 @@ class TestCallbackRequest:
input = TaskCallbackRequest(
full_filepath="filepath",
simple_task_instance=SimpleTaskInstance.from_ti(ti),
+ processor_subdir='/test_dir',
is_failure_callback=True,
)
json_str = input.to_json()
diff --git a/tests/conftest.py b/tests/conftest.py
index 2d0b90d821..b974bf25a0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -518,11 +518,13 @@ def dag_maker(request):
return
dag.clear(session=self.session)
- dag.sync_to_db(self.session)
+ dag.sync_to_db(processor_subdir=self.processor_subdir, session=self.session)
self.dag_model = self.session.query(DagModel).get(dag.dag_id)
if self.want_serialized:
- self.serialized_model = SerializedDagModel(dag)
+ self.serialized_model = SerializedDagModel(
+ dag, processor_subdir=self.dag_model.processor_subdir
+ )
self.session.merge(self.serialized_model)
serialized_dag = self._serialized_dag()
self.dagbag.bag_dag(serialized_dag, root_dag=serialized_dag)
@@ -578,7 +580,13 @@ def dag_maker(request):
)
def __call__(
- self, dag_id='test_dag', serialized=want_serialized, fileloc=None, session=None, **kwargs
+ self,
+ dag_id='test_dag',
+ serialized=want_serialized,
+ fileloc=None,
+ processor_subdir=None,
+ session=None,
+ **kwargs,
):
from airflow import settings
from airflow.models import DAG
@@ -606,6 +614,7 @@ def dag_maker(request):
self.dag = DAG(dag_id, **self.kwargs)
self.dag.fileloc = fileloc or request.module.__file__
self.want_serialized = serialized
+ self.processor_subdir = processor_subdir
return self
diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py
index f88c27dafd..70208312e9 100644
--- a/tests/dag_processing/test_manager.py
+++ b/tests/dag_processing/test_manager.py
@@ -66,8 +66,8 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1)
class FakeDagFileProcessorRunner(DagFileProcessorProcess):
# This fake processor will return the zombies it received in constructor
# as its processing result w/o actually parsing anything.
- def __init__(self, file_path, pickle_dags, dag_ids, callbacks):
- super().__init__(file_path, pickle_dags, dag_ids, callbacks)
+ def __init__(self, file_path, pickle_dags, dag_ids, dag_directory, callbacks):
+ super().__init__(file_path, pickle_dags, dag_ids, dag_directory, callbacks)
# We need a "real" selectable handle for waitable_handle to work
readable, writable = multiprocessing.Pipe(duplex=False)
writable.send('abc')
@@ -95,11 +95,12 @@ class FakeDagFileProcessorRunner(DagFileProcessorProcess):
return self._result
@staticmethod
- def _create_process(file_path, callback_requests, dag_ids, pickle_dags):
+ def _create_process(file_path, callback_requests, dag_ids, dag_directory, pickle_dags):
return FakeDagFileProcessorRunner(
file_path,
pickle_dags,
dag_ids,
+ dag_directory,
callback_requests,
)
@@ -504,7 +505,6 @@ class TestDagFileProcessorManager:
)
assert serialized_dag_count == 1
- manager._file_stats[test_dag_path] = stat
manager._deactivate_stale_dags()
active_dag_count = (
@@ -521,6 +521,62 @@ class TestDagFileProcessorManager:
)
assert serialized_dag_count == 0
+ @conf_vars(
+ {
+ ('core', 'load_examples'): 'False',
+ ('scheduler', 'standalone_dag_processor'): 'True',
+ }
+ )
+ def test_deactivate_stale_dags_standalone_mode(self):
+ """
+ Ensure only dags from current dag_directory are updated
+ """
+ dag_directory = 'directory'
+ manager = DagFileProcessorManager(
+ dag_directory=dag_directory,
+ max_runs=1,
+ processor_timeout=timedelta(minutes=10),
+ signal_conn=MagicMock(),
+ dag_ids=[],
+ pickle_dags=False,
+ async_mode=True,
+ )
+
+ test_dag_path = str(TEST_DAG_FOLDER / 'test_example_bash_operator.py')
+ dagbag = DagBag(test_dag_path, read_dags_from_db=False)
+ other_test_dag_path = str(TEST_DAG_FOLDER / 'test_scheduler_dags.py')
+ other_dagbag = DagBag(other_test_dag_path, read_dags_from_db=False)
+
+ with create_session() as session:
+ # Add stale DAG to the DB
+ dag = dagbag.get_dag('test_example_bash_operator')
+ dag.last_parsed_time = timezone.utcnow()
+ dag.sync_to_db(processor_subdir=dag_directory)
+
+ # Add stale DAG to the DB
+ other_dag = other_dagbag.get_dag('test_start_date_scheduling')
+ other_dag.last_parsed_time = timezone.utcnow()
+ other_dag.sync_to_db(processor_subdir='other')
+
+ # Add DAG to the file_parsing_stats
+ stat = DagFileStat(
+ num_dags=1,
+ import_errors=0,
+ last_finish_time=timezone.utcnow() + timedelta(hours=1),
+ last_duration=1,
+ run_count=1,
+ )
+ manager._file_paths = [test_dag_path]
+ manager._file_stats[test_dag_path] = stat
+
+ active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar()
+ assert active_dag_count == 2
+
+ manager._deactivate_stale_dags()
+
+ active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar()
+ assert active_dag_count == 1
+
@mock.patch(
"airflow.dag_processing.processor.DagFileProcessorProcess.waitable_handle", new_callable=PropertyMock
)
@@ -539,7 +595,13 @@ class TestDagFileProcessorManager:
async_mode=True,
)
- processor = DagFileProcessorProcess('abc.txt', False, [], [])
+ processor = DagFileProcessorProcess(
+ file_path='abc.txt',
+ pickle_dags=False,
+ dag_ids=[],
+ dag_directory=TEST_DAG_FOLDER,
+ callback_requests=[],
+ )
processor._start_time = timezone.make_aware(datetime.min)
manager._processors = {'abc.txt': processor}
manager.waitables[3] = processor
@@ -554,7 +616,7 @@ class TestDagFileProcessorManager:
def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_pid):
mock_pid.return_value = 1234
manager = DagFileProcessorManager(
- dag_directory='directory',
+ dag_directory=TEST_DAG_FOLDER,
max_runs=1,
processor_timeout=timedelta(seconds=5),
signal_conn=MagicMock(),
@@ -563,7 +625,13 @@ class TestDagFileProcessorManager:
async_mode=True,
)
- processor = DagFileProcessorProcess('abc.txt', False, [], [])
+ processor = DagFileProcessorProcess(
+ file_path='abc.txt',
+ pickle_dags=False,
+ dag_ids=[],
+ dag_directory=str(TEST_DAG_FOLDER),
+ callback_requests=[],
+ )
processor._start_time = timezone.make_aware(datetime.max)
manager._processors = {'abc.txt': processor}
manager._kill_timed_out_processors()
@@ -757,17 +825,20 @@ class TestDagFileProcessorManager:
dag_id="test_start_date_scheduling",
full_filepath=str(dag_filepath),
is_failure_callback=True,
+ processor_subdir=str(tmpdir),
run_id='123',
)
callback2 = DagCallbackRequest(
dag_id="test_start_date_scheduling",
full_filepath=str(dag_filepath),
is_failure_callback=True,
+ processor_subdir=str(tmpdir),
run_id='456',
)
callback3 = SlaCallbackRequest(
dag_id="test_start_date_scheduling",
full_filepath=str(dag_filepath),
+ processor_subdir=str(tmpdir),
)
with create_session() as session:
@@ -777,7 +848,7 @@ class TestDagFileProcessorManager:
child_pipe, parent_pipe = multiprocessing.Pipe()
manager = DagFileProcessorManager(
- dag_directory=tmpdir,
+ dag_directory=str(tmpdir),
max_runs=1,
processor_timeout=timedelta(days=365),
signal_conn=child_pipe,
@@ -790,6 +861,50 @@ class TestDagFileProcessorManager:
self.run_processor_manager_one_loop(manager, parent_pipe)
assert session.query(DbCallbackRequest).count() == 0
+ @conf_vars(
+ {
+ ('core', 'load_examples'): 'False',
+ ('scheduler', 'standalone_dag_processor'): 'True',
+ }
+ )
+ def test_fetch_callbacks_for_current_dag_directory_only(self, tmpdir):
+ """Test DagFileProcessorManager._fetch_callbacks method"""
+ dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py"
+
+ callback1 = DagCallbackRequest(
+ dag_id="test_start_date_scheduling",
+ full_filepath=str(dag_filepath),
+ is_failure_callback=True,
+ processor_subdir=str(tmpdir),
+ run_id='123',
+ )
+ callback2 = DagCallbackRequest(
+ dag_id="test_start_date_scheduling",
+ full_filepath=str(dag_filepath),
+ is_failure_callback=True,
+ processor_subdir="/some/other/dir/",
+ run_id='456',
+ )
+
+ with create_session() as session:
+ session.add(DbCallbackRequest(callback=callback1, priority_weight=11))
+ session.add(DbCallbackRequest(callback=callback2, priority_weight=10))
+
+ child_pipe, parent_pipe = multiprocessing.Pipe()
+ manager = DagFileProcessorManager(
+ dag_directory=tmpdir,
+ max_runs=1,
+ processor_timeout=timedelta(days=365),
+ signal_conn=child_pipe,
+ dag_ids=[],
+ pickle_dags=False,
+ async_mode=False,
+ )
+
+ with create_session() as session:
+ self.run_processor_manager_one_loop(manager, parent_pipe)
+ assert session.query(DbCallbackRequest).count() == 1
+
@conf_vars(
{
('scheduler', 'standalone_dag_processor'): 'True',
@@ -808,12 +923,13 @@ class TestDagFileProcessorManager:
full_filepath=str(dag_filepath),
is_failure_callback=True,
run_id=str(i),
+ processor_subdir=str(tmpdir),
)
session.add(DbCallbackRequest(callback=callback, priority_weight=i))
child_pipe, parent_pipe = multiprocessing.Pipe()
manager = DagFileProcessorManager(
- dag_directory=tmpdir,
+ dag_directory=str(tmpdir),
max_runs=1,
processor_timeout=timedelta(days=365),
signal_conn=child_pipe,
@@ -844,6 +960,7 @@ class TestDagFileProcessorManager:
dag_id="test_start_date_scheduling",
full_filepath=str(dag_filepath),
is_failure_callback=True,
+ processor_subdir=str(tmpdir),
run_id='123',
)
session.add(DbCallbackRequest(callback=callback, priority_weight=10))
@@ -884,6 +1001,7 @@ class TestDagFileProcessorManager:
dag_id="dag1",
run_id="run1",
is_failure_callback=False,
+ processor_subdir=tmpdir,
msg=None,
)
dag1_req2 = DagCallbackRequest(
@@ -891,16 +1009,26 @@ class TestDagFileProcessorManager:
dag_id="dag1",
run_id="run1",
is_failure_callback=False,
+ processor_subdir=tmpdir,
msg=None,
)
- dag1_sla1 = SlaCallbackRequest(full_filepath="/green_eggs/ham/file1.py", dag_id="dag1")
- dag1_sla2 = SlaCallbackRequest(full_filepath="/green_eggs/ham/file1.py", dag_id="dag1")
+ dag1_sla1 = SlaCallbackRequest(
+ full_filepath="/green_eggs/ham/file1.py",
+ dag_id="dag1",
+ processor_subdir=tmpdir,
+ )
+ dag1_sla2 = SlaCallbackRequest(
+ full_filepath="/green_eggs/ham/file1.py",
+ dag_id="dag1",
+ processor_subdir=tmpdir,
+ )
dag2_req1 = DagCallbackRequest(
full_filepath="/green_eggs/ham/file2.py",
dag_id="dag2",
run_id="run1",
is_failure_callback=False,
+ processor_subdir=tmpdir,
msg=None,
)
@@ -946,10 +1074,6 @@ class TestDagFileProcessorAgent(unittest.TestCase):
for mod in remove_list:
del sys.modules[mod]
- @staticmethod
- def _processor_factory(file_path, zombies, dag_ids, pickle_dags):
- return DagFileProcessorProcess(file_path, pickle_dags, dag_ids, zombies)
-
def test_reload_module(self):
"""
Configure the context to have logging.logging_config_class set to a fake logging
diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py
index d007851092..b0e09d0417 100644
--- a/tests/dag_processing/test_processor.py
+++ b/tests/dag_processing/test_processor.py
@@ -97,8 +97,10 @@ class TestDagFileProcessor:
self.scheduler_job = None
self.clean_db()
- def _process_file(self, file_path, session):
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ def _process_file(self, file_path, dag_directory, session):
+ dag_file_processor = DagFileProcessor(
+ dag_ids=[], dag_directory=str(dag_directory), log=mock.MagicMock()
+ )
dag_file_processor.process_file(file_path, [], False, session)
@@ -124,7 +126,9 @@ class TestDagFileProcessor:
session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor = DagFileProcessor(
+ dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
+ )
dag_file_processor.manage_slas(dag=dag, session=session)
assert sla_callback.called
@@ -153,7 +157,9 @@ class TestDagFileProcessor:
session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor = DagFileProcessor(
+ dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
+ )
dag_file_processor.manage_slas(dag=dag, session=session)
sla_callback.assert_not_called()
@@ -192,7 +198,9 @@ class TestDagFileProcessor:
)
# Now call manage_slas and see if the sla_miss callback gets called
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor = DagFileProcessor(
+ dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
+ )
dag_file_processor.manage_slas(dag=dag, session=session)
sla_callback.assert_not_called()
@@ -220,7 +228,9 @@ class TestDagFileProcessor:
session.merge(ti)
session.flush()
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor = DagFileProcessor(
+ dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
+ )
dag_file_processor.manage_slas(dag=dag, session=session)
sla_miss_count = (
session.query(SlaMiss)
@@ -264,7 +274,7 @@ class TestDagFileProcessor:
# Now call manage_slas and see if the sla_miss callback gets called
mock_log = mock.MagicMock()
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+ dag_file_processor = DagFileProcessor(dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock_log)
dag_file_processor.manage_slas(dag=dag, session=session)
assert sla_callback.called
mock_log.exception.assert_called_once_with(
@@ -294,7 +304,9 @@ class TestDagFileProcessor:
session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date))
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor = DagFileProcessor(
+ dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
+ )
dag_file_processor.manage_slas(dag=dag, session=session)
@@ -333,7 +345,7 @@ class TestDagFileProcessor:
session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
mock_log = mock.MagicMock()
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+ dag_file_processor = DagFileProcessor(dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock_log)
dag_file_processor.manage_slas(dag=dag, session=session)
mock_log.exception.assert_called_once_with(
@@ -364,13 +376,15 @@ class TestDagFileProcessor:
)
mock_log = mock.MagicMock()
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+ dag_file_processor = DagFileProcessor(dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock_log)
dag_file_processor.manage_slas(dag=dag, session=session)
@patch.object(TaskInstance, 'handle_failure')
def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor = DagFileProcessor(
+ dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
+ )
with create_session() as session:
session.query(TaskInstance).delete()
dag = dagbag.get_dag('example_branch_operator')
@@ -401,7 +415,9 @@ class TestDagFileProcessor:
@patch.object(TaskInstance, 'handle_failure')
def test_execute_on_failure_callbacks_without_dag(self, mock_ti_handle_failure, has_serialized_dag):
dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor = DagFileProcessor(
+ dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
+ )
with create_session() as session:
session.query(TaskInstance).delete()
dag = dagbag.get_dag('example_branch_operator')
@@ -431,7 +447,9 @@ class TestDagFileProcessor:
def test_failure_callbacks_should_not_drop_hostname(self):
dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor = DagFileProcessor(
+ dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
+ )
dag_file_processor.UNIT_TEST_MODE = False
with create_session() as session:
@@ -462,7 +480,9 @@ class TestDagFileProcessor:
callback_file = tmp_path.joinpath("callback.txt")
callback_file.touch()
monkeypatch.setenv("AIRFLOW_CALLBACK_FILE", str(callback_file))
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor = DagFileProcessor(
+ dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
+ )
dag = get_test_dag('test_on_failure_callback')
task = dag.get_task(task_id='test_on_failure_callback_task')
@@ -496,7 +516,7 @@ class TestDagFileProcessor:
unparseable_file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS)
with create_session() as session:
- self._process_file(unparseable_filename, session)
+ self._process_file(unparseable_filename, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
assert len(import_errors) == 1
@@ -513,7 +533,7 @@ class TestDagFileProcessor:
zip_file.writestr(TEMP_DAG_FILENAME, UNPARSEABLE_DAG_FILE_CONTENTS)
with create_session() as session:
- self._process_file(zip_filename, session)
+ self._process_file(zip_filename, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
assert len(import_errors) == 1
@@ -530,7 +550,7 @@ class TestDagFileProcessor:
for line in main_dag:
next_dag.write(line)
# first we parse the dag
- self._process_file(temp_dagfile, session)
+ self._process_file(temp_dagfile, dag_directory=tmpdir, session=session)
# assert DagModel.has_import_errors is false
dm = session.query(DagModel).filter(DagModel.fileloc == temp_dagfile).first()
assert not dm.has_import_errors
@@ -538,7 +558,7 @@ class TestDagFileProcessor:
with open(temp_dagfile, 'a') as file:
file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS)
- self._process_file(temp_dagfile, session)
+ self._process_file(temp_dagfile, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
assert len(import_errors) == 1
@@ -555,7 +575,7 @@ class TestDagFileProcessor:
parseable_file.writelines(PARSEABLE_DAG_FILE_CONTENTS)
with create_session() as session:
- self._process_file(parseable_filename, session)
+ self._process_file(parseable_filename, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
assert len(import_errors) == 0
@@ -568,7 +588,7 @@ class TestDagFileProcessor:
zip_file.writestr(TEMP_DAG_FILENAME, PARSEABLE_DAG_FILE_CONTENTS)
with create_session() as session:
- self._process_file(zip_filename, session)
+ self._process_file(zip_filename, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
assert len(import_errors) == 0
@@ -583,14 +603,14 @@ class TestDagFileProcessor:
with open(unparseable_filename, 'w') as unparseable_file:
unparseable_file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS)
session = settings.Session()
- self._process_file(unparseable_filename, session)
+ self._process_file(unparseable_filename, dag_directory=tmpdir, session=session)
# Generate replacement import error (the error will be on the second line now)
with open(unparseable_filename, 'w') as unparseable_file:
unparseable_file.writelines(
PARSEABLE_DAG_FILE_CONTENTS + os.linesep + UNPARSEABLE_DAG_FILE_CONTENTS
)
- self._process_file(unparseable_filename, session)
+ self._process_file(unparseable_filename, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
@@ -612,7 +632,7 @@ class TestDagFileProcessor:
with open(filename_to_parse, 'w') as file_to_parse:
file_to_parse.writelines(UNPARSEABLE_DAG_FILE_CONTENTS)
session = settings.Session()
- self._process_file(filename_to_parse, session)
+ self._process_file(filename_to_parse, dag_directory=tmpdir, session=session)
import_error_1 = (
session.query(errors.ImportError).filter(errors.ImportError.filename == filename_to_parse).one()
@@ -620,7 +640,7 @@ class TestDagFileProcessor:
# process the file multiple times
for _ in range(10):
- self._process_file(filename_to_parse, session)
+ self._process_file(filename_to_parse, dag_directory=tmpdir, session=session)
import_error_2 = (
session.query(errors.ImportError).filter(errors.ImportError.filename == filename_to_parse).one()
@@ -636,12 +656,12 @@ class TestDagFileProcessor:
with open(filename_to_parse, 'w') as file_to_parse:
file_to_parse.writelines(UNPARSEABLE_DAG_FILE_CONTENTS)
session = settings.Session()
- self._process_file(filename_to_parse, session)
+ self._process_file(filename_to_parse, dag_directory=tmpdir, session=session)
# Remove the import error from the file
with open(filename_to_parse, 'w') as file_to_parse:
file_to_parse.writelines(PARSEABLE_DAG_FILE_CONTENTS)
- self._process_file(filename_to_parse, session)
+ self._process_file(filename_to_parse, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
@@ -656,7 +676,7 @@ class TestDagFileProcessor:
zip_filename = os.path.join(tmpdir, "test_zip.zip")
with ZipFile(zip_filename, "w") as zip_file:
zip_file.writestr(TEMP_DAG_FILENAME, UNPARSEABLE_DAG_FILE_CONTENTS)
- self._process_file(zip_filename, session)
+ self._process_file(zip_filename, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
assert len(import_errors) == 1
@@ -664,7 +684,7 @@ class TestDagFileProcessor:
# Remove the import error from the file
with ZipFile(zip_filename, "w") as zip_file:
zip_file.writestr(TEMP_DAG_FILENAME, 'import os # airflow DAG')
- self._process_file(zip_filename, session)
+ self._process_file(zip_filename, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
assert len(import_errors) == 0
@@ -677,7 +697,7 @@ class TestDagFileProcessor:
unparseable_file.writelines(INVALID_DAG_WITH_DEPTH_FILE_CONTENTS)
with create_session() as session:
- self._process_file(unparseable_filename, session)
+ self._process_file(unparseable_filename, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
assert len(import_errors) == 1
@@ -703,7 +723,7 @@ class TestDagFileProcessor:
unparseable_file.writelines(INVALID_DAG_WITH_DEPTH_FILE_CONTENTS)
with create_session() as session:
- self._process_file(unparseable_filename, session)
+ self._process_file(unparseable_filename, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
assert len(import_errors) == 1
@@ -726,7 +746,7 @@ class TestDagFileProcessor:
invalid_zip_file.writestr(TEMP_DAG_FILENAME, INVALID_DAG_WITH_DEPTH_FILE_CONTENTS)
with create_session() as session:
- self._process_file(invalid_zip_filename, session)
+ self._process_file(invalid_zip_filename, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
assert len(import_errors) == 1
@@ -753,7 +773,7 @@ class TestDagFileProcessor:
invalid_zip_file.writestr(TEMP_DAG_FILENAME, INVALID_DAG_WITH_DEPTH_FILE_CONTENTS)
with create_session() as session:
- self._process_file(invalid_zip_filename, session)
+ self._process_file(invalid_zip_filename, dag_directory=tmpdir, session=session)
import_errors = session.query(errors.ImportError).all()
assert len(import_errors) == 1
@@ -779,7 +799,7 @@ class TestProcessorAgent:
def test_error_when_waiting_in_async_mode(self, tmp_path):
self.processor_agent = DagFileProcessorAgent(
- dag_directory=str(tmp_path),
+ dag_directory=tmp_path,
max_runs=1,
processor_timeout=datetime.timedelta(1),
dag_ids=[],
@@ -792,7 +812,7 @@ class TestProcessorAgent:
def test_default_multiprocessing_behaviour(self, tmp_path):
self.processor_agent = DagFileProcessorAgent(
- dag_directory=str(tmp_path),
+ dag_directory=tmp_path,
max_runs=1,
processor_timeout=datetime.timedelta(1),
dag_ids=[],
@@ -806,7 +826,7 @@ class TestProcessorAgent:
@conf_vars({("core", "mp_start_method"): "spawn"})
def test_spawn_multiprocessing_behaviour(self, tmp_path):
self.processor_agent = DagFileProcessorAgent(
- dag_directory=str(tmp_path),
+ dag_directory=tmp_path,
max_runs=1,
processor_timeout=datetime.timedelta(1),
dag_ids=[],
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index b299b852c2..c46f846b3e 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -136,7 +136,6 @@ class TestSchedulerJob:
# Speed up some tests by not running the tasks, just look at what we
# enqueue!
self.null_exec: Optional[MockExecutor] = MockExecutor()
-
# Since we don't want to store the code for the DAG defined in this file
with patch('airflow.dag_processing.manager.SerializedDagModel.remove_deleted_dags'), patch(
'airflow.models.dag.DagCode.bulk_sync_to_db'
@@ -343,6 +342,7 @@ class TestSchedulerJob:
mock_task_callback.assert_called_once_with(
full_filepath=dag.fileloc,
simple_task_instance=mock.ANY,
+ processor_subdir=None,
msg='Executor reports task instance '
'<TaskInstance: test_process_executor_events_with_callback.dummy_task test [queued]> '
'finished (failed) although the task says its queued. (Info: None) '
@@ -1610,6 +1610,7 @@ class TestSchedulerJob:
dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout',
start_date=DEFAULT_DATE,
max_active_runs=1,
+ processor_subdir=TEST_DAG_FOLDER,
dagrun_timeout=datetime.timedelta(seconds=60),
) as dag:
EmptyOperator(task_id='dummy')
@@ -1657,6 +1658,7 @@ class TestSchedulerJob:
dag_id=dr.dag_id,
is_failure_callback=True,
run_id=dr.run_id,
+ processor_subdir=TEST_DAG_FOLDER,
msg="timed_out",
)
@@ -1674,6 +1676,7 @@ class TestSchedulerJob:
with dag_maker(
dag_id='test_scheduler_fail_dagrun_timeout',
dagrun_timeout=datetime.timedelta(seconds=60),
+ processor_subdir=TEST_DAG_FOLDER,
session=session,
):
EmptyOperator(task_id='dummy')
@@ -1697,6 +1700,7 @@ class TestSchedulerJob:
dag_id=dr.dag_id,
is_failure_callback=True,
run_id=dr.run_id,
+ processor_subdir=TEST_DAG_FOLDER,
msg="timed_out",
)
@@ -1751,6 +1755,7 @@ class TestSchedulerJob:
dag_id='test_dagrun_callbacks_are_called',
on_success_callback=lambda x: print("success"),
on_failure_callback=lambda x: print("failed"),
+ processor_subdir=TEST_DAG_FOLDER,
) as dag:
EmptyOperator(task_id='dummy')
@@ -1773,6 +1778,7 @@ class TestSchedulerJob:
dag_id=dr.dag_id,
is_failure_callback=bool(state == State.FAILED),
run_id=dr.run_id,
+ processor_subdir=TEST_DAG_FOLDER,
msg=expected_callback_msg,
)
@@ -1786,6 +1792,7 @@ class TestSchedulerJob:
dag_id='test_dagrun_timeout_callbacks_are_stored_in_database',
on_failure_callback=lambda x: print("failed"),
dagrun_timeout=timedelta(hours=1),
+ processor_subdir=TEST_DAG_FOLDER,
) as dag:
EmptyOperator(task_id='empty')
@@ -1812,6 +1819,7 @@ class TestSchedulerJob:
dag_id=dr.dag_id,
is_failure_callback=True,
run_id=dr.run_id,
+ processor_subdir=TEST_DAG_FOLDER,
msg='timed_out',
)
@@ -2996,7 +3004,11 @@ class TestSchedulerJob:
def test_send_sla_callbacks_to_processor_sla_with_task_slas(self, schedule, dag_maker):
"""Test SLA Callbacks are sent to the DAG Processor when SLAs are defined on tasks"""
dag_id = 'test_send_sla_callbacks_to_processor_sla_with_task_slas'
- with dag_maker(dag_id=dag_id, schedule=schedule) as dag:
+ with dag_maker(
+ dag_id=dag_id,
+ schedule=schedule,
+ processor_subdir=TEST_DAG_FOLDER,
+ ) as dag:
EmptyOperator(task_id='task1', sla=timedelta(seconds=60))
with patch.object(settings, "CHECK_SLAS", True):
@@ -3005,7 +3017,11 @@ class TestSchedulerJob:
self.scheduler_job._send_sla_callbacks_to_processor(dag)
- expected_callback = SlaCallbackRequest(full_filepath=dag.fileloc, dag_id=dag.dag_id)
+ expected_callback = SlaCallbackRequest(
+ full_filepath=dag.fileloc,
+ dag_id=dag.dag_id,
+ processor_subdir=TEST_DAG_FOLDER,
+ )
self.scheduler_job.executor.callback_sink.send.assert_called_once_with(expected_callback)
@pytest.mark.parametrize(
@@ -3284,7 +3300,7 @@ class TestSchedulerJob:
)
assert dr is not None
# Run DAG.bulk_write_to_db -- this is run when in DagFileProcessor.process_file
- DAG.bulk_write_to_db([dag], session)
+ DAG.bulk_write_to_db([dag], session=session)
# Test that 'dag_model.next_dagrun' has not been changed because of newly created external
# triggered DagRun.
@@ -4134,7 +4150,7 @@ class TestSchedulerJob:
session = settings.Session()
session.query(LocalTaskJob).delete()
dag = dagbag.get_dag('test_example_bash_operator')
- dag.sync_to_db()
+ dag.sync_to_db(processor_subdir=TEST_DAG_FOLDER)
dag_run = dag.create_dagrun(
state=DagRunState.RUNNING,
@@ -4156,14 +4172,6 @@ class TestSchedulerJob:
ti.job_id = local_job.id
session.flush()
- expected_failure_callback_requests = [
- TaskCallbackRequest(
- full_filepath=dag.fileloc,
- simple_task_instance=SimpleTaskInstance.from_ti(ti),
- msg="Message",
- )
- ]
-
self.scheduler_job = SchedulerJob(subdir=os.devnull)
self.scheduler_job.executor = MockExecutor()
self.scheduler_job.processor_agent = mock.MagicMock()
@@ -4171,10 +4179,53 @@ class TestSchedulerJob:
self.scheduler_job._find_zombies(session=session)
self.scheduler_job.executor.callback_sink.send.assert_called_once()
+
+ expected_failure_callback_requests = [
+ TaskCallbackRequest(
+ full_filepath=dag.fileloc,
+ simple_task_instance=SimpleTaskInstance.from_ti(ti),
+ processor_subdir=TEST_DAG_FOLDER,
+ msg=str(self.scheduler_job._generate_zombie_message_details(ti)),
+ )
+ ]
callback_requests = self.scheduler_job.executor.callback_sink.send.call_args[0]
+ assert len(callback_requests) == 1
assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == {
result.simple_task_instance.key for result in callback_requests
}
+ expected_failure_callback_requests[0].simple_task_instance = None
+ callback_requests[0].simple_task_instance = None
+ assert expected_failure_callback_requests[0] == callback_requests[0]
+
+ def test_cleanup_stale_dags(self):
+ dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False)
+ with create_session() as session:
+ dag = dagbag.get_dag('test_example_bash_operator')
+ dag.sync_to_db()
+ dm = DagModel.get_current('test_example_bash_operator')
+ # Make it "stale".
+ dm.last_parsed_time = timezone.utcnow() - timedelta(minutes=11)
+ session.merge(dm)
+
+ # This one should remain active.
+ dag = dagbag.get_dag('test_start_date_scheduling')
+ dag.sync_to_db()
+
+ session.flush()
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.executor = MockExecutor()
+ self.scheduler_job.processor_agent = mock.MagicMock()
+
+ active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar()
+ assert active_dag_count == 2
+
+ self.scheduler_job._cleanup_stale_dags(session)
+
+ session.flush()
+
+ active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar()
+ assert active_dag_count == 1
@mock.patch.object(settings, 'USE_JOB_SCHEDULE', False)
def run_scheduler_until_dagrun_terminal(self, job: SchedulerJob):
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 173a08fb70..78ec31d3f4 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -798,7 +798,7 @@ class TestDag:
session = settings.Session()
dag.clear()
- DAG.bulk_write_to_db([dag], session)
+ DAG.bulk_write_to_db([dag], session=session)
model = session.query(DagModel).get((dag.dag_id,))
@@ -832,7 +832,7 @@ class TestDag:
session = settings.Session()
dag.clear()
- DAG.bulk_write_to_db([dag], session)
+ DAG.bulk_write_to_db([dag], session=session)
model = session.query(DagModel).get((dag.dag_id,))
@@ -871,7 +871,7 @@ class TestDag:
EmptyOperator(task_id=task_id, dag=dag2, outlets=[Dataset(uri1, extra={"should": "be used"})])
session = settings.Session()
dag1.clear()
- DAG.bulk_write_to_db([dag1, dag2], session)
+ DAG.bulk_write_to_db([dag1, dag2], session=session)
session.commit()
stored_datasets = {x.uri: x for x in session.query(DatasetModel).all()}
d1 = stored_datasets[d1.uri]
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index c2a4f2aafa..b9759f64a0 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -815,9 +815,9 @@ class TestDagBag:
# Test that 3 attempts were made to run 'DAG.bulk_write_to_db' successfully
mock_bulk_write_to_db.assert_has_calls(
[
- mock.call(mock.ANY, session=mock.ANY),
- mock.call(mock.ANY, session=mock.ANY),
- mock.call(mock.ANY, session=mock.ANY),
+ mock.call(mock.ANY, processor_subdir=None, session=mock.ANY),
+ mock.call(mock.ANY, processor_subdir=None, session=mock.ANY),
+ mock.call(mock.ANY, processor_subdir=None, session=mock.ANY),
]
)
# Assert that rollback is called twice (i.e. whenever OperationalError occurs)
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 5e7c9e8046..60858765ff 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -426,6 +426,8 @@ class TestDagRun:
start_date=datetime.datetime(2017, 1, 1),
on_success_callback=on_success_callable,
)
+ DAG.bulk_write_to_db(dags=[dag], processor_subdir='/tmp/test', session=session)
+
dag_task1 = EmptyOperator(task_id='test_state_succeeded1', dag=dag)
dag_task2 = EmptyOperator(task_id='test_state_succeeded2', dag=dag)
dag_task1.set_downstream(dag_task2)
@@ -449,6 +451,7 @@ class TestDagRun:
dag_id="test_dagrun_update_state_with_handle_callback_success",
run_id=dag_run.run_id,
is_failure_callback=False,
+ processor_subdir='/tmp/test',
msg="success",
)
@@ -461,6 +464,8 @@ class TestDagRun:
start_date=datetime.datetime(2017, 1, 1),
on_failure_callback=on_failure_callable,
)
+ DAG.bulk_write_to_db(dags=[dag], processor_subdir='/tmp/test', session=session)
+
dag_task1 = EmptyOperator(task_id='test_state_succeeded1', dag=dag)
dag_task2 = EmptyOperator(task_id='test_state_failed2', dag=dag)
dag_task1.set_downstream(dag_task2)
@@ -484,6 +489,7 @@ class TestDagRun:
dag_id="test_dagrun_update_state_with_handle_callback_failure",
run_id=dag_run.run_id,
is_failure_callback=True,
+ processor_subdir='/tmp/test',
msg="task_failure",
)
diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py
index c9b3a63cfe..121474747e 100644
--- a/tests/models/test_serialized_dag.py
+++ b/tests/models/test_serialized_dag.py
@@ -90,7 +90,7 @@ class SerializedDagModelTest(unittest.TestCase):
# Verifies JSON schema.
SerializedDAG.validate_schema(result.data)
- def test_serialized_dag_is_updated_only_if_dag_is_changed(self):
+ def test_serialized_dag_is_updated_if_dag_is_changed(self):
"""Test Serialized DAG is updated if DAG is changed"""
example_dags = make_example_dags(example_dags_module)
example_bash_op_dag = example_dags.get("example_bash_operator")
@@ -121,6 +121,33 @@ class SerializedDagModelTest(unittest.TestCase):
assert s_dag_2.data["dag"]["tags"] == ["example", "example2", "new_tag"]
assert dag_updated is True
+ def test_serialized_dag_is_updated_if_processor_subdir_changed(self):
+ """Test Serialized DAG is updated if processor_subdir is changed"""
+ example_dags = make_example_dags(example_dags_module)
+ example_bash_op_dag = example_dags.get("example_bash_operator")
+ dag_updated = SDM.write_dag(dag=example_bash_op_dag, processor_subdir='/tmp/test')
+ assert dag_updated is True
+
+ with create_session() as session:
+ s_dag = session.query(SDM).get(example_bash_op_dag.dag_id)
+
+ # Test that if DAG is not changed, Serialized DAG is not re-written and last_updated
+ # column is not updated
+ dag_updated = SDM.write_dag(dag=example_bash_op_dag, processor_subdir='/tmp/test')
+ s_dag_1 = session.query(SDM).get(example_bash_op_dag.dag_id)
+
+ assert s_dag_1.dag_hash == s_dag.dag_hash
+ assert s_dag.last_updated == s_dag_1.last_updated
+ assert dag_updated is False
+ session.flush()
+
+ # Update DAG
+ dag_updated = SDM.write_dag(dag=example_bash_op_dag, processor_subdir='/tmp/other')
+ s_dag_2 = session.query(SDM).get(example_bash_op_dag.dag_id)
+
+ assert s_dag.processor_subdir != s_dag_2.processor_subdir
+ assert dag_updated is True
+
def test_read_dags(self):
"""DAGs can be read from database."""
example_dags = self._write_example_dags()
diff --git a/tests/test_utils/perf/perf_kit/sqlalchemy.py b/tests/test_utils/perf/perf_kit/sqlalchemy.py
index 37cf0fe14e..480c2e7e26 100644
--- a/tests/test_utils/perf/perf_kit/sqlalchemy.py
+++ b/tests/test_utils/perf/perf_kit/sqlalchemy.py
@@ -231,7 +231,7 @@ if __name__ == "__main__":
},
):
log = logging.getLogger(__name__)
- processor = DagFileProcessor(dag_ids=[], log=log)
+ processor = DagFileProcessor(dag_ids=[], dag_directory="/tmp", log=log)
dag_file = os.path.join(os.path.dirname(__file__), os.path.pardir, "dags", "elastic_dag.py")
processor.process_file(file_path=dag_file, callback_requests=[])
diff --git a/tests/www/views/test_views_home.py b/tests/www/views/test_views_home.py
index e17586e5f8..27c37a9292 100644
--- a/tests/www/views/test_views_home.py
+++ b/tests/www/views/test_views_home.py
@@ -121,7 +121,7 @@ TEST_FILTER_DAG_IDS = ['filter_test_1', 'filter_test_2', 'a_first_dag_id_asc']
def _process_file(file_path, session):
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor = DagFileProcessor(dag_ids=[], dag_directory='/tmp', log=mock.MagicMock())
dag_file_processor.process_file(file_path, [], False, session)