You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/08/13 23:15:04 UTC
[airflow] 03/09: Move DagFileProcessor and DagFileProcessorProcess
out of scheduler_job.py (#16581)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 5bcc620e3f6a5adc17222f598fdc065a8cca53a2
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Fri Jun 25 05:36:56 2021 +0100
Move DagFileProcessor and DagFileProcessorProcess out of scheduler_job.py (#16581)
This change moves DagFileProcessor and DagFileProcessorProcess out of scheduler_job.py.
Also, dag_processing.py was moved out of airflow/utils.
(cherry picked from commit 88ee2aa7ddf91799f25add9c57e1ea128de2b7aa)
---
.github/boring-cyborg.yml | 2 +-
airflow/dag_processing/__init__.py | 16 +
.../manager.py} | 0
airflow/dag_processing/processor.py | 650 ++++++++++++++++++
airflow/jobs/scheduler_job.py | 619 +----------------
tests/dag_processing/__init__.py | 16 +
.../test_manager.py} | 16 +-
tests/dag_processing/test_processor.py | 749 +++++++++++++++++++++
tests/jobs/test_scheduler_job.py | 700 +------------------
tests/test_utils/perf/perf_kit/python.py | 2 +-
tests/test_utils/perf/perf_kit/sqlalchemy.py | 2 +-
11 files changed, 1456 insertions(+), 1316 deletions(-)
diff --git a/.github/boring-cyborg.yml b/.github/boring-cyborg.yml
index d5f7632..8ae0532 100644
--- a/.github/boring-cyborg.yml
+++ b/.github/boring-cyborg.yml
@@ -157,7 +157,7 @@ labelPRBasedOnFilePath:
- airflow/executors/**/*
- airflow/jobs/**/*
- airflow/task/task_runner/**/*
- - airflow/utils/dag_processing.py
+ - airflow/dag_processing/**/*
- docs/apache-airflow/executor/**/*
- docs/apache-airflow/scheduler.rst
- tests/executors/**/*
diff --git a/airflow/dag_processing/__init__.py b/airflow/dag_processing/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/airflow/dag_processing/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/utils/dag_processing.py b/airflow/dag_processing/manager.py
similarity index 100%
rename from airflow/utils/dag_processing.py
rename to airflow/dag_processing/manager.py
diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py
new file mode 100644
index 0000000..44dc5f2
--- /dev/null
+++ b/airflow/dag_processing/processor.py
@@ -0,0 +1,650 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import logging
+import multiprocessing
+import os
+import signal
+import threading
+from contextlib import redirect_stderr, redirect_stdout, suppress
+from datetime import timedelta
+from multiprocessing.connection import Connection as MultiprocessingConnection
+from typing import List, Optional, Set, Tuple
+
+from setproctitle import setproctitle # pylint: disable=no-name-in-module
+from sqlalchemy import func, or_
+from sqlalchemy.orm.session import Session
+
+from airflow import models, settings
+from airflow.configuration import conf
+from airflow.dag_processing.manager import AbstractDagFileProcessorProcess
+from airflow.exceptions import AirflowException, TaskNotFound
+from airflow.models import DAG, DagModel, SlaMiss, errors
+from airflow.models.dagbag import DagBag
+from airflow.stats import Stats
+from airflow.utils import timezone
+from airflow.utils.callback_requests import (
+ CallbackRequest,
+ DagCallbackRequest,
+ SlaCallbackRequest,
+ TaskCallbackRequest,
+)
+from airflow.utils.email import get_email_address_list, send_email
+from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context
+from airflow.utils.mixins import MultiprocessingStartMethodMixin
+from airflow.utils.session import provide_session
+from airflow.utils.state import State
+
+TI = models.TaskInstance
+
+
+class DagFileProcessorProcess(AbstractDagFileProcessorProcess, LoggingMixin, MultiprocessingStartMethodMixin):
+ """Runs DAG processing in a separate process using DagFileProcessor
+
+ :param file_path: a Python file containing Airflow DAG definitions
+ :type file_path: str
+ :param pickle_dags: whether to serialize the DAG objects to the DB
+ :type pickle_dags: bool
+ :param dag_ids: If specified, only look at these DAG ID's
+ :type dag_ids: List[str]
+ :param callback_requests: failure callback to execute
+ :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
+ """
+
+ # Counter that increments every time an instance of this class is created
+ class_creation_counter = 0
+
+ def __init__(
+ self,
+ file_path: str,
+ pickle_dags: bool,
+ dag_ids: Optional[List[str]],
+ callback_requests: List[CallbackRequest],
+ ):
+ super().__init__()
+ self._file_path = file_path
+ self._pickle_dags = pickle_dags
+ self._dag_ids = dag_ids
+ self._callback_requests = callback_requests
+
+ # The process that was launched to process the given .
+ self._process: Optional[multiprocessing.process.BaseProcess] = None
+ # The result of DagFileProcessor.process_file(file_path).
+ self._result: Optional[Tuple[int, int]] = None
+ # Whether the process is done running.
+ self._done = False
+ # When the process started.
+ self._start_time: Optional[datetime.datetime] = None
+ # This ID is use to uniquely name the process / thread that's launched
+ # by this processor instance
+ self._instance_id = DagFileProcessorProcess.class_creation_counter
+
+ self._parent_channel: Optional[MultiprocessingConnection] = None
+ DagFileProcessorProcess.class_creation_counter += 1
+
+ @property
+ def file_path(self) -> str:
+ return self._file_path
+
+ @staticmethod
+ def _run_file_processor(
+ result_channel: MultiprocessingConnection,
+ parent_channel: MultiprocessingConnection,
+ file_path: str,
+ pickle_dags: bool,
+ dag_ids: Optional[List[str]],
+ thread_name: str,
+ callback_requests: List[CallbackRequest],
+ ) -> None:
+ """
+ Process the given file.
+
+ :param result_channel: the connection to use for passing back the result
+ :type result_channel: multiprocessing.Connection
+ :param parent_channel: the parent end of the channel to close in the child
+ :type parent_channel: multiprocessing.Connection
+ :param file_path: the file to process
+ :type file_path: str
+ :param pickle_dags: whether to pickle the DAGs found in the file and
+ save them to the DB
+ :type pickle_dags: bool
+ :param dag_ids: if specified, only examine DAG ID's that are
+ in this list
+ :type dag_ids: list[str]
+ :param thread_name: the name to use for the process that is launched
+ :type thread_name: str
+ :param callback_requests: failure callback to execute
+ :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
+ :return: the process that was launched
+ :rtype: multiprocessing.Process
+ """
+ # This helper runs in the newly created process
+ log: logging.Logger = logging.getLogger("airflow.processor")
+
+ # Since we share all open FDs from the parent, we need to close the parent side of the pipe here in
+ # the child, else it won't get closed properly until we exit.
+ log.info("Closing parent pipe")
+
+ parent_channel.close()
+ del parent_channel
+
+ set_context(log, file_path)
+ setproctitle(f"airflow scheduler - DagFileProcessor {file_path}")
+
+ try:
+ # redirect stdout/stderr to log
+ with redirect_stdout(StreamLogWriter(log, logging.INFO)), redirect_stderr(
+ StreamLogWriter(log, logging.WARN)
+ ), Stats.timer() as timer:
+ # Re-configure the ORM engine as there are issues with multiple processes
+ settings.configure_orm()
+
+ # Change the thread name to differentiate log lines. This is
+ # really a separate process, but changing the name of the
+ # process doesn't work, so changing the thread name instead.
+ threading.current_thread().name = thread_name
+
+ log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path)
+ dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log)
+ result: Tuple[int, int] = dag_file_processor.process_file(
+ file_path=file_path,
+ pickle_dags=pickle_dags,
+ callback_requests=callback_requests,
+ )
+ result_channel.send(result)
+ log.info("Processing %s took %.3f seconds", file_path, timer.duration)
+ except Exception: # pylint: disable=broad-except
+ # Log exceptions through the logging framework.
+ log.exception("Got an exception! Propagating...")
+ raise
+ finally:
+ # We re-initialized the ORM within this Process above so we need to
+ # tear it down manually here
+ settings.dispose_orm()
+
+ result_channel.close()
+
+ def start(self) -> None:
+ """Launch the process and start processing the DAG."""
+ start_method = self._get_multiprocessing_start_method()
+ context = multiprocessing.get_context(start_method)
+
+ _parent_channel, _child_channel = context.Pipe(duplex=False)
+ process = context.Process(
+ target=type(self)._run_file_processor,
+ args=(
+ _child_channel,
+ _parent_channel,
+ self.file_path,
+ self._pickle_dags,
+ self._dag_ids,
+ f"DagFileProcessor{self._instance_id}",
+ self._callback_requests,
+ ),
+ name=f"DagFileProcessor{self._instance_id}-Process",
+ )
+ self._process = process
+ self._start_time = timezone.utcnow()
+ process.start()
+
+ # Close the child side of the pipe now the subprocess has started -- otherwise this would prevent it
+ # from closing in some cases
+ _child_channel.close()
+ del _child_channel
+
+ # Don't store it on self until after we've started the child process - we don't want to keep it from
+ # getting GCd/closed
+ self._parent_channel = _parent_channel
+
+ def kill(self) -> None:
+ """Kill the process launched to process the file, and ensure consistent state."""
+ if self._process is None:
+ raise AirflowException("Tried to kill before starting!")
+ self._kill_process()
+
+ def terminate(self, sigkill: bool = False) -> None:
+ """
+ Terminate (and then kill) the process launched to process the file.
+
+ :param sigkill: whether to issue a SIGKILL if SIGTERM doesn't work.
+ :type sigkill: bool
+ """
+ if self._process is None or self._parent_channel is None:
+ raise AirflowException("Tried to call terminate before starting!")
+
+ self._process.terminate()
+ # Arbitrarily wait 5s for the process to die
+ with suppress(TimeoutError):
+ self._process._popen.wait(5) # type: ignore # pylint: disable=protected-access
+ if sigkill:
+ self._kill_process()
+ self._parent_channel.close()
+
+ def _kill_process(self) -> None:
+ if self._process is None:
+ raise AirflowException("Tried to kill process before starting!")
+
+ if self._process.is_alive() and self._process.pid:
+ self.log.warning("Killing DAGFileProcessorProcess (PID=%d)", self._process.pid)
+ os.kill(self._process.pid, signal.SIGKILL)
+ if self._parent_channel:
+ self._parent_channel.close()
+
+ @property
+ def pid(self) -> int:
+ """
+ :return: the PID of the process launched to process the given file
+ :rtype: int
+ """
+ if self._process is None or self._process.pid is None:
+ raise AirflowException("Tried to get PID before starting!")
+ return self._process.pid
+
+ @property
+ def exit_code(self) -> Optional[int]:
+ """
+ After the process is finished, this can be called to get the return code
+
+ :return: the exit code of the process
+ :rtype: int
+ """
+ if self._process is None:
+ raise AirflowException("Tried to get exit code before starting!")
+ if not self._done:
+ raise AirflowException("Tried to call retcode before process was finished!")
+ return self._process.exitcode
+
+ @property
+ def done(self) -> bool:
+ """
+ Check if the process launched to process this file is done.
+
+ :return: whether the process is finished running
+ :rtype: bool
+ """
+ if self._process is None or self._parent_channel is None:
+ raise AirflowException("Tried to see if it's done before starting!")
+
+ if self._done:
+ return True
+
+ if self._parent_channel.poll():
+ try:
+ self._result = self._parent_channel.recv()
+ self._done = True
+ self.log.debug("Waiting for %s", self._process)
+ self._process.join()
+ self._parent_channel.close()
+ return True
+ except EOFError:
+ # If we get an EOFError, it means the child end of the pipe has been closed. This only happens
+ # in the finally block. But due to a possible race condition, the process may have not yet
+ # terminated (it could be doing cleanup/python shutdown still). So we kill it here after a
+ # "suitable" timeout.
+ self._done = True
+ # Arbitrary timeout -- error/race condition only, so this doesn't need to be tunable.
+ self._process.join(timeout=5)
+ if self._process.is_alive():
+ # Didn't shut down cleanly - kill it
+ self._kill_process()
+
+ if not self._process.is_alive():
+ self._done = True
+ self.log.debug("Waiting for %s", self._process)
+ self._process.join()
+ self._parent_channel.close()
+ return True
+
+ return False
+
+ @property
+ def result(self) -> Optional[Tuple[int, int]]:
+ """
+ :return: result of running DagFileProcessor.process_file()
+ :rtype: tuple[int, int] or None
+ """
+ if not self.done:
+ raise AirflowException("Tried to get the result before it's done!")
+ return self._result
+
+ @property
+ def start_time(self) -> datetime.datetime:
+ """
+ :return: when this started to process the file
+ :rtype: datetime
+ """
+ if self._start_time is None:
+ raise AirflowException("Tried to get start time before it started!")
+ return self._start_time
+
+ @property
+ def waitable_handle(self):
+ return self._process.sentinel
+
+
+class DagFileProcessor(LoggingMixin):
+ """
+ Process a Python file containing Airflow DAGs.
+
+ This includes:
+
+ 1. Execute the file and look for DAG objects in the namespace.
+ 2. Execute any Callbacks if passed to DagFileProcessor.process_file
+ 3. Serialize the DAGs and save it to DB (or update existing record in the DB).
+ 4. Pickle the DAG and save it to the DB (if necessary).
+ 5. Record any errors importing the file into ORM
+
+ Returns a tuple of 'number of dags found' and 'the count of import errors'
+
+ :param dag_ids: If specified, only look at these DAG ID's
+ :type dag_ids: List[str]
+ :param log: Logger to save the processing process
+ :type log: logging.Logger
+ """
+
+ UNIT_TEST_MODE: bool = conf.getboolean('core', 'UNIT_TEST_MODE')
+
+ def __init__(self, dag_ids: Optional[List[str]], log: logging.Logger):
+ super().__init__()
+ self.dag_ids = dag_ids
+ self._log = log
+
+ @provide_session
+ def manage_slas(self, dag: DAG, session: Session = None) -> None:
+ """
+ Finding all tasks that have SLAs defined, and sending alert emails
+ where needed. New SLA misses are also recorded in the database.
+
+ We are assuming that the scheduler runs often, so we only check for
+ tasks that should have succeeded in the past hour.
+ """
+ self.log.info("Running SLA Checks for %s", dag.dag_id)
+ if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks):
+ self.log.info("Skipping SLA check for %s because no tasks in DAG have SLAs", dag)
+ return
+
+ qry = (
+ session.query(TI.task_id, func.max(TI.execution_date).label('max_ti'))
+ .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql')
+ .filter(TI.dag_id == dag.dag_id)
+ .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED))
+ .filter(TI.task_id.in_(dag.task_ids))
+ .group_by(TI.task_id)
+ .subquery('sq')
+ )
+
+ max_tis: List[TI] = (
+ session.query(TI)
+ .filter(
+ TI.dag_id == dag.dag_id,
+ TI.task_id == qry.c.task_id,
+ TI.execution_date == qry.c.max_ti,
+ )
+ .all()
+ )
+
+ ts = timezone.utcnow()
+ for ti in max_tis:
+ task = dag.get_task(ti.task_id)
+ if task.sla and not isinstance(task.sla, timedelta):
+ raise TypeError(
+ f"SLA is expected to be timedelta object, got "
+ f"{type(task.sla)} in {task.dag_id}:{task.task_id}"
+ )
+
+ dttm = dag.following_schedule(ti.execution_date)
+ while dttm < timezone.utcnow():
+ following_schedule = dag.following_schedule(dttm)
+ if following_schedule + task.sla < timezone.utcnow():
+ session.merge(
+ SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts)
+ )
+ dttm = dag.following_schedule(dttm)
+ session.commit()
+
+ # pylint: disable=singleton-comparison
+ slas: List[SlaMiss] = (
+ session.query(SlaMiss)
+ .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa
+ .all()
+ )
+ # pylint: enable=singleton-comparison
+
+ if slas: # pylint: disable=too-many-nested-blocks
+ sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas]
+ fetched_tis: List[TI] = (
+ session.query(TI)
+ .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id)
+ .all()
+ )
+ blocking_tis: List[TI] = []
+ for ti in fetched_tis:
+ if ti.task_id in dag.task_ids:
+ ti.task = dag.get_task(ti.task_id)
+ blocking_tis.append(ti)
+ else:
+ session.delete(ti)
+ session.commit()
+
+ task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas)
+ blocking_task_list = "\n".join(
+ ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis
+ )
+ # Track whether email or any alert notification sent
+ # We consider email or the alert callback as notifications
+ email_sent = False
+ notification_sent = False
+ if dag.sla_miss_callback:
+ # Execute the alert callback
+ self.log.info('Calling SLA miss callback')
+ try:
+ dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis)
+ notification_sent = True
+ except Exception: # pylint: disable=broad-except
+ self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id)
+ email_content = f"""\
+ Here's a list of tasks that missed their SLAs:
+ <pre><code>{task_list}\n<code></pre>
+ Blocking tasks:
+ <pre><code>{blocking_task_list}<code></pre>
+ Airflow Webserver URL: {conf.get(section='webserver', key='base_url')}
+ """
+
+ tasks_missed_sla = []
+ for sla in slas:
+ try:
+ task = dag.get_task(sla.task_id)
+ except TaskNotFound:
+ # task already deleted from DAG, skip it
+ self.log.warning(
+ "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id
+ )
+ continue
+ tasks_missed_sla.append(task)
+
+ emails: Set[str] = set()
+ for task in tasks_missed_sla:
+ if task.email:
+ if isinstance(task.email, str):
+ emails |= set(get_email_address_list(task.email))
+ elif isinstance(task.email, (list, tuple)):
+ emails |= set(task.email)
+ if emails:
+ try:
+ send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content)
+ email_sent = True
+ notification_sent = True
+ except Exception: # pylint: disable=broad-except
+ Stats.incr('sla_email_notification_failure')
+ self.log.exception("Could not send SLA Miss email notification for DAG %s", dag.dag_id)
+ # If we sent any notification, update the sla_miss table
+ if notification_sent:
+ for sla in slas:
+ sla.email_sent = email_sent
+ sla.notification_sent = True
+ session.merge(sla)
+ session.commit()
+
+ @staticmethod
+ def update_import_errors(session: Session, dagbag: DagBag) -> None:
+ """
+ For the DAGs in the given DagBag, record any associated import errors and clears
+ errors for files that no longer have them. These are usually displayed through the
+ Airflow UI so that users know that there are issues parsing DAGs.
+
+ :param session: session for ORM operations
+ :type session: sqlalchemy.orm.session.Session
+ :param dagbag: DagBag containing DAGs with import errors
+ :type dagbag: airflow.DagBag
+ """
+ # Clear the errors of the processed files
+ for dagbag_file in dagbag.file_last_changed:
+ session.query(errors.ImportError).filter(errors.ImportError.filename == dagbag_file).delete()
+
+ # Add the errors of the processed files
+ for filename, stacktrace in dagbag.import_errors.items():
+ session.add(
+ errors.ImportError(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace)
+ )
+ session.commit()
+
+ @provide_session
+ def execute_callbacks(
+ self, dagbag: DagBag, callback_requests: List[CallbackRequest], session: Session = None
+ ) -> None:
+ """
+ Execute on failure callbacks. These objects can come from SchedulerJob or from
+ DagFileProcessorManager.
+
+ :param dagbag: Dag Bag of dags
+ :param callback_requests: failure callbacks to execute
+ :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
+ :param session: DB session.
+ """
+ for request in callback_requests:
+ self.log.debug("Processing Callback Request: %s", request)
+ try:
+ if isinstance(request, TaskCallbackRequest):
+ self._execute_task_callbacks(dagbag, request)
+ elif isinstance(request, SlaCallbackRequest):
+ self.manage_slas(dagbag.dags.get(request.dag_id))
+ elif isinstance(request, DagCallbackRequest):
+ self._execute_dag_callbacks(dagbag, request, session)
+ except Exception: # pylint: disable=broad-except
+ self.log.exception(
+ "Error executing %s callback for file: %s",
+ request.__class__.__name__,
+ request.full_filepath,
+ )
+
+ session.commit()
+
+ @provide_session
+ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session):
+ dag = dagbag.dags[request.dag_id]
+ dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session)
+ dag.handle_callback(
+ dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session
+ )
+
+ def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
+ simple_ti = request.simple_task_instance
+ if simple_ti.dag_id in dagbag.dags:
+ dag = dagbag.dags[simple_ti.dag_id]
+ if simple_ti.task_id in dag.task_ids:
+ task = dag.get_task(simple_ti.task_id)
+ ti = TI(task, simple_ti.execution_date)
+ # Get properties needed for failure handling from SimpleTaskInstance.
+ ti.start_date = simple_ti.start_date
+ ti.end_date = simple_ti.end_date
+ ti.try_number = simple_ti.try_number
+ ti.state = simple_ti.state
+ ti.test_mode = self.UNIT_TEST_MODE
+ if request.is_failure_callback:
+ ti.handle_failure_with_callback(error=request.msg, test_mode=ti.test_mode)
+ self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
+
+ @provide_session
+ def process_file(
+ self,
+ file_path: str,
+ callback_requests: List[CallbackRequest],
+ pickle_dags: bool = False,
+ session: Session = None,
+ ) -> Tuple[int, int]:
+ """
+ Process a Python file containing Airflow DAGs.
+
+ This includes:
+
+ 1. Execute the file and look for DAG objects in the namespace.
+ 2. Execute any Callbacks if passed to this method.
+ 3. Serialize the DAGs and save it to DB (or update existing record in the DB).
+ 4. Pickle the DAG and save it to the DB (if necessary).
+ 5. Record any errors importing the file into ORM
+
+ :param file_path: the path to the Python file that should be executed
+ :type file_path: str
+ :param callback_requests: failure callback to execute
+ :type callback_requests: List[airflow.utils.dag_processing.CallbackRequest]
+ :param pickle_dags: whether serialize the DAGs found in the file and
+ save them to the db
+ :type pickle_dags: bool
+ :param session: Sqlalchemy ORM Session
+ :type session: Session
+ :return: number of dags found, count of import errors
+ :rtype: Tuple[int, int]
+ """
+ self.log.info("Processing file %s for tasks to queue", file_path)
+
+ try:
+ dagbag = DagBag(file_path, include_examples=False, include_smart_sensor=False)
+ except Exception: # pylint: disable=broad-except
+ self.log.exception("Failed at reloading the DAG file %s", file_path)
+ Stats.incr('dag_file_refresh_error', 1, 1)
+ return 0, 0
+
+ if len(dagbag.dags) > 0:
+ self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path)
+ else:
+ self.log.warning("No viable dags retrieved from %s", file_path)
+ self.update_import_errors(session, dagbag)
+ return 0, len(dagbag.import_errors)
+
+ self.execute_callbacks(dagbag, callback_requests)
+
+ # Save individual DAGs in the ORM
+ dagbag.sync_to_db()
+
+ if pickle_dags:
+ paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids)
+
+ unpaused_dags: List[DAG] = [
+ dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids
+ ]
+
+ for dag in unpaused_dags:
+ dag.pickle(session)
+
+ # Record import errors into the ORM
+ try:
+ self.update_import_errors(session, dagbag)
+ except Exception: # pylint: disable=broad-except
+ self.log.exception("Error logging import errors!")
+
+ return len(dagbag.dags), len(dagbag.import_errors)
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index fe8e0b0..5b24e00 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -23,15 +23,11 @@ import multiprocessing
import os
import signal
import sys
-import threading
import time
from collections import defaultdict
-from contextlib import redirect_stderr, redirect_stdout, suppress
from datetime import timedelta
-from multiprocessing.connection import Connection as MultiprocessingConnection
from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple
-from setproctitle import setproctitle
from sqlalchemy import and_, func, not_, or_, tuple_
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import load_only, selectinload
@@ -39,10 +35,13 @@ from sqlalchemy.orm.session import Session, make_transient
from airflow import models, settings
from airflow.configuration import conf
-from airflow.exceptions import AirflowException, SerializedDagNotFound, TaskNotFound
+from airflow.dag_processing.manager import DagFileProcessorAgent
+from airflow.dag_processing.processor import DagFileProcessorProcess
+from airflow.exceptions import SerializedDagNotFound
from airflow.executors.executor_loader import UNPICKLEABLE_EXECUTORS
from airflow.jobs.base_job import BaseJob
-from airflow.models import DAG, DagModel, SlaMiss, errors
+from airflow.models import DAG
+from airflow.models.dag import DagModel
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.serialized_dag import SerializedDagModel
@@ -50,17 +49,8 @@ from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey
from airflow.stats import Stats
from airflow.ti_deps.dependencies_states import EXECUTION_STATES
from airflow.utils import timezone
-from airflow.utils.callback_requests import (
- CallbackRequest,
- DagCallbackRequest,
- SlaCallbackRequest,
- TaskCallbackRequest,
-)
-from airflow.utils.dag_processing import AbstractDagFileProcessorProcess, DagFileProcessorAgent
-from airflow.utils.email import get_email_address_list, send_email
+from airflow.utils.callback_requests import CallbackRequest, DagCallbackRequest, TaskCallbackRequest
from airflow.utils.event_scheduler import EventScheduler
-from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context
-from airflow.utils.mixins import MultiprocessingStartMethodMixin
from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
from airflow.utils.session import create_session, provide_session
from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, skip_locked, with_row_locks
@@ -72,603 +62,6 @@ DR = models.DagRun
DM = models.DagModel
-class DagFileProcessorProcess(AbstractDagFileProcessorProcess, LoggingMixin, MultiprocessingStartMethodMixin):
- """Runs DAG processing in a separate process using DagFileProcessor
-
- :param file_path: a Python file containing Airflow DAG definitions
- :type file_path: str
- :param pickle_dags: whether to serialize the DAG objects to the DB
- :type pickle_dags: bool
- :param dag_ids: If specified, only look at these DAG ID's
- :type dag_ids: List[str]
- :param callback_requests: failure callback to execute
- :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
- """
-
- # Counter that increments every time an instance of this class is created
- class_creation_counter = 0
-
- def __init__(
- self,
- file_path: str,
- pickle_dags: bool,
- dag_ids: Optional[List[str]],
- callback_requests: List[CallbackRequest],
- ):
- super().__init__()
- self._file_path = file_path
- self._pickle_dags = pickle_dags
- self._dag_ids = dag_ids
- self._callback_requests = callback_requests
-
- # The process that was launched to process the given .
- self._process: Optional[multiprocessing.process.BaseProcess] = None
- # The result of DagFileProcessor.process_file(file_path).
- self._result: Optional[Tuple[int, int]] = None
- # Whether the process is done running.
- self._done = False
- # When the process started.
- self._start_time: Optional[datetime.datetime] = None
- # This ID is use to uniquely name the process / thread that's launched
- # by this processor instance
- self._instance_id = DagFileProcessorProcess.class_creation_counter
-
- self._parent_channel: Optional[MultiprocessingConnection] = None
- DagFileProcessorProcess.class_creation_counter += 1
-
- @property
- def file_path(self) -> str:
- return self._file_path
-
- @staticmethod
- def _run_file_processor(
- result_channel: MultiprocessingConnection,
- parent_channel: MultiprocessingConnection,
- file_path: str,
- pickle_dags: bool,
- dag_ids: Optional[List[str]],
- thread_name: str,
- callback_requests: List[CallbackRequest],
- ) -> None:
- """
- Process the given file.
-
- :param result_channel: the connection to use for passing back the result
- :type result_channel: multiprocessing.Connection
- :param parent_channel: the parent end of the channel to close in the child
- :type parent_channel: multiprocessing.Connection
- :param file_path: the file to process
- :type file_path: str
- :param pickle_dags: whether to pickle the DAGs found in the file and
- save them to the DB
- :type pickle_dags: bool
- :param dag_ids: if specified, only examine DAG ID's that are
- in this list
- :type dag_ids: list[str]
- :param thread_name: the name to use for the process that is launched
- :type thread_name: str
- :param callback_requests: failure callback to execute
- :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
- :return: the process that was launched
- :rtype: multiprocessing.Process
- """
- # This helper runs in the newly created process
- log: logging.Logger = logging.getLogger("airflow.processor")
-
- # Since we share all open FDs from the parent, we need to close the parent side of the pipe here in
- # the child, else it won't get closed properly until we exit.
- log.info("Closing parent pipe")
-
- parent_channel.close()
- del parent_channel
-
- set_context(log, file_path)
- setproctitle(f"airflow scheduler - DagFileProcessor {file_path}")
-
- try:
- # redirect stdout/stderr to log
- with redirect_stdout(StreamLogWriter(log, logging.INFO)), redirect_stderr(
- StreamLogWriter(log, logging.WARN)
- ), Stats.timer() as timer:
- # Re-configure the ORM engine as there are issues with multiple processes
- settings.configure_orm()
-
- # Change the thread name to differentiate log lines. This is
- # really a separate process, but changing the name of the
- # process doesn't work, so changing the thread name instead.
- threading.current_thread().name = thread_name
-
- log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path)
- dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log)
- result: Tuple[int, int] = dag_file_processor.process_file(
- file_path=file_path,
- pickle_dags=pickle_dags,
- callback_requests=callback_requests,
- )
- result_channel.send(result)
- log.info("Processing %s took %.3f seconds", file_path, timer.duration)
- except Exception: # pylint: disable=broad-except
- # Log exceptions through the logging framework.
- log.exception("Got an exception! Propagating...")
- raise
- finally:
- # We re-initialized the ORM within this Process above so we need to
- # tear it down manually here
- settings.dispose_orm()
-
- result_channel.close()
-
- def start(self) -> None:
- """Launch the process and start processing the DAG."""
- start_method = self._get_multiprocessing_start_method()
- context = multiprocessing.get_context(start_method)
-
- _parent_channel, _child_channel = context.Pipe(duplex=False)
- process = context.Process(
- target=type(self)._run_file_processor,
- args=(
- _child_channel,
- _parent_channel,
- self.file_path,
- self._pickle_dags,
- self._dag_ids,
- f"DagFileProcessor{self._instance_id}",
- self._callback_requests,
- ),
- name=f"DagFileProcessor{self._instance_id}-Process",
- )
- self._process = process
- self._start_time = timezone.utcnow()
- process.start()
-
- # Close the child side of the pipe now the subprocess has started -- otherwise this would prevent it
- # from closing in some cases
- _child_channel.close()
- del _child_channel
-
- # Don't store it on self until after we've started the child process - we don't want to keep it from
- # getting GCd/closed
- self._parent_channel = _parent_channel
-
- def kill(self) -> None:
- """Kill the process launched to process the file, and ensure consistent state."""
- if self._process is None:
- raise AirflowException("Tried to kill before starting!")
- self._kill_process()
-
- def terminate(self, sigkill: bool = False) -> None:
- """
- Terminate (and then kill) the process launched to process the file.
-
- :param sigkill: whether to issue a SIGKILL if SIGTERM doesn't work.
- :type sigkill: bool
- """
- if self._process is None or self._parent_channel is None:
- raise AirflowException("Tried to call terminate before starting!")
-
- self._process.terminate()
- # Arbitrarily wait 5s for the process to die
- with suppress(TimeoutError):
- self._process._popen.wait(5) # type: ignore # pylint: disable=protected-access
- if sigkill:
- self._kill_process()
- self._parent_channel.close()
-
- def _kill_process(self) -> None:
- if self._process is None:
- raise AirflowException("Tried to kill process before starting!")
-
- if self._process.is_alive() and self._process.pid:
- self.log.warning("Killing DAGFileProcessorProcess (PID=%d)", self._process.pid)
- os.kill(self._process.pid, signal.SIGKILL)
- if self._parent_channel:
- self._parent_channel.close()
-
- @property
- def pid(self) -> int:
- """
- :return: the PID of the process launched to process the given file
- :rtype: int
- """
- if self._process is None or self._process.pid is None:
- raise AirflowException("Tried to get PID before starting!")
- return self._process.pid
-
- @property
- def exit_code(self) -> Optional[int]:
- """
- After the process is finished, this can be called to get the return code
-
- :return: the exit code of the process
- :rtype: int
- """
- if self._process is None:
- raise AirflowException("Tried to get exit code before starting!")
- if not self._done:
- raise AirflowException("Tried to call retcode before process was finished!")
- return self._process.exitcode
-
- @property
- def done(self) -> bool:
- """
- Check if the process launched to process this file is done.
-
- :return: whether the process is finished running
- :rtype: bool
- """
- if self._process is None or self._parent_channel is None:
- raise AirflowException("Tried to see if it's done before starting!")
-
- if self._done:
- return True
-
- if self._parent_channel.poll():
- try:
- self._result = self._parent_channel.recv()
- self._done = True
- self.log.debug("Waiting for %s", self._process)
- self._process.join()
- self._parent_channel.close()
- return True
- except EOFError:
- # If we get an EOFError, it means the child end of the pipe has been closed. This only happens
- # in the finally block. But due to a possible race condition, the process may have not yet
- # terminated (it could be doing cleanup/python shutdown still). So we kill it here after a
- # "suitable" timeout.
- self._done = True
- # Arbitrary timeout -- error/race condition only, so this doesn't need to be tunable.
- self._process.join(timeout=5)
- if self._process.is_alive():
- # Didn't shut down cleanly - kill it
- self._kill_process()
-
- if not self._process.is_alive():
- self._done = True
- self.log.debug("Waiting for %s", self._process)
- self._process.join()
- self._parent_channel.close()
- return True
-
- return False
-
- @property
- def result(self) -> Optional[Tuple[int, int]]:
- """
- :return: result of running DagFileProcessor.process_file()
- :rtype: tuple[int, int] or None
- """
- if not self.done:
- raise AirflowException("Tried to get the result before it's done!")
- return self._result
-
- @property
- def start_time(self) -> datetime.datetime:
- """
- :return: when this started to process the file
- :rtype: datetime
- """
- if self._start_time is None:
- raise AirflowException("Tried to get start time before it started!")
- return self._start_time
-
- @property
- def waitable_handle(self):
- return self._process.sentinel
-
-
-class DagFileProcessor(LoggingMixin):
- """
- Process a Python file containing Airflow DAGs.
-
- This includes:
-
- 1. Execute the file and look for DAG objects in the namespace.
- 2. Execute any Callbacks if passed to DagFileProcessor.process_file
- 3. Serialize the DAGs and save it to DB (or update existing record in the DB).
- 4. Pickle the DAG and save it to the DB (if necessary).
- 5. Record any errors importing the file into ORM
-
- Returns a tuple of 'number of dags found' and 'the count of import errors'
-
- :param dag_ids: If specified, only look at these DAG ID's
- :type dag_ids: List[str]
- :param log: Logger to save the processing process
- :type log: logging.Logger
- """
-
- UNIT_TEST_MODE: bool = conf.getboolean('core', 'UNIT_TEST_MODE')
-
- def __init__(self, dag_ids: Optional[List[str]], log: logging.Logger):
- super().__init__()
- self.dag_ids = dag_ids
- self._log = log
-
- @provide_session
- def manage_slas(self, dag: DAG, session: Session = None) -> None:
- """
- Finding all tasks that have SLAs defined, and sending alert emails
- where needed. New SLA misses are also recorded in the database.
-
- We are assuming that the scheduler runs often, so we only check for
- tasks that should have succeeded in the past hour.
- """
- self.log.info("Running SLA Checks for %s", dag.dag_id)
- if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks):
- self.log.info("Skipping SLA check for %s because no tasks in DAG have SLAs", dag)
- return
-
- qry = (
- session.query(TI.task_id, func.max(TI.execution_date).label('max_ti'))
- .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql')
- .filter(TI.dag_id == dag.dag_id)
- .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED))
- .filter(TI.task_id.in_(dag.task_ids))
- .group_by(TI.task_id)
- .subquery('sq')
- )
-
- max_tis: List[TI] = (
- session.query(TI)
- .filter(
- TI.dag_id == dag.dag_id,
- TI.task_id == qry.c.task_id,
- TI.execution_date == qry.c.max_ti,
- )
- .all()
- )
-
- ts = timezone.utcnow()
- for ti in max_tis:
- task = dag.get_task(ti.task_id)
- if task.sla and not isinstance(task.sla, timedelta):
- raise TypeError(
- f"SLA is expected to be timedelta object, got "
- f"{type(task.sla)} in {task.dag_id}:{task.task_id}"
- )
-
- dttm = dag.following_schedule(ti.execution_date)
- while dttm < timezone.utcnow():
- following_schedule = dag.following_schedule(dttm)
- if following_schedule + task.sla < timezone.utcnow():
- session.merge(
- SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts)
- )
- dttm = dag.following_schedule(dttm)
- session.commit()
-
- # pylint: disable=singleton-comparison
- slas: List[SlaMiss] = (
- session.query(SlaMiss)
- .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa
- .all()
- )
- # pylint: enable=singleton-comparison
-
- if slas: # pylint: disable=too-many-nested-blocks
- sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas]
- fetched_tis: List[TI] = (
- session.query(TI)
- .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id)
- .all()
- )
- blocking_tis: List[TI] = []
- for ti in fetched_tis:
- if ti.task_id in dag.task_ids:
- ti.task = dag.get_task(ti.task_id)
- blocking_tis.append(ti)
- else:
- session.delete(ti)
- session.commit()
-
- task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas)
- blocking_task_list = "\n".join(
- ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis
- )
- # Track whether email or any alert notification sent
- # We consider email or the alert callback as notifications
- email_sent = False
- notification_sent = False
- if dag.sla_miss_callback:
- # Execute the alert callback
- self.log.info('Calling SLA miss callback')
- try:
- dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis)
- notification_sent = True
- except Exception: # pylint: disable=broad-except
- self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id)
- email_content = f"""\
- Here's a list of tasks that missed their SLAs:
- <pre><code>{task_list}\n<code></pre>
- Blocking tasks:
- <pre><code>{blocking_task_list}<code></pre>
- Airflow Webserver URL: {conf.get(section='webserver', key='base_url')}
- """
-
- tasks_missed_sla = []
- for sla in slas:
- try:
- task = dag.get_task(sla.task_id)
- except TaskNotFound:
- # task already deleted from DAG, skip it
- self.log.warning(
- "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id
- )
- continue
- tasks_missed_sla.append(task)
-
- emails: Set[str] = set()
- for task in tasks_missed_sla:
- if task.email:
- if isinstance(task.email, str):
- emails |= set(get_email_address_list(task.email))
- elif isinstance(task.email, (list, tuple)):
- emails |= set(task.email)
- if emails:
- try:
- send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content)
- email_sent = True
- notification_sent = True
- except Exception: # pylint: disable=broad-except
- Stats.incr('sla_email_notification_failure')
- self.log.exception("Could not send SLA Miss email notification for DAG %s", dag.dag_id)
- # If we sent any notification, update the sla_miss table
- if notification_sent:
- for sla in slas:
- sla.email_sent = email_sent
- sla.notification_sent = True
- session.merge(sla)
- session.commit()
-
- @staticmethod
- def update_import_errors(session: Session, dagbag: DagBag) -> None:
- """
- For the DAGs in the given DagBag, record any associated import errors and clears
- errors for files that no longer have them. These are usually displayed through the
- Airflow UI so that users know that there are issues parsing DAGs.
-
- :param session: session for ORM operations
- :type session: sqlalchemy.orm.session.Session
- :param dagbag: DagBag containing DAGs with import errors
- :type dagbag: airflow.DagBag
- """
- # Clear the errors of the processed files
- for dagbag_file in dagbag.file_last_changed:
- session.query(errors.ImportError).filter(errors.ImportError.filename == dagbag_file).delete()
-
- # Add the errors of the processed files
- for filename, stacktrace in dagbag.import_errors.items():
- session.add(
- errors.ImportError(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace)
- )
- session.commit()
-
- @provide_session
- def execute_callbacks(
- self, dagbag: DagBag, callback_requests: List[CallbackRequest], session: Session = None
- ) -> None:
- """
- Execute on failure callbacks. These objects can come from SchedulerJob or from
- DagFileProcessorManager.
-
- :param dagbag: Dag Bag of dags
- :param callback_requests: failure callbacks to execute
- :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
- :param session: DB session.
- """
- for request in callback_requests:
- self.log.debug("Processing Callback Request: %s", request)
- try:
- if isinstance(request, TaskCallbackRequest):
- self._execute_task_callbacks(dagbag, request)
- elif isinstance(request, SlaCallbackRequest):
- self.manage_slas(dagbag.dags.get(request.dag_id))
- elif isinstance(request, DagCallbackRequest):
- self._execute_dag_callbacks(dagbag, request, session)
- except Exception: # pylint: disable=broad-except
- self.log.exception(
- "Error executing %s callback for file: %s",
- request.__class__.__name__,
- request.full_filepath,
- )
-
- session.commit()
-
- @provide_session
- def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session):
- dag = dagbag.dags[request.dag_id]
- dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session)
- dag.handle_callback(
- dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session
- )
-
- def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
- simple_ti = request.simple_task_instance
- if simple_ti.dag_id in dagbag.dags:
- dag = dagbag.dags[simple_ti.dag_id]
- if simple_ti.task_id in dag.task_ids:
- task = dag.get_task(simple_ti.task_id)
- ti = TI(task, simple_ti.execution_date)
- # Get properties needed for failure handling from SimpleTaskInstance.
- ti.start_date = simple_ti.start_date
- ti.end_date = simple_ti.end_date
- ti.try_number = simple_ti.try_number
- ti.state = simple_ti.state
- ti.test_mode = self.UNIT_TEST_MODE
- if request.is_failure_callback:
- ti.handle_failure_with_callback(error=request.msg, test_mode=ti.test_mode)
- self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
-
- @provide_session
- def process_file(
- self,
- file_path: str,
- callback_requests: List[CallbackRequest],
- pickle_dags: bool = False,
- session: Session = None,
- ) -> Tuple[int, int]:
- """
- Process a Python file containing Airflow DAGs.
-
- This includes:
-
- 1. Execute the file and look for DAG objects in the namespace.
- 2. Execute any Callbacks if passed to this method.
- 3. Serialize the DAGs and save it to DB (or update existing record in the DB).
- 4. Pickle the DAG and save it to the DB (if necessary).
- 5. Record any errors importing the file into ORM
-
- :param file_path: the path to the Python file that should be executed
- :type file_path: str
- :param callback_requests: failure callback to execute
- :type callback_requests: List[airflow.utils.dag_processing.CallbackRequest]
- :param pickle_dags: whether serialize the DAGs found in the file and
- save them to the db
- :type pickle_dags: bool
- :param session: Sqlalchemy ORM Session
- :type session: Session
- :return: number of dags found, count of import errors
- :rtype: Tuple[int, int]
- """
- self.log.info("Processing file %s for tasks to queue", file_path)
-
- try:
- dagbag = DagBag(file_path, include_examples=False, include_smart_sensor=False)
- except Exception: # pylint: disable=broad-except
- self.log.exception("Failed at reloading the DAG file %s", file_path)
- Stats.incr('dag_file_refresh_error', 1, 1)
- return 0, 0
-
- if len(dagbag.dags) > 0:
- self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path)
- else:
- self.log.warning("No viable dags retrieved from %s", file_path)
- self.update_import_errors(session, dagbag)
- return 0, len(dagbag.import_errors)
-
- self.execute_callbacks(dagbag, callback_requests)
-
- # Save individual DAGs in the ORM
- dagbag.sync_to_db()
-
- if pickle_dags:
- paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids)
-
- unpaused_dags: List[DAG] = [
- dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids
- ]
-
- for dag in unpaused_dags:
- dag.pickle(session)
-
- # Record import errors into the ORM
- try:
- self.update_import_errors(session, dagbag)
- except Exception: # pylint: disable=broad-except
- self.log.exception("Error logging import errors!")
-
- return len(dagbag.dags), len(dagbag.import_errors)
-
-
def _is_parent_process():
"""
Returns True if the current process is the parent process. False if the current process is a child
diff --git a/tests/dag_processing/__init__.py b/tests/dag_processing/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/tests/dag_processing/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/utils/test_dag_processing.py b/tests/dag_processing/test_manager.py
similarity index 99%
rename from tests/utils/test_dag_processing.py
rename to tests/dag_processing/test_manager.py
index 58ad010..0ab7f2b 100644
--- a/tests/utils/test_dag_processing.py
+++ b/tests/dag_processing/test_manager.py
@@ -34,20 +34,20 @@ import pytest
from freezegun import freeze_time
from airflow.configuration import conf
-from airflow.jobs.local_task_job import LocalTaskJob as LJ
-from airflow.jobs.scheduler_job import DagFileProcessorProcess
-from airflow.models import DagBag, DagModel, TaskInstance as TI
-from airflow.models.serialized_dag import SerializedDagModel
-from airflow.models.taskinstance import SimpleTaskInstance
-from airflow.utils import timezone
-from airflow.utils.callback_requests import CallbackRequest, TaskCallbackRequest
-from airflow.utils.dag_processing import (
+from airflow.dag_processing.manager import (
DagFileProcessorAgent,
DagFileProcessorManager,
DagFileStat,
DagParsingSignal,
DagParsingStat,
)
+from airflow.dag_processing.processor import DagFileProcessorProcess
+from airflow.jobs.local_task_job import LocalTaskJob as LJ
+from airflow.models import DagBag, DagModel, TaskInstance as TI
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.models.taskinstance import SimpleTaskInstance
+from airflow.utils import timezone
+from airflow.utils.callback_requests import CallbackRequest, TaskCallbackRequest
from airflow.utils.net import get_hostname
from airflow.utils.session import create_session
from airflow.utils.state import State
diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py
new file mode 100644
index 0000000..5953517
--- /dev/null
+++ b/tests/dag_processing/test_processor.py
@@ -0,0 +1,749 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# pylint: disable=attribute-defined-outside-init
+import datetime
+import os
+import unittest
+from datetime import timedelta
+from tempfile import NamedTemporaryFile
+from unittest import mock
+from unittest.mock import MagicMock, patch
+
+import pytest
+from parameterized import parameterized
+
+from airflow import settings
+from airflow.configuration import conf
+from airflow.dag_processing.processor import DagFileProcessor
+from airflow.jobs.scheduler_job import SchedulerJob
+from airflow.models import DAG, DagBag, DagModel, SlaMiss, TaskInstance
+from airflow.models.dagrun import DagRun
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.models.taskinstance import SimpleTaskInstance
+from airflow.operators.bash import BashOperator
+from airflow.operators.dummy import DummyOperator
+from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.utils import timezone
+from airflow.utils.callback_requests import TaskCallbackRequest
+from airflow.utils.dates import days_ago
+from airflow.utils.session import create_session
+from airflow.utils.state import State
+from airflow.utils.types import DagRunType
+from tests.test_utils.config import conf_vars, env_vars
+from tests.test_utils.db import (
+ clear_db_dags,
+ clear_db_import_errors,
+ clear_db_jobs,
+ clear_db_pools,
+ clear_db_runs,
+ clear_db_serialized_dags,
+ clear_db_sla_miss,
+)
+from tests.test_utils.mock_executor import MockExecutor
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+
+
+@pytest.fixture(scope="class")
+def disable_load_example():
+ with conf_vars({('core', 'load_examples'): 'false'}):
+ with env_vars({('core', 'load_examples'): 'false'}):
+ yield
+
+
+@pytest.mark.usefixtures("disable_load_example")
+class TestDagFileProcessor(unittest.TestCase):
+ @staticmethod
+ def clean_db():
+ clear_db_runs()
+ clear_db_pools()
+ clear_db_dags()
+ clear_db_sla_miss()
+ clear_db_import_errors()
+ clear_db_jobs()
+ clear_db_serialized_dags()
+
+ def setUp(self):
+ self.clean_db()
+
+ # Speed up some tests by not running the tasks, just look at what we
+ # enqueue!
+ self.null_exec = MockExecutor()
+ self.scheduler_job = None
+
+ def tearDown(self) -> None:
+ if self.scheduler_job and self.scheduler_job.processor_agent:
+ self.scheduler_job.processor_agent.end()
+ self.scheduler_job = None
+ self.clean_db()
+
+ def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(hours=1), **kwargs):
+ dag = DAG(
+ dag_id='test_scheduler_reschedule',
+ start_date=start_date,
+ # Make sure it only creates a single DAG Run
+ end_date=end_date,
+ )
+ dag.clear()
+ dag.is_subdag = False
+ with create_session() as session:
+ orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False)
+ session.merge(orm_dag)
+ session.commit()
+ return dag
+
+ @classmethod
+ def setUpClass(cls):
+ # Ensure the DAGs we are looking at from the DB are up-to-date
+ non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False)
+ non_serialized_dagbag.sync_to_db()
+ cls.dagbag = DagBag(read_dags_from_db=True)
+
+ def test_dag_file_processor_sla_miss_callback(self):
+ """
+ Test that the dag file processor calls the sla miss callback
+ """
+ session = settings.Session()
+
+ sla_callback = MagicMock()
+
+ # Create dag with a start of 1 day ago, but an sla of 0
+ # so we'll already have an sla_miss on the books.
+ test_start_date = days_ago(1)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta()},
+ )
+
+ task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
+
+ session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor.manage_slas(dag=dag, session=session)
+
+ assert sla_callback.called
+
+ def test_dag_file_processor_sla_miss_callback_invalid_sla(self):
+ """
+ Test that the dag file processor does not call the sla miss callback when
+ given an invalid sla
+ """
+ session = settings.Session()
+
+ sla_callback = MagicMock()
+
+ # Create dag with a start of 1 day ago, but an sla of 0
+ # so we'll already have an sla_miss on the books.
+ # Pass anything besides a timedelta object to the sla argument.
+ test_start_date = days_ago(1)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date, 'sla': None},
+ )
+
+ task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
+
+ session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor.manage_slas(dag=dag, session=session)
+ sla_callback.assert_not_called()
+
+ def test_dag_file_processor_sla_miss_callback_sent_notification(self):
+ """
+ Test that the dag file processor does not call the sla_miss_callback when a
+ notification has already been sent
+ """
+ session = settings.Session()
+
+ # Mock the callback function so we can verify that it was not called
+ sla_callback = MagicMock()
+
+ # Create dag with a start of 2 days ago, but an sla of 1 day
+ # ago so we'll already have an sla_miss on the books
+ test_start_date = days_ago(2)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
+
+ task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
+
+ # Create a TaskInstance for two days ago
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
+
+ # Create an SlaMiss where notification was sent, but email was not
+ session.merge(
+ SlaMiss(
+ task_id='dummy',
+ dag_id='test_sla_miss',
+ execution_date=test_start_date,
+ email_sent=False,
+ notification_sent=True,
+ )
+ )
+
+ # Now call manage_slas and see if the sla_miss callback gets called
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor.manage_slas(dag=dag, session=session)
+
+ sla_callback.assert_not_called()
+
+ def test_dag_file_processor_sla_miss_callback_exception(self):
+ """
+ Test that the dag file processor gracefully logs an exception if there is a problem
+ calling the sla_miss_callback
+ """
+ session = settings.Session()
+
+ sla_callback = MagicMock(side_effect=RuntimeError('Could not call function'))
+
+ test_start_date = days_ago(2)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date},
+ )
+
+ task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1))
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+ # Create an SlaMiss where notification was sent, but email was not
+ session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+ # Now call manage_slas and see if the sla_miss callback gets called
+ mock_log = mock.MagicMock()
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+ dag_file_processor.manage_slas(dag=dag, session=session)
+ assert sla_callback.called
+ mock_log.exception.assert_called_once_with(
+ 'Could not call sla_miss_callback for DAG %s', 'test_sla_miss'
+ )
+
+ @mock.patch('airflow.dag_processing.processor.send_email')
+ def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email):
+ session = settings.Session()
+
+ test_start_date = days_ago(2)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
+
+ email1 = 'test1@test.com'
+ task = DummyOperator(
+ task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1)
+ )
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+ email2 = 'test2@test.com'
+ DummyOperator(task_id='sla_not_missed', dag=dag, owner='airflow', email=email2)
+
+ session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date))
+
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+
+ dag_file_processor.manage_slas(dag=dag, session=session)
+
+ assert len(mock_send_email.call_args_list) == 1
+
+ send_email_to = mock_send_email.call_args_list[0][0][0]
+ assert email1 in send_email_to
+ assert email2 not in send_email_to
+
+ @mock.patch('airflow.jobs.scheduler_job.Stats.incr')
+ @mock.patch("airflow.utils.email.send_email")
+ def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock_stats_incr):
+ """
+ Test that the dag file processor gracefully logs an exception if there is a problem
+ sending an email
+ """
+ session = settings.Session()
+
+ # Mock the callback function so we can verify that it was not called
+ mock_send_email.side_effect = RuntimeError('Could not send an email')
+
+ test_start_date = days_ago(2)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
+
+ task = DummyOperator(
+ task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
+ )
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+ # Create an SlaMiss where notification was sent, but email was not
+ session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+ mock_log = mock.MagicMock()
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+
+ dag_file_processor.manage_slas(dag=dag, session=session)
+ mock_log.exception.assert_called_once_with(
+ 'Could not send SLA Miss email notification for DAG %s', 'test_sla_miss'
+ )
+ mock_stats_incr.assert_called_once_with('sla_email_notification_failure')
+
+ def test_dag_file_processor_sla_miss_deleted_task(self):
+ """
+ Test that the dag file processor will not crash when trying to send
+ sla miss notification for a deleted task
+ """
+ session = settings.Session()
+
+ test_start_date = days_ago(2)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
+
+ task = DummyOperator(
+ task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
+ )
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+ # Create an SlaMiss where notification was sent, but email was not
+ session.merge(
+ SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss', execution_date=test_start_date)
+ )
+
+ mock_log = mock.MagicMock()
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+ dag_file_processor.manage_slas(dag=dag, session=session)
+
+ @parameterized.expand(
+ [
+ [State.NONE, None, None],
+ [
+ State.UP_FOR_RETRY,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ [
+ State.UP_FOR_RESCHEDULE,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ ]
+ )
+ def test_dag_file_processor_process_task_instances(self, state, start_date, end_date):
+ """
+ Test if _process_task_instances puts the right task instances into the
+ mock_list.
+ """
+ dag = DAG(dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE)
+ BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi')
+
+ with create_session() as session:
+ orm_dag = DagModel(dag_id=dag.dag_id)
+ session.merge(orm_dag)
+
+ dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+ self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+ dag.clear()
+ dr = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+ assert dr is not None
+
+ with create_session() as session:
+ ti = dr.get_task_instances(session=session)[0]
+ ti.state = state
+ ti.start_date = start_date
+ ti.end_date = end_date
+
+ count = self.scheduler_job._schedule_dag_run(dr, set(), session)
+ assert count == 1
+
+ session.refresh(ti)
+ assert ti.state == State.SCHEDULED
+
+ @parameterized.expand(
+ [
+ [State.NONE, None, None],
+ [
+ State.UP_FOR_RETRY,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ [
+ State.UP_FOR_RESCHEDULE,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ ]
+ )
+ def test_dag_file_processor_process_task_instances_with_task_concurrency(
+ self,
+ state,
+ start_date,
+ end_date,
+ ):
+ """
+ Test if _process_task_instances puts the right task instances into the
+ mock_list.
+ """
+ dag = DAG(dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE)
+ BashOperator(task_id='dummy', task_concurrency=2, dag=dag, owner='airflow', bash_command='echo Hi')
+
+ with create_session() as session:
+ orm_dag = DagModel(dag_id=dag.dag_id)
+ session.merge(orm_dag)
+
+ dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+ self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+ dag.clear()
+ dr = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+ assert dr is not None
+
+ with create_session() as session:
+ ti = dr.get_task_instances(session=session)[0]
+ ti.state = state
+ ti.start_date = start_date
+ ti.end_date = end_date
+
+ count = self.scheduler_job._schedule_dag_run(dr, set(), session)
+ assert count == 1
+
+ session.refresh(ti)
+ assert ti.state == State.SCHEDULED
+
+ @parameterized.expand(
+ [
+ [State.NONE, None, None],
+ [
+ State.UP_FOR_RETRY,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ [
+ State.UP_FOR_RESCHEDULE,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ ]
+ )
+ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, start_date, end_date):
+ """
+ Test if _process_task_instances puts the right task instances into the
+ mock_list.
+ """
+ dag = DAG(
+ dag_id='test_scheduler_process_execute_task_depends_on_past',
+ start_date=DEFAULT_DATE,
+ default_args={
+ 'depends_on_past': True,
+ },
+ )
+ BashOperator(task_id='dummy1', dag=dag, owner='airflow', bash_command='echo hi')
+ BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo hi')
+
+ with create_session() as session:
+ orm_dag = DagModel(dag_id=dag.dag_id)
+ session.merge(orm_dag)
+
+ dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+ self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+ dag.clear()
+ dr = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+ assert dr is not None
+
+ with create_session() as session:
+ tis = dr.get_task_instances(session=session)
+ for ti in tis:
+ ti.state = state
+ ti.start_date = start_date
+ ti.end_date = end_date
+
+ count = self.scheduler_job._schedule_dag_run(dr, set(), session)
+ assert count == 2
+
+ session.refresh(tis[0])
+ session.refresh(tis[1])
+ assert tis[0].state == State.SCHEDULED
+ assert tis[1].state == State.SCHEDULED
+
+ def test_scheduler_job_add_new_task(self):
+ """
+ Test if a task instance will be added if the dag is updated
+ """
+ dag = DAG(dag_id='test_scheduler_add_new_task', start_date=DEFAULT_DATE)
+ BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo test')
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+
+ # Since we don't want to store the code for the DAG defined in this file
+ with mock.patch.object(settings, "STORE_DAG_CODE", False):
+ self.scheduler_job.dagbag.sync_to_db()
+
+ session = settings.Session()
+ orm_dag = session.query(DagModel).get(dag.dag_id)
+ assert orm_dag is not None
+
+ if self.scheduler_job.processor_agent:
+ self.scheduler_job.processor_agent.end()
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+ dag = self.scheduler_job.dagbag.get_dag('test_scheduler_add_new_task', session=session)
+ self.scheduler_job._create_dag_runs([orm_dag], session)
+
+ drs = DagRun.find(dag_id=dag.dag_id, session=session)
+ assert len(drs) == 1
+ dr = drs[0]
+
+ tis = dr.get_task_instances()
+ assert len(tis) == 1
+
+ BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test')
+ SerializedDagModel.write_dag(dag=dag)
+
+ scheduled_tis = self.scheduler_job._schedule_dag_run(dr, set(), session)
+ session.flush()
+ assert scheduled_tis == 2
+
+ drs = DagRun.find(dag_id=dag.dag_id, session=session)
+ assert len(drs) == 1
+ dr = drs[0]
+
+ tis = dr.get_task_instances()
+ assert len(tis) == 2
+
+ def test_runs_respected_after_clear(self):
+ """
+ Test if _process_task_instances only schedules ti's up to max_active_runs
+ (related to issue AIRFLOW-137)
+ """
+ dag = DAG(dag_id='test_scheduler_max_active_runs_respected_after_clear', start_date=DEFAULT_DATE)
+ dag.max_active_runs = 3
+
+ BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo Hi')
+
+ session = settings.Session()
+ orm_dag = DagModel(dag_id=dag.dag_id)
+ session.merge(orm_dag)
+ session.commit()
+ session.close()
+ dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+ self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+ dag.clear()
+
+ date = DEFAULT_DATE
+ dr1 = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=date,
+ state=State.RUNNING,
+ )
+ date = dag.following_schedule(date)
+ dr2 = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=date,
+ state=State.RUNNING,
+ )
+ date = dag.following_schedule(date)
+ dr3 = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=date,
+ state=State.RUNNING,
+ )
+
+ # First create up to 3 dagruns in RUNNING state.
+ assert dr1 is not None
+ assert dr2 is not None
+ assert dr3 is not None
+ assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 3
+
+ # Reduce max_active_runs to 1
+ dag.max_active_runs = 1
+
+ # and schedule them in, so we can check how many
+ # tasks are put on the task_instances_list (should be one, not 3)
+ with create_session() as session:
+ num_scheduled = self.scheduler_job._schedule_dag_run(dr1, set(), session)
+ assert num_scheduled == 1
+ num_scheduled = self.scheduler_job._schedule_dag_run(dr2, {dr1.execution_date}, session)
+ assert num_scheduled == 0
+ num_scheduled = self.scheduler_job._schedule_dag_run(dr3, {dr1.execution_date}, session)
+ assert num_scheduled == 0
+
+ @patch.object(TaskInstance, 'handle_failure_with_callback')
+ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
+ dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ with create_session() as session:
+ session.query(TaskInstance).delete()
+ dag = dagbag.get_dag('example_branch_operator')
+ task = dag.get_task(task_id='run_this_first')
+
+ ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
+
+ session.add(ti)
+ session.commit()
+
+ requests = [
+ TaskCallbackRequest(
+ full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message"
+ )
+ ]
+ dag_file_processor.execute_callbacks(dagbag, requests)
+ mock_ti_handle_failure.assert_called_once_with(
+ error="Message",
+ test_mode=conf.getboolean('core', 'unit_test_mode'),
+ )
+
+ def test_process_file_should_failure_callback(self):
+ dag_file = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py'
+ )
+ dagbag = DagBag(dag_folder=dag_file, include_examples=False)
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ with create_session() as session, NamedTemporaryFile(delete=False) as callback_file:
+ session.query(TaskInstance).delete()
+ dag = dagbag.get_dag('test_om_failure_callback_dag')
+ task = dag.get_task(task_id='test_om_failure_callback_task')
+
+ ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
+
+ session.add(ti)
+ session.commit()
+
+ requests = [
+ TaskCallbackRequest(
+ full_filepath=dag.full_filepath,
+ simple_task_instance=SimpleTaskInstance(ti),
+ msg="Message",
+ )
+ ]
+ callback_file.close()
+
+ with mock.patch.dict("os.environ", {"AIRFLOW_CALLBACK_FILE": callback_file.name}):
+ dag_file_processor.process_file(dag_file, requests)
+ with open(callback_file.name) as callback_file2:
+ content = callback_file2.read()
+ assert "Callback fired" == content
+ os.remove(callback_file.name)
+
+ def test_should_mark_dummy_task_as_success(self):
+ dag_file = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py'
+ )
+
+ # Write DAGs to dag and serialized_dag table
+ dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False)
+ dagbag.sync_to_db()
+
+ self.scheduler_job_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job_job.processor_agent = mock.MagicMock()
+ dag = self.scheduler_job_job.dagbag.get_dag("test_only_dummy_tasks")
+
+ # Create DagRun
+ session = settings.Session()
+ orm_dag = session.query(DagModel).get(dag.dag_id)
+ self.scheduler_job_job._create_dag_runs([orm_dag], session)
+
+ drs = DagRun.find(dag_id=dag.dag_id, session=session)
+ assert len(drs) == 1
+ dr = drs[0]
+
+ # Schedule TaskInstances
+ self.scheduler_job_job._schedule_dag_run(dr, {}, session)
+ with create_session() as session:
+ tis = session.query(TaskInstance).all()
+
+ dags = self.scheduler_job_job.dagbag.dags.values()
+ assert ['test_only_dummy_tasks'] == [dag.dag_id for dag in dags]
+ assert 5 == len(tis)
+ assert {
+ ('test_task_a', 'success'),
+ ('test_task_b', None),
+ ('test_task_c', 'success'),
+ ('test_task_on_execute', 'scheduled'),
+ ('test_task_on_success', 'scheduled'),
+ } == {(ti.task_id, ti.state) for ti in tis}
+ for state, start_date, end_date, duration in [
+ (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
+ ]:
+ if state == 'success':
+ assert start_date is not None
+ assert end_date is not None
+ assert 0.0 == duration
+ else:
+ assert start_date is None
+ assert end_date is None
+ assert duration is None
+
+ self.scheduler_job_job._schedule_dag_run(dr, {}, session)
+ with create_session() as session:
+ tis = session.query(TaskInstance).all()
+
+ assert 5 == len(tis)
+ assert {
+ ('test_task_a', 'success'),
+ ('test_task_b', 'success'),
+ ('test_task_c', 'success'),
+ ('test_task_on_execute', 'scheduled'),
+ ('test_task_on_success', 'scheduled'),
+ } == {(ti.task_id, ti.state) for ti in tis}
+ for state, start_date, end_date, duration in [
+ (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
+ ]:
+ if state == 'success':
+ assert start_date is not None
+ assert end_date is not None
+ assert 0.0 == duration
+ else:
+ assert start_date is None
+ assert end_date is None
+ assert duration is None
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index fe0b257..9fe8517 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -22,7 +22,7 @@ import os
import shutil
import unittest
from datetime import timedelta
-from tempfile import NamedTemporaryFile, mkdtemp
+from tempfile import mkdtemp
from time import sleep
from unittest import mock
from unittest.mock import MagicMock, patch
@@ -37,22 +37,20 @@ from sqlalchemy import func
import airflow.example_dags
import airflow.smart_sensor_dags
from airflow import settings
-from airflow.configuration import conf
+from airflow.dag_processing.manager import DagFileProcessorAgent
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.jobs.backfill_job import BackfillJob
-from airflow.jobs.scheduler_job import DagFileProcessor, SchedulerJob
-from airflow.models import DAG, DagBag, DagModel, Pool, SlaMiss, TaskInstance, errors
+from airflow.jobs.scheduler_job import SchedulerJob
+from airflow.models import DAG, DagBag, DagModel, Pool, TaskInstance, errors
from airflow.models.dagrun import DagRun
from airflow.models.serialized_dag import SerializedDagModel
-from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey
+from airflow.models.taskinstance import TaskInstanceKey
from airflow.operators.bash import BashOperator
from airflow.operators.dummy import DummyOperator
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils import timezone
-from airflow.utils.callback_requests import DagCallbackRequest, TaskCallbackRequest
-from airflow.utils.dag_processing import DagFileProcessorAgent
-from airflow.utils.dates import days_ago
+from airflow.utils.callback_requests import DagCallbackRequest
from airflow.utils.file import list_py_file_paths
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
@@ -101,688 +99,6 @@ def disable_load_example():
@pytest.mark.usefixtures("disable_load_example")
-class TestDagFileProcessor(unittest.TestCase):
- @staticmethod
- def clean_db():
- clear_db_runs()
- clear_db_pools()
- clear_db_dags()
- clear_db_sla_miss()
- clear_db_import_errors()
- clear_db_jobs()
- clear_db_serialized_dags()
-
- def setUp(self):
- self.clean_db()
-
- # Speed up some tests by not running the tasks, just look at what we
- # enqueue!
- self.null_exec = MockExecutor()
- self.scheduler_job = None
-
- def tearDown(self) -> None:
- if self.scheduler_job and self.scheduler_job.processor_agent:
- self.scheduler_job.processor_agent.end()
- self.scheduler_job = None
- self.clean_db()
-
- def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(hours=1), **kwargs):
- dag = DAG(
- dag_id='test_scheduler_reschedule',
- start_date=start_date,
- # Make sure it only creates a single DAG Run
- end_date=end_date,
- )
- dag.clear()
- dag.is_subdag = False
- with create_session() as session:
- orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False)
- session.merge(orm_dag)
- session.commit()
- return dag
-
- @classmethod
- def setUpClass(cls):
- # Ensure the DAGs we are looking at from the DB are up-to-date
- non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False)
- non_serialized_dagbag.sync_to_db()
- cls.dagbag = DagBag(read_dags_from_db=True)
-
- def test_dag_file_processor_sla_miss_callback(self):
- """
- Test that the dag file processor calls the sla miss callback
- """
- session = settings.Session()
-
- sla_callback = MagicMock()
-
- # Create dag with a start of 1 day ago, but an sla of 0
- # so we'll already have an sla_miss on the books.
- test_start_date = days_ago(1)
- dag = DAG(
- dag_id='test_sla_miss',
- sla_miss_callback=sla_callback,
- default_args={'start_date': test_start_date, 'sla': datetime.timedelta()},
- )
-
- task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
-
- session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
-
- session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
-
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
- dag_file_processor.manage_slas(dag=dag, session=session)
-
- assert sla_callback.called
-
- def test_dag_file_processor_sla_miss_callback_invalid_sla(self):
- """
- Test that the dag file processor does not call the sla miss callback when
- given an invalid sla
- """
- session = settings.Session()
-
- sla_callback = MagicMock()
-
- # Create dag with a start of 1 day ago, but an sla of 0
- # so we'll already have an sla_miss on the books.
- # Pass anything besides a timedelta object to the sla argument.
- test_start_date = days_ago(1)
- dag = DAG(
- dag_id='test_sla_miss',
- sla_miss_callback=sla_callback,
- default_args={'start_date': test_start_date, 'sla': None},
- )
-
- task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
-
- session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
-
- session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
-
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
- dag_file_processor.manage_slas(dag=dag, session=session)
- sla_callback.assert_not_called()
-
- def test_dag_file_processor_sla_miss_callback_sent_notification(self):
- """
- Test that the dag file processor does not call the sla_miss_callback when a
- notification has already been sent
- """
- session = settings.Session()
-
- # Mock the callback function so we can verify that it was not called
- sla_callback = MagicMock()
-
- # Create dag with a start of 2 days ago, but an sla of 1 day
- # ago so we'll already have an sla_miss on the books
- test_start_date = days_ago(2)
- dag = DAG(
- dag_id='test_sla_miss',
- sla_miss_callback=sla_callback,
- default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
- )
-
- task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
-
- # Create a TaskInstance for two days ago
- session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
-
- # Create an SlaMiss where notification was sent, but email was not
- session.merge(
- SlaMiss(
- task_id='dummy',
- dag_id='test_sla_miss',
- execution_date=test_start_date,
- email_sent=False,
- notification_sent=True,
- )
- )
-
- # Now call manage_slas and see if the sla_miss callback gets called
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
- dag_file_processor.manage_slas(dag=dag, session=session)
-
- sla_callback.assert_not_called()
-
- def test_dag_file_processor_sla_miss_callback_exception(self):
- """
- Test that the dag file processor gracefully logs an exception if there is a problem
- calling the sla_miss_callback
- """
- session = settings.Session()
-
- sla_callback = MagicMock(side_effect=RuntimeError('Could not call function'))
-
- test_start_date = days_ago(2)
- dag = DAG(
- dag_id='test_sla_miss',
- sla_miss_callback=sla_callback,
- default_args={'start_date': test_start_date},
- )
-
- task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1))
-
- session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
-
- # Create an SlaMiss where notification was sent, but email was not
- session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
-
- # Now call manage_slas and see if the sla_miss callback gets called
- mock_log = mock.MagicMock()
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
- dag_file_processor.manage_slas(dag=dag, session=session)
- assert sla_callback.called
- mock_log.exception.assert_called_once_with(
- 'Could not call sla_miss_callback for DAG %s', 'test_sla_miss'
- )
-
- @mock.patch('airflow.jobs.scheduler_job.send_email')
- def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email):
- session = settings.Session()
-
- test_start_date = days_ago(2)
- dag = DAG(
- dag_id='test_sla_miss',
- default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
- )
-
- email1 = 'test1@test.com'
- task = DummyOperator(
- task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1)
- )
-
- session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
-
- email2 = 'test2@test.com'
- DummyOperator(task_id='sla_not_missed', dag=dag, owner='airflow', email=email2)
-
- session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date))
-
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
-
- dag_file_processor.manage_slas(dag=dag, session=session)
-
- assert len(mock_send_email.call_args_list) == 1
-
- send_email_to = mock_send_email.call_args_list[0][0][0]
- assert email1 in send_email_to
- assert email2 not in send_email_to
-
- @mock.patch('airflow.jobs.scheduler_job.Stats.incr')
- @mock.patch("airflow.utils.email.send_email")
- def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock_stats_incr):
- """
- Test that the dag file processor gracefully logs an exception if there is a problem
- sending an email
- """
- session = settings.Session()
-
- # Mock the callback function so we can verify that it was not called
- mock_send_email.side_effect = RuntimeError('Could not send an email')
-
- test_start_date = days_ago(2)
- dag = DAG(
- dag_id='test_sla_miss',
- default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
- )
-
- task = DummyOperator(
- task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
- )
-
- session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
-
- # Create an SlaMiss where notification was sent, but email was not
- session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
-
- mock_log = mock.MagicMock()
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
-
- dag_file_processor.manage_slas(dag=dag, session=session)
- mock_log.exception.assert_called_once_with(
- 'Could not send SLA Miss email notification for DAG %s', 'test_sla_miss'
- )
- mock_stats_incr.assert_called_once_with('sla_email_notification_failure')
-
- def test_dag_file_processor_sla_miss_deleted_task(self):
- """
- Test that the dag file processor will not crash when trying to send
- sla miss notification for a deleted task
- """
- session = settings.Session()
-
- test_start_date = days_ago(2)
- dag = DAG(
- dag_id='test_sla_miss',
- default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
- )
-
- task = DummyOperator(
- task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
- )
-
- session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
-
- # Create an SlaMiss where notification was sent, but email was not
- session.merge(
- SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss', execution_date=test_start_date)
- )
-
- mock_log = mock.MagicMock()
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
- dag_file_processor.manage_slas(dag=dag, session=session)
-
- @parameterized.expand(
- [
- [State.NONE, None, None],
- [
- State.UP_FOR_RETRY,
- timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15),
- ],
- [
- State.UP_FOR_RESCHEDULE,
- timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15),
- ],
- ]
- )
- def test_dag_file_processor_process_task_instances(self, state, start_date, end_date):
- """
- Test if _process_task_instances puts the right task instances into the
- mock_list.
- """
- dag = DAG(dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE)
- BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi')
-
- with create_session() as session:
- orm_dag = DagModel(dag_id=dag.dag_id)
- session.merge(orm_dag)
-
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
- self.scheduler_job = SchedulerJob(subdir=os.devnull)
- self.scheduler_job.processor_agent = mock.MagicMock()
- self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
- dag.clear()
- dr = dag.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
- assert dr is not None
-
- with create_session() as session:
- ti = dr.get_task_instances(session=session)[0]
- ti.state = state
- ti.start_date = start_date
- ti.end_date = end_date
-
- count = self.scheduler_job._schedule_dag_run(dr, set(), session)
- assert count == 1
-
- session.refresh(ti)
- assert ti.state == State.SCHEDULED
-
- @parameterized.expand(
- [
- [State.NONE, None, None],
- [
- State.UP_FOR_RETRY,
- timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15),
- ],
- [
- State.UP_FOR_RESCHEDULE,
- timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15),
- ],
- ]
- )
- def test_dag_file_processor_process_task_instances_with_task_concurrency(
- self,
- state,
- start_date,
- end_date,
- ):
- """
- Test if _process_task_instances puts the right task instances into the
- mock_list.
- """
- dag = DAG(dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE)
- BashOperator(task_id='dummy', task_concurrency=2, dag=dag, owner='airflow', bash_command='echo Hi')
-
- with create_session() as session:
- orm_dag = DagModel(dag_id=dag.dag_id)
- session.merge(orm_dag)
-
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
- self.scheduler_job = SchedulerJob(subdir=os.devnull)
- self.scheduler_job.processor_agent = mock.MagicMock()
- self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
- dag.clear()
- dr = dag.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
- assert dr is not None
-
- with create_session() as session:
- ti = dr.get_task_instances(session=session)[0]
- ti.state = state
- ti.start_date = start_date
- ti.end_date = end_date
-
- count = self.scheduler_job._schedule_dag_run(dr, set(), session)
- assert count == 1
-
- session.refresh(ti)
- assert ti.state == State.SCHEDULED
-
- @parameterized.expand(
- [
- [State.NONE, None, None],
- [
- State.UP_FOR_RETRY,
- timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15),
- ],
- [
- State.UP_FOR_RESCHEDULE,
- timezone.utcnow() - datetime.timedelta(minutes=30),
- timezone.utcnow() - datetime.timedelta(minutes=15),
- ],
- ]
- )
- def test_dag_file_processor_process_task_instances_depends_on_past(self, state, start_date, end_date):
- """
- Test if _process_task_instances puts the right task instances into the
- mock_list.
- """
- dag = DAG(
- dag_id='test_scheduler_process_execute_task_depends_on_past',
- start_date=DEFAULT_DATE,
- default_args={
- 'depends_on_past': True,
- },
- )
- BashOperator(task_id='dummy1', dag=dag, owner='airflow', bash_command='echo hi')
- BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo hi')
-
- with create_session() as session:
- orm_dag = DagModel(dag_id=dag.dag_id)
- session.merge(orm_dag)
-
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
- self.scheduler_job = SchedulerJob(subdir=os.devnull)
- self.scheduler_job.processor_agent = mock.MagicMock()
- self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
- dag.clear()
- dr = dag.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
- assert dr is not None
-
- with create_session() as session:
- tis = dr.get_task_instances(session=session)
- for ti in tis:
- ti.state = state
- ti.start_date = start_date
- ti.end_date = end_date
-
- count = self.scheduler_job._schedule_dag_run(dr, set(), session)
- assert count == 2
-
- session.refresh(tis[0])
- session.refresh(tis[1])
- assert tis[0].state == State.SCHEDULED
- assert tis[1].state == State.SCHEDULED
-
- def test_scheduler_job_add_new_task(self):
- """
- Test if a task instance will be added if the dag is updated
- """
- dag = DAG(dag_id='test_scheduler_add_new_task', start_date=DEFAULT_DATE)
- BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo test')
-
- self.scheduler_job = SchedulerJob(subdir=os.devnull)
- self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
-
- # Since we don't want to store the code for the DAG defined in this file
- with mock.patch.object(settings, "STORE_DAG_CODE", False):
- self.scheduler_job.dagbag.sync_to_db()
-
- session = settings.Session()
- orm_dag = session.query(DagModel).get(dag.dag_id)
- assert orm_dag is not None
-
- if self.scheduler_job.processor_agent:
- self.scheduler_job.processor_agent.end()
- self.scheduler_job = SchedulerJob(subdir=os.devnull)
- self.scheduler_job.processor_agent = mock.MagicMock()
- dag = self.scheduler_job.dagbag.get_dag('test_scheduler_add_new_task', session=session)
- self.scheduler_job._create_dag_runs([orm_dag], session)
-
- drs = DagRun.find(dag_id=dag.dag_id, session=session)
- assert len(drs) == 1
- dr = drs[0]
-
- tis = dr.get_task_instances()
- assert len(tis) == 1
-
- BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test')
- SerializedDagModel.write_dag(dag=dag)
-
- scheduled_tis = self.scheduler_job._schedule_dag_run(dr, set(), session)
- session.flush()
- assert scheduled_tis == 2
-
- drs = DagRun.find(dag_id=dag.dag_id, session=session)
- assert len(drs) == 1
- dr = drs[0]
-
- tis = dr.get_task_instances()
- assert len(tis) == 2
-
- def test_runs_respected_after_clear(self):
- """
- Test if _process_task_instances only schedules ti's up to max_active_runs
- (related to issue AIRFLOW-137)
- """
- dag = DAG(dag_id='test_scheduler_max_active_runs_respected_after_clear', start_date=DEFAULT_DATE)
- dag.max_active_runs = 3
-
- BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo Hi')
-
- session = settings.Session()
- orm_dag = DagModel(dag_id=dag.dag_id)
- session.merge(orm_dag)
- session.commit()
- session.close()
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
- self.scheduler_job = SchedulerJob(subdir=os.devnull)
- self.scheduler_job.processor_agent = mock.MagicMock()
- self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
- dag.clear()
-
- date = DEFAULT_DATE
- dr1 = dag.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- execution_date=date,
- state=State.RUNNING,
- )
- date = dag.following_schedule(date)
- dr2 = dag.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- execution_date=date,
- state=State.RUNNING,
- )
- date = dag.following_schedule(date)
- dr3 = dag.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- execution_date=date,
- state=State.RUNNING,
- )
-
- # First create up to 3 dagruns in RUNNING state.
- assert dr1 is not None
- assert dr2 is not None
- assert dr3 is not None
- assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 3
-
- # Reduce max_active_runs to 1
- dag.max_active_runs = 1
-
- # and schedule them in, so we can check how many
- # tasks are put on the task_instances_list (should be one, not 3)
- with create_session() as session:
- num_scheduled = self.scheduler_job._schedule_dag_run(dr1, set(), session)
- assert num_scheduled == 1
- num_scheduled = self.scheduler_job._schedule_dag_run(dr2, {dr1.execution_date}, session)
- assert num_scheduled == 0
- num_scheduled = self.scheduler_job._schedule_dag_run(dr3, {dr1.execution_date}, session)
- assert num_scheduled == 0
-
- @patch.object(TaskInstance, 'handle_failure_with_callback')
- def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
- dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
- with create_session() as session:
- session.query(TaskInstance).delete()
- dag = dagbag.get_dag('example_branch_operator')
- task = dag.get_task(task_id='run_this_first')
-
- ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
-
- session.add(ti)
- session.commit()
-
- requests = [
- TaskCallbackRequest(
- full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message"
- )
- ]
- dag_file_processor.execute_callbacks(dagbag, requests)
- mock_ti_handle_failure.assert_called_once_with(
- error="Message",
- test_mode=conf.getboolean('core', 'unit_test_mode'),
- )
-
- def test_process_file_should_failure_callback(self):
- dag_file = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py'
- )
- dagbag = DagBag(dag_folder=dag_file, include_examples=False)
- dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
- with create_session() as session, NamedTemporaryFile(delete=False) as callback_file:
- session.query(TaskInstance).delete()
- dag = dagbag.get_dag('test_om_failure_callback_dag')
- task = dag.get_task(task_id='test_om_failure_callback_task')
-
- ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
-
- session.add(ti)
- session.commit()
-
- requests = [
- TaskCallbackRequest(
- full_filepath=dag.full_filepath,
- simple_task_instance=SimpleTaskInstance(ti),
- msg="Message",
- )
- ]
- callback_file.close()
-
- with mock.patch.dict("os.environ", {"AIRFLOW_CALLBACK_FILE": callback_file.name}):
- dag_file_processor.process_file(dag_file, requests)
- with open(callback_file.name) as callback_file2:
- content = callback_file2.read()
- assert "Callback fired" == content
- os.remove(callback_file.name)
-
- def test_should_mark_dummy_task_as_success(self):
- dag_file = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py'
- )
-
- # Write DAGs to dag and serialized_dag table
- dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False)
- dagbag.sync_to_db()
-
- self.scheduler_job_job = SchedulerJob(subdir=os.devnull)
- self.scheduler_job_job.processor_agent = mock.MagicMock()
- dag = self.scheduler_job_job.dagbag.get_dag("test_only_dummy_tasks")
-
- # Create DagRun
- session = settings.Session()
- orm_dag = session.query(DagModel).get(dag.dag_id)
- self.scheduler_job_job._create_dag_runs([orm_dag], session)
-
- drs = DagRun.find(dag_id=dag.dag_id, session=session)
- assert len(drs) == 1
- dr = drs[0]
-
- # Schedule TaskInstances
- self.scheduler_job_job._schedule_dag_run(dr, {}, session)
- with create_session() as session:
- tis = session.query(TaskInstance).all()
-
- dags = self.scheduler_job_job.dagbag.dags.values()
- assert ['test_only_dummy_tasks'] == [dag.dag_id for dag in dags]
- assert 5 == len(tis)
- assert {
- ('test_task_a', 'success'),
- ('test_task_b', None),
- ('test_task_c', 'success'),
- ('test_task_on_execute', 'scheduled'),
- ('test_task_on_success', 'scheduled'),
- } == {(ti.task_id, ti.state) for ti in tis}
- for state, start_date, end_date, duration in [
- (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
- ]:
- if state == 'success':
- assert start_date is not None
- assert end_date is not None
- assert 0.0 == duration
- else:
- assert start_date is None
- assert end_date is None
- assert duration is None
-
- self.scheduler_job_job._schedule_dag_run(dr, {}, session)
- with create_session() as session:
- tis = session.query(TaskInstance).all()
-
- assert 5 == len(tis)
- assert {
- ('test_task_a', 'success'),
- ('test_task_b', 'success'),
- ('test_task_c', 'success'),
- ('test_task_on_execute', 'scheduled'),
- ('test_task_on_success', 'scheduled'),
- } == {(ti.task_id, ti.state) for ti in tis}
- for state, start_date, end_date, duration in [
- (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
- ]:
- if state == 'success':
- assert start_date is not None
- assert end_date is not None
- assert 0.0 == duration
- else:
- assert start_date is None
- assert end_date is None
- assert duration is None
-
-
-@pytest.mark.usefixtures("disable_load_example")
class TestSchedulerJob(unittest.TestCase):
@staticmethod
def clean_db():
@@ -802,7 +118,7 @@ class TestSchedulerJob(unittest.TestCase):
# enqueue!
self.null_exec = MockExecutor()
- self.patcher = patch('airflow.utils.dag_processing.SerializedDagModel.remove_deleted_dags')
+ self.patcher = patch('airflow.dag_processing.manager.SerializedDagModel.remove_deleted_dags')
# Since we don't want to store the code for the DAG defined in this file
self.patcher_dag_code = patch.object(settings, "STORE_DAG_CODE", False)
self.patcher.start()
@@ -3213,7 +2529,7 @@ class TestSchedulerJob(unittest.TestCase):
dagbag.bag_dag(dag=dag, root_dag=dag)
dagbag.sync_to_db()
- @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag)
+ @mock.patch('airflow.dag_processing.processor.DagBag', return_value=dagbag)
def do_schedule(mock_dagbag):
# Use a empty file since the above mock will return the
# expected DAGs. Also specify only a single file so that it doesn't
diff --git a/tests/test_utils/perf/perf_kit/python.py b/tests/test_utils/perf/perf_kit/python.py
index 7d92a49..596f4f6 100644
--- a/tests/test_utils/perf/perf_kit/python.py
+++ b/tests/test_utils/perf/perf_kit/python.py
@@ -91,7 +91,7 @@ if __name__ == "__main__":
import logging
import airflow
- from airflow.jobs.scheduler_job import DagFileProcessor
+ from airflow.dag_processing.processor import DagFileProcessor
log = logging.getLogger(__name__)
processor = DagFileProcessor(dag_ids=[], log=log)
diff --git a/tests/test_utils/perf/perf_kit/sqlalchemy.py b/tests/test_utils/perf/perf_kit/sqlalchemy.py
index e60ad51..37cf0fe 100644
--- a/tests/test_utils/perf/perf_kit/sqlalchemy.py
+++ b/tests/test_utils/perf/perf_kit/sqlalchemy.py
@@ -218,7 +218,7 @@ if __name__ == "__main__":
import logging
from unittest import mock
- from airflow.jobs.scheduler_job import DagFileProcessor
+ from airflow.dag_processing.processor import DagFileProcessor
with mock.patch.dict(
"os.environ",