You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by jh...@apache.org on 2021/08/11 21:57:32 UTC
[airflow] 01/03: Add 'queued' state to DagRun (#16401)
This is an automated email from the ASF dual-hosted git repository.
jhtimmins pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit f7bece3da14879712a895e897dad339944170ce3
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Jul 6 15:03:27 2021 +0100
Add 'queued' state to DagRun (#16401)
This change adds queued state to DagRun. Newly created DagRuns
start in the queued state, are then moved to the running state after
satisfying the DAG's max_active_runs. If the Dag doesn't have
max_active_runs, the DagRuns are moved to running state immediately
Clearing a DagRun sets the state to queued state
Closes: #9975, #16366
(cherry picked from commit 6611ffd399dce0474d8329720de7e83f568df598)
---
airflow/api_connexion/openapi/v1.yaml | 4 +-
airflow/jobs/scheduler_job.py | 149 ++--
...93827b8_add_queued_at_column_to_dagrun_table.py | 49 ++
airflow/models/dag.py | 4 +-
airflow/models/dagrun.py | 17 +-
airflow/models/taskinstance.py | 6 +-
airflow/www/static/js/tree.js | 4 +-
airflow/www/views.py | 5 +-
docs/apache-airflow/migrations-ref.rst | 4 +-
tests/api/common/experimental/test_mark_tasks.py | 4 +-
.../endpoints/test_dag_run_endpoint.py | 14 +-
tests/api_connexion/schemas/test_dag_run_schema.py | 3 +
tests/dag_processing/test_processor.py | 746 +++++++++++++++++++++
tests/jobs/test_scheduler_job.py | 206 +++---
tests/models/test_cleartasks.py | 37 +
tests/models/test_dagrun.py | 25 +-
tests/sensors/test_external_task_sensor.py | 8 +-
tests/utils/test_dag_processing.py | 11 +-
18 files changed, 1037 insertions(+), 259 deletions(-)
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index 182f356..47fa3dc 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1853,6 +1853,7 @@ components:
description: |
The start time. The time when DAG run was actually created.
readOnly: true
+ nullable: true
end_date:
type: string
format: date-time
@@ -3025,8 +3026,9 @@ components:
description: DAG State.
type: string
enum:
- - success
+ - queued
- running
+ - success
- failed
TriggerRule:
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index fe8e0b0..b7506b5 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -804,7 +804,7 @@ class SchedulerJob(BaseJob):
"""
For all DAG IDs in the DagBag, look for task instances in the
old_states and set them to new_state if the corresponding DagRun
- does not exist or exists but is not in the running state. This
+ does not exist or exists but is not in the running or queued state. This
normally should not happen, but it can if the state of DagRuns are
changed manually.
@@ -821,7 +821,7 @@ class SchedulerJob(BaseJob):
.filter(models.TaskInstance.state.in_(old_states))
.filter(
or_(
- models.DagRun.state != State.RUNNING,
+ models.DagRun.state.notin_([State.RUNNING, State.QUEUED]),
models.DagRun.state.is_(None),
)
)
@@ -1489,39 +1489,12 @@ class SchedulerJob(BaseJob):
if settings.USE_JOB_SCHEDULE:
self._create_dagruns_for_dags(guard, session)
- dag_runs = self._get_next_dagruns_to_examine(session)
+ self._start_queued_dagruns(session)
+ guard.commit()
+ dag_runs = self._get_next_dagruns_to_examine(State.RUNNING, session)
# Bulk fetch the currently active dag runs for the dags we are
# examining, rather than making one query per DagRun
- # TODO: This query is probably horribly inefficient (though there is an
- # index on (dag_id,state)). It is to deal with the case when a user
- # clears more than max_active_runs older tasks -- we don't want the
- # scheduler to suddenly go and start running tasks from all of the
- # runs. (AIRFLOW-137/GH #1442)
- #
- # The longer term fix would be to have `clear` do this, and put DagRuns
- # in to the queued state, then take DRs out of queued before creating
- # any new ones
-
- # Build up a set of execution_dates that are "active" for a given
- # dag_id -- only tasks from those runs will be scheduled.
- active_runs_by_dag_id = defaultdict(set)
-
- query = (
- session.query(
- TI.dag_id,
- TI.execution_date,
- )
- .filter(
- TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})),
- TI.state.notin_(list(State.finished) + [State.REMOVED]),
- )
- .group_by(TI.dag_id, TI.execution_date)
- )
-
- for dag_id, execution_date in query:
- active_runs_by_dag_id[dag_id].add(execution_date)
-
for dag_run in dag_runs:
# Use try_except to not stop the Scheduler when a Serialized DAG is not found
# This takes care of Dynamic DAGs especially
@@ -1530,7 +1503,7 @@ class SchedulerJob(BaseJob):
# But this would take care of the scenario when the Scheduler is restarted after DagRun is
# created and the DAG is deleted / renamed
try:
- self._schedule_dag_run(dag_run, active_runs_by_dag_id.get(dag_run.dag_id, set()), session)
+ self._schedule_dag_run(dag_run, session)
except SerializedDagNotFound:
self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue
@@ -1570,9 +1543,9 @@ class SchedulerJob(BaseJob):
return num_queued_tis
@retry_db_transaction
- def _get_next_dagruns_to_examine(self, session):
+ def _get_next_dagruns_to_examine(self, state, session):
"""Get Next DagRuns to Examine with retries"""
- return DagRun.next_dagruns_to_examine(session)
+ return DagRun.next_dagruns_to_examine(state, session)
@retry_db_transaction
def _create_dagruns_for_dags(self, guard, session):
@@ -1593,7 +1566,7 @@ class SchedulerJob(BaseJob):
# as DagModel.dag_id and DagModel.next_dagrun
# This list is used to verify if the DagRun already exist so that we don't attempt to create
# duplicate dag runs
- active_dagruns = (
+ existing_dagruns = (
session.query(DagRun.dag_id, DagRun.execution_date)
.filter(
tuple_(DagRun.dag_id, DagRun.execution_date).in_(
@@ -1616,89 +1589,83 @@ class SchedulerJob(BaseJob):
# are not updated.
# We opted to check DagRun existence instead
# of catching an Integrity error and rolling back the session i.e
- # we need to run self._update_dag_next_dagruns if the Dag Run already exists or if we
+ # we need to set dag.next_dagrun_info if the Dag Run already exists or if we
# create a new one. This is so that in the next Scheduling loop we try to create new runs
# instead of falling in a loop of Integrity Error.
- if (dag.dag_id, dag_model.next_dagrun) not in active_dagruns:
- run = dag.create_dagrun(
+ if (dag.dag_id, dag_model.next_dagrun) not in existing_dagruns:
+ dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=dag_model.next_dagrun,
- start_date=timezone.utcnow(),
- state=State.RUNNING,
+ state=State.QUEUED,
external_trigger=False,
session=session,
dag_hash=dag_hash,
creating_job_id=self.id,
)
-
- expected_start_date = dag.following_schedule(run.execution_date)
- if expected_start_date:
- schedule_delay = run.start_date - expected_start_date
- Stats.timing(
- f'dagrun.schedule_delay.{dag.dag_id}',
- schedule_delay,
- )
-
- self._update_dag_next_dagruns(dag_models, session)
+ dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info(
+ dag_model.next_dagrun
+ )
# TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in
# memory for larger dags? or expunge_all()
- def _update_dag_next_dagruns(self, dag_models: Iterable[DagModel], session: Session) -> None:
- """
- Bulk update the next_dagrun and next_dagrun_create_after for all the dags.
+ def _start_queued_dagruns(
+ self,
+ session: Session,
+ ) -> int:
+ """Find DagRuns in queued state and decide moving them to running state"""
+ dag_runs = self._get_next_dagruns_to_examine(State.QUEUED, session)
- We batch the select queries to get info about all the dags at once
- """
- # Check max_active_runs, to see if we are _now_ at the limit for any of
- # these dag? (we've just created a DagRun for them after all)
- active_runs_of_dags = dict(
+ active_runs_of_dags = defaultdict(
+ lambda: 0,
session.query(DagRun.dag_id, func.count('*'))
- .filter(
- DagRun.dag_id.in_([o.dag_id for o in dag_models]),
+ .filter( # We use `list` here because SQLA doesn't accept a set
+ # We use set to avoid duplicate dag_ids
+ DagRun.dag_id.in_(list({dr.dag_id for dr in dag_runs})),
DagRun.state == State.RUNNING,
- DagRun.external_trigger.is_(False),
)
.group_by(DagRun.dag_id)
- .all()
+ .all(),
)
- for dag_model in dag_models:
- # Get the DAG in a try_except to not stop the Scheduler when a Serialized DAG is not found
- # This takes care of Dynamic DAGs especially
+ def _update_state(dag_run):
+ dag_run.state = State.RUNNING
+ dag_run.start_date = timezone.utcnow()
+ expected_start_date = dag.following_schedule(dag_run.execution_date)
+ if expected_start_date:
+ schedule_delay = dag_run.start_date - expected_start_date
+ Stats.timing(
+ f'dagrun.schedule_delay.{dag.dag_id}',
+ schedule_delay,
+ )
+
+ for dag_run in dag_runs:
try:
- dag = self.dagbag.get_dag(dag_model.dag_id, session=session)
+ dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
except SerializedDagNotFound:
- self.log.exception("DAG '%s' not found in serialized_dag table", dag_model.dag_id)
+ self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue
- active_runs_of_dag = active_runs_of_dags.get(dag.dag_id, 0)
- if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs:
- self.log.info(
- "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs",
+ active_runs = active_runs_of_dags[dag_run.dag_id]
+ if dag.max_active_runs and active_runs >= dag.max_active_runs:
+ self.log.debug(
+ "DAG %s already has %d active runs, not moving any more runs to RUNNING state %s",
dag.dag_id,
- active_runs_of_dag,
- dag.max_active_runs,
+ active_runs,
+ dag_run.execution_date,
)
- dag_model.next_dagrun_create_after = None
else:
- dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info(
- dag_model.next_dagrun
- )
+ active_runs_of_dags[dag_run.dag_id] += 1
+ _update_state(dag_run)
def _schedule_dag_run(
self,
dag_run: DagRun,
- currently_active_runs: Set[datetime.datetime],
session: Session,
) -> int:
"""
Make scheduling decisions about an individual dag run
- ``currently_active_runs`` is passed in so that a batch query can be
- used to ask this for all dag runs in the batch, to avoid an n+1 query.
-
:param dag_run: The DagRun to schedule
- :param currently_active_runs: Number of currently active runs of this DAG
:return: Number of tasks scheduled
"""
dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
@@ -1725,9 +1692,6 @@ class SchedulerJob(BaseJob):
session.flush()
self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id)
- # Work out if we should allow creating a new DagRun now?
- self._update_dag_next_dagruns([session.query(DagModel).get(dag_run.dag_id)], session)
-
callback_to_execute = DagCallbackRequest(
full_filepath=dag.fileloc,
dag_id=dag.dag_id,
@@ -1745,19 +1709,6 @@ class SchedulerJob(BaseJob):
self.log.error("Execution date is in future: %s", dag_run.execution_date)
return 0
- if dag.max_active_runs:
- if (
- len(currently_active_runs) >= dag.max_active_runs
- and dag_run.execution_date not in currently_active_runs
- ):
- self.log.info(
- "DAG %s already has %d active runs, not queuing any tasks for run %s",
- dag.dag_id,
- len(currently_active_runs),
- dag_run.execution_date,
- )
- return 0
-
self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session)
# TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False)
diff --git a/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py b/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py
new file mode 100644
index 0000000..03caebc
--- /dev/null
+++ b/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add queued_at column to dagrun table
+
+Revision ID: 97cdd93827b8
+Revises: a13f7613ad25
+Create Date: 2021-06-29 21:53:48.059438
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+from sqlalchemy.dialects import mssql
+
+# revision identifiers, used by Alembic.
+revision = '97cdd93827b8'
+down_revision = 'a13f7613ad25'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ """Apply Add queued_at column to dagrun table"""
+ conn = op.get_bind()
+ if conn.dialect.name == "mssql":
+ op.add_column('dag_run', sa.Column('queued_at', mssql.DATETIME2(precision=6), nullable=True))
+ else:
+ op.add_column('dag_run', sa.Column('queued_at', sa.DateTime(), nullable=True))
+
+
+def downgrade():
+ """Unapply Add queued_at column to dagrun table"""
+ op.drop_column('dag_run', "queued_at")
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 13d69c2..a3d06db 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1153,7 +1153,7 @@ class DAG(LoggingMixin):
confirm_prompt=False,
include_subdags=True,
include_parentdag=True,
- dag_run_state: str = State.RUNNING,
+ dag_run_state: str = State.QUEUED,
dry_run=False,
session=None,
get_tis=False,
@@ -1369,7 +1369,7 @@ class DAG(LoggingMixin):
confirm_prompt=False,
include_subdags=True,
include_parentdag=False,
- dag_run_state=State.RUNNING,
+ dag_run_state=State.QUEUED,
dry_run=False,
):
all_tis = []
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 5b8ac0c..c503ac4 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -61,12 +61,15 @@ class DagRun(Base, LoggingMixin):
__tablename__ = "dag_run"
+ __NO_VALUE = object()
+
id = Column(Integer, primary_key=True)
dag_id = Column(String(ID_LEN))
+ queued_at = Column(UtcDateTime)
execution_date = Column(UtcDateTime, default=timezone.utcnow)
- start_date = Column(UtcDateTime, default=timezone.utcnow)
+ start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
- _state = Column('state', String(50), default=State.RUNNING)
+ _state = Column('state', String(50), default=State.QUEUED)
run_id = Column(String(ID_LEN))
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
@@ -102,6 +105,7 @@ class DagRun(Base, LoggingMixin):
self,
dag_id: Optional[str] = None,
run_id: Optional[str] = None,
+ queued_at: Optional[datetime] = __NO_VALUE,
execution_date: Optional[datetime] = None,
start_date: Optional[datetime] = None,
external_trigger: Optional[bool] = None,
@@ -118,6 +122,10 @@ class DagRun(Base, LoggingMixin):
self.external_trigger = external_trigger
self.conf = conf or {}
self.state = state
+ if queued_at is self.__NO_VALUE:
+ self.queued_at = timezone.utcnow() if state == State.QUEUED else None
+ else:
+ self.queued_at = queued_at
self.run_type = run_type
self.dag_hash = dag_hash
self.creating_job_id = creating_job_id
@@ -140,6 +148,8 @@ class DagRun(Base, LoggingMixin):
if self._state != state:
self._state = state
self.end_date = timezone.utcnow() if self._state in State.finished else None
+ if state == State.QUEUED:
+ self.queued_at = timezone.utcnow()
@declared_attr
def state(self):
@@ -160,6 +170,7 @@ class DagRun(Base, LoggingMixin):
@classmethod
def next_dagruns_to_examine(
cls,
+ state: str,
session: Session,
max_number: Optional[int] = None,
):
@@ -180,7 +191,7 @@ class DagRun(Base, LoggingMixin):
# TODO: Bake this query, it is run _A lot_
query = (
session.query(cls)
- .filter(cls.state == State.RUNNING, cls.run_type != DagRunType.BACKFILL_JOB)
+ .filter(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
.join(
DagModel,
DagModel.dag_id == cls.dag_id,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index b99fa34..8d2578f 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -138,7 +138,7 @@ def clear_task_instances(
session,
activate_dag_runs=None,
dag=None,
- dag_run_state: Union[str, Literal[False]] = State.RUNNING,
+ dag_run_state: Union[str, Literal[False]] = State.QUEUED,
):
"""
Clears a set of task instances, but makes sure the running ones
@@ -240,7 +240,9 @@ def clear_task_instances(
)
for dr in drs:
dr.state = dag_run_state
- dr.start_date = timezone.utcnow()
+ dr.start_date = None
+ if dag_run_state == State.QUEUED:
+ dr.last_scheduling_decision = None
class TaskInstanceKey(NamedTuple):
diff --git a/airflow/www/static/js/tree.js b/airflow/www/static/js/tree.js
index 4bf366a..d45c880 100644
--- a/airflow/www/static/js/tree.js
+++ b/airflow/www/static/js/tree.js
@@ -58,7 +58,9 @@ document.addEventListener('DOMContentLoaded', () => {
const tree = d3.layout.tree().nodeSize([0, 25]);
let nodes = tree.nodes(data);
const nodeobj = {};
- const getActiveRuns = () => data.instances.filter((run) => run.state === 'running').length > 0;
+ const runActiveStates = ['queued', 'running'];
+ const getActiveRuns = () => data.instances
+ .filter((run) => runActiveStates.includes(run.state)).length > 0;
const now = Date.now() / 1000;
const devicePixelRatio = window.devicePixelRatio || 1;
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 09d27e0..9af5e1b 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1526,7 +1526,7 @@ class Airflow(AirflowBaseView):
dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=execution_date,
- state=State.RUNNING,
+ state=State.QUEUED,
conf=run_conf,
external_trigger=True,
dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
@@ -3451,6 +3451,7 @@ class DagRunModelView(AirflowModelView):
'execution_date',
'run_id',
'run_type',
+ 'queued_at',
'start_date',
'end_date',
'external_trigger',
@@ -3786,7 +3787,7 @@ class TaskInstanceModelView(AirflowModelView):
lazy_gettext('Clear'),
lazy_gettext(
'Are you sure you want to clear the state of the selected task'
- ' instance(s) and set their dagruns to the running state?'
+ ' instance(s) and set their dagruns to the QUEUED state?'
),
single=False,
)
diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst
index 0c663da..0af143f 100644
--- a/docs/apache-airflow/migrations-ref.rst
+++ b/docs/apache-airflow/migrations-ref.rst
@@ -23,9 +23,7 @@ Here's the list of all the Database Migrations that are executed via when you ru
+--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
-| ``e9304a3141f0`` (head) | ``83f031fd9f1c`` | | Make XCom primary key columns non-nullable |
-+--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
-| ``83f031fd9f1c`` | ``a13f7613ad25`` | | Improve MSSQL compatibility |
+| ``97cdd93827b8`` (head) | ``a13f7613ad25`` | ``2.1.3`` | Add ``queued_at`` column in ``dag_run`` table |
+--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
| ``a13f7613ad25`` | ``e165e7455d70`` | ``2.1.0`` | Resource based permissions for default ``Flask-AppBuilder`` views |
+--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
diff --git a/tests/api/common/experimental/test_mark_tasks.py b/tests/api/common/experimental/test_mark_tasks.py
index 4dab57e..49008d3 100644
--- a/tests/api/common/experimental/test_mark_tasks.py
+++ b/tests/api/common/experimental/test_mark_tasks.py
@@ -414,7 +414,9 @@ class TestMarkDAGRun(unittest.TestCase):
assert ti.state == state
def _create_test_dag_run(self, state, date):
- return self.dag1.create_dagrun(run_type=DagRunType.MANUAL, state=state, execution_date=date)
+ return self.dag1.create_dagrun(
+ run_type=DagRunType.MANUAL, state=state, start_date=date, execution_date=date
+ )
def _verify_dag_run_state(self, dag, date, state):
drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date)
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index e51eca8..0aa13b2 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -90,7 +90,7 @@ class TestDagRunEndpoint:
def teardown_method(self) -> None:
clear_db_runs()
- # clear_db_dags()
+ clear_db_dags()
def _create_dag(self, dag_id):
dag_instance = DagModel(dag_id=dag_id)
@@ -118,6 +118,7 @@ class TestDagRunEndpoint:
execution_date=timezone.parse(self.default_time_2),
start_date=timezone.parse(self.default_time),
external_trigger=True,
+ state=state,
)
dag_runs.append(dagrun_model_2)
if extra_dag:
@@ -131,6 +132,7 @@ class TestDagRunEndpoint:
execution_date=timezone.parse(self.default_time_2),
start_date=timezone.parse(self.default_time),
external_trigger=True,
+ state=state,
)
)
if commit:
@@ -193,6 +195,7 @@ class TestGetDagRun(TestDagRunEndpoint):
execution_date=timezone.parse(self.default_time),
start_date=timezone.parse(self.default_time),
external_trigger=True,
+ state='running',
)
session.add(dagrun_model)
session.commit()
@@ -532,7 +535,7 @@ class TestGetDagRunsEndDateFilters(TestDagRunEndpoint):
(
f"api/v1/dags/TEST_DAG_ID/dagRuns?end_date_lte="
f"{(timezone.utcnow() + timedelta(days=1)).isoformat()}",
- ["TEST_DAG_RUN_ID_1"],
+ ["TEST_DAG_RUN_ID_1", "TEST_DAG_RUN_ID_2"],
),
]
)
@@ -750,6 +753,7 @@ class TestGetDagRunBatchPagination(TestDagRunEndpoint):
DagRun(
dag_id="TEST_DAG_ID",
run_id="TEST_DAG_RUN_ID" + str(i),
+ state='running',
run_type=DagRunType.MANUAL,
execution_date=timezone.parse(self.default_time) + timedelta(minutes=i),
start_date=timezone.parse(self.default_time),
@@ -884,7 +888,7 @@ class TestGetDagRunBatchDateFilters(TestDagRunEndpoint):
),
(
{"end_date_lte": f"{(timezone.utcnow() + timedelta(days=1)).isoformat()}"},
- ["TEST_DAG_RUN_ID_1"],
+ ["TEST_DAG_RUN_ID_1", "TEST_DAG_RUN_ID_2"],
),
]
)
@@ -927,8 +931,8 @@ class TestPostDagRun(TestDagRunEndpoint):
"end_date": None,
"execution_date": response.json["execution_date"],
"external_trigger": True,
- "start_date": response.json["start_date"],
- "state": "running",
+ "start_date": None,
+ "state": "queued",
} == response.json
@parameterized.expand(
diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py b/tests/api_connexion/schemas/test_dag_run_schema.py
index 3e6bf2e..9e4a9e8 100644
--- a/tests/api_connexion/schemas/test_dag_run_schema.py
+++ b/tests/api_connexion/schemas/test_dag_run_schema.py
@@ -49,6 +49,7 @@ class TestDAGRunSchema(TestDAGRunBase):
def test_serialize(self, session):
dagrun_model = DagRun(
run_id="my-dag-run",
+ state='running',
run_type=DagRunType.MANUAL.value,
execution_date=timezone.parse(self.default_time),
start_date=timezone.parse(self.default_time),
@@ -124,6 +125,7 @@ class TestDagRunCollection(TestDAGRunBase):
def test_serialize(self, session):
dagrun_model_1 = DagRun(
run_id="my-dag-run",
+ state='running',
execution_date=timezone.parse(self.default_time),
run_type=DagRunType.MANUAL.value,
start_date=timezone.parse(self.default_time),
@@ -131,6 +133,7 @@ class TestDagRunCollection(TestDAGRunBase):
)
dagrun_model_2 = DagRun(
run_id="my-dag-run-2",
+ state='running',
execution_date=timezone.parse(self.default_time),
start_date=timezone.parse(self.default_time),
run_type=DagRunType.MANUAL.value,
diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py
new file mode 100644
index 0000000..9425bbb
--- /dev/null
+++ b/tests/dag_processing/test_processor.py
@@ -0,0 +1,746 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import datetime
+import os
+import unittest
+from datetime import timedelta
+from tempfile import NamedTemporaryFile
+from unittest import mock
+from unittest.mock import MagicMock, patch
+
+import pytest
+from parameterized import parameterized
+
+from airflow import settings
+from airflow.configuration import conf
+from airflow.dag_processing.processor import DagFileProcessor
+from airflow.jobs.scheduler_job import SchedulerJob
+from airflow.models import DAG, DagBag, DagModel, SlaMiss, TaskInstance
+from airflow.models.dagrun import DagRun
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.models.taskinstance import SimpleTaskInstance
+from airflow.operators.bash import BashOperator
+from airflow.operators.dummy import DummyOperator
+from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.utils import timezone
+from airflow.utils.callback_requests import TaskCallbackRequest
+from airflow.utils.dates import days_ago
+from airflow.utils.session import create_session
+from airflow.utils.state import State
+from airflow.utils.types import DagRunType
+from tests.test_utils.config import conf_vars, env_vars
+from tests.test_utils.db import (
+ clear_db_dags,
+ clear_db_import_errors,
+ clear_db_jobs,
+ clear_db_pools,
+ clear_db_runs,
+ clear_db_serialized_dags,
+ clear_db_sla_miss,
+)
+from tests.test_utils.mock_executor import MockExecutor
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+
+
+@pytest.fixture(scope="class")
+def disable_load_example():
+ with conf_vars({('core', 'load_examples'): 'false'}):
+ with env_vars({('core', 'load_examples'): 'false'}):
+ yield
+
+
+@pytest.mark.usefixtures("disable_load_example")
+class TestDagFileProcessor(unittest.TestCase):
+ @staticmethod
+ def clean_db():
+ clear_db_runs()
+ clear_db_pools()
+ clear_db_dags()
+ clear_db_sla_miss()
+ clear_db_import_errors()
+ clear_db_jobs()
+ clear_db_serialized_dags()
+
+ def setUp(self):
+ self.clean_db()
+
+ # Speed up some tests by not running the tasks, just look at what we
+ # enqueue!
+ self.null_exec = MockExecutor()
+ self.scheduler_job = None
+
+ def tearDown(self) -> None:
+ if self.scheduler_job and self.scheduler_job.processor_agent:
+ self.scheduler_job.processor_agent.end()
+ self.scheduler_job = None
+ self.clean_db()
+
+ def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(hours=1), **kwargs):
+ dag = DAG(
+ dag_id='test_scheduler_reschedule',
+ start_date=start_date,
+ # Make sure it only creates a single DAG Run
+ end_date=end_date,
+ )
+ dag.clear()
+ dag.is_subdag = False
+ with create_session() as session:
+ orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False)
+ session.merge(orm_dag)
+ session.commit()
+ return dag
+
+ @classmethod
+ def setUpClass(cls):
+ # Ensure the DAGs we are looking at from the DB are up-to-date
+ non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False)
+ non_serialized_dagbag.sync_to_db()
+ cls.dagbag = DagBag(read_dags_from_db=True)
+
+ def test_dag_file_processor_sla_miss_callback(self):
+ """
+ Test that the dag file processor calls the sla miss callback
+ """
+ session = settings.Session()
+
+ sla_callback = MagicMock()
+
+ # Create dag with a start of 1 day ago, but an sla of 0
+ # so we'll already have an sla_miss on the books.
+ test_start_date = days_ago(1)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta()},
+ )
+
+ task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
+
+ session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor.manage_slas(dag=dag, session=session)
+
+ assert sla_callback.called
+
+ def test_dag_file_processor_sla_miss_callback_invalid_sla(self):
+ """
+ Test that the dag file processor does not call the sla miss callback when
+ given an invalid sla
+ """
+ session = settings.Session()
+
+ sla_callback = MagicMock()
+
+ # Create dag with a start of 1 day ago, but an sla of 0
+ # so we'll already have an sla_miss on the books.
+ # Pass anything besides a timedelta object to the sla argument.
+ test_start_date = days_ago(1)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date, 'sla': None},
+ )
+
+ task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
+
+ session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor.manage_slas(dag=dag, session=session)
+ sla_callback.assert_not_called()
+
+ def test_dag_file_processor_sla_miss_callback_sent_notification(self):
+ """
+ Test that the dag file processor does not call the sla_miss_callback when a
+ notification has already been sent
+ """
+ session = settings.Session()
+
+ # Mock the callback function so we can verify that it was not called
+ sla_callback = MagicMock()
+
+ # Create dag with a start of 2 days ago, but an sla of 1 day
+ # ago so we'll already have an sla_miss on the books
+ test_start_date = days_ago(2)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
+
+ task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
+
+ # Create a TaskInstance for two days ago
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
+
+ # Create an SlaMiss where notification was sent, but email was not
+ session.merge(
+ SlaMiss(
+ task_id='dummy',
+ dag_id='test_sla_miss',
+ execution_date=test_start_date,
+ email_sent=False,
+ notification_sent=True,
+ )
+ )
+
+ # Now call manage_slas and see if the sla_miss callback gets called
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ dag_file_processor.manage_slas(dag=dag, session=session)
+
+ sla_callback.assert_not_called()
+
+ def test_dag_file_processor_sla_miss_callback_exception(self):
+ """
+ Test that the dag file processor gracefully logs an exception if there is a problem
+ calling the sla_miss_callback
+ """
+ session = settings.Session()
+
+ sla_callback = MagicMock(side_effect=RuntimeError('Could not call function'))
+
+ test_start_date = days_ago(2)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ sla_miss_callback=sla_callback,
+ default_args={'start_date': test_start_date},
+ )
+
+ task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1))
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+ # Create an SlaMiss where notification was sent, but email was not
+ session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+ # Now call manage_slas and see if the sla_miss callback gets called
+ mock_log = mock.MagicMock()
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+ dag_file_processor.manage_slas(dag=dag, session=session)
+ assert sla_callback.called
+ mock_log.exception.assert_called_once_with(
+ 'Could not call sla_miss_callback for DAG %s', 'test_sla_miss'
+ )
+
+ @mock.patch('airflow.dag_processing.processor.send_email')
+ def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email):
+ session = settings.Session()
+
+ test_start_date = days_ago(2)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
+
+ email1 = 'test1@test.com'
+ task = DummyOperator(
+ task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1)
+ )
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+ email2 = 'test2@test.com'
+ DummyOperator(task_id='sla_not_missed', dag=dag, owner='airflow', email=email2)
+
+ session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date))
+
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+
+ dag_file_processor.manage_slas(dag=dag, session=session)
+
+ assert len(mock_send_email.call_args_list) == 1
+
+ send_email_to = mock_send_email.call_args_list[0][0][0]
+ assert email1 in send_email_to
+ assert email2 not in send_email_to
+
+ @mock.patch('airflow.dag_processing.processor.Stats.incr')
+ @mock.patch("airflow.utils.email.send_email")
+ def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock_stats_incr):
+ """
+ Test that the dag file processor gracefully logs an exception if there is a problem
+ sending an email
+ """
+ session = settings.Session()
+
+ # Mock the callback function so we can verify that it was not called
+ mock_send_email.side_effect = RuntimeError('Could not send an email')
+
+ test_start_date = days_ago(2)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
+
+ task = DummyOperator(
+ task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
+ )
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+ # Create an SlaMiss where notification was sent, but email was not
+ session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+ mock_log = mock.MagicMock()
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+
+ dag_file_processor.manage_slas(dag=dag, session=session)
+ mock_log.exception.assert_called_once_with(
+ 'Could not send SLA Miss email notification for DAG %s', 'test_sla_miss'
+ )
+ mock_stats_incr.assert_called_once_with('sla_email_notification_failure')
+
+ def test_dag_file_processor_sla_miss_deleted_task(self):
+ """
+ Test that the dag file processor will not crash when trying to send
+ sla miss notification for a deleted task
+ """
+ session = settings.Session()
+
+ test_start_date = days_ago(2)
+ dag = DAG(
+ dag_id='test_sla_miss',
+ default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+ )
+
+ task = DummyOperator(
+ task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
+ )
+
+ session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+ # Create an SlaMiss where notification was sent, but email was not
+ session.merge(
+ SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss', execution_date=test_start_date)
+ )
+
+ mock_log = mock.MagicMock()
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+ dag_file_processor.manage_slas(dag=dag, session=session)
+
+ @parameterized.expand(
+ [
+ [State.NONE, None, None],
+ [
+ State.UP_FOR_RETRY,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ [
+ State.UP_FOR_RESCHEDULE,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ ]
+ )
+ def test_dag_file_processor_process_task_instances(self, state, start_date, end_date):
+ """
+ Test if _process_task_instances puts the right task instances into the
+ mock_list.
+ """
+ dag = DAG(dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE)
+ BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi')
+
+ with create_session() as session:
+ orm_dag = DagModel(dag_id=dag.dag_id)
+ session.merge(orm_dag)
+
+ dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+ self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+ dag.clear()
+ dr = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+ assert dr is not None
+
+ with create_session() as session:
+ ti = dr.get_task_instances(session=session)[0]
+ ti.state = state
+ ti.start_date = start_date
+ ti.end_date = end_date
+
+ count = self.scheduler_job._schedule_dag_run(dr, session)
+ assert count == 1
+
+ session.refresh(ti)
+ assert ti.state == State.SCHEDULED
+
+ @parameterized.expand(
+ [
+ [State.NONE, None, None],
+ [
+ State.UP_FOR_RETRY,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ [
+ State.UP_FOR_RESCHEDULE,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ ]
+ )
+ def test_dag_file_processor_process_task_instances_with_task_concurrency(
+ self,
+ state,
+ start_date,
+ end_date,
+ ):
+ """
+ Test if _process_task_instances puts the right task instances into the
+ mock_list.
+ """
+ dag = DAG(dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE)
+ BashOperator(task_id='dummy', task_concurrency=2, dag=dag, owner='airflow', bash_command='echo Hi')
+
+ with create_session() as session:
+ orm_dag = DagModel(dag_id=dag.dag_id)
+ session.merge(orm_dag)
+
+ dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+ self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+ dag.clear()
+ dr = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+ assert dr is not None
+
+ with create_session() as session:
+ ti = dr.get_task_instances(session=session)[0]
+ ti.state = state
+ ti.start_date = start_date
+ ti.end_date = end_date
+
+ count = self.scheduler_job._schedule_dag_run(dr, session)
+ assert count == 1
+
+ session.refresh(ti)
+ assert ti.state == State.SCHEDULED
+
+ @parameterized.expand(
+ [
+ [State.NONE, None, None],
+ [
+ State.UP_FOR_RETRY,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ [
+ State.UP_FOR_RESCHEDULE,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ ]
+ )
+ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, start_date, end_date):
+ """
+ Test if _process_task_instances puts the right task instances into the
+ mock_list.
+ """
+ dag = DAG(
+ dag_id='test_scheduler_process_execute_task_depends_on_past',
+ start_date=DEFAULT_DATE,
+ default_args={
+ 'depends_on_past': True,
+ },
+ )
+ BashOperator(task_id='dummy1', dag=dag, owner='airflow', bash_command='echo hi')
+ BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo hi')
+
+ with create_session() as session:
+ orm_dag = DagModel(dag_id=dag.dag_id)
+ session.merge(orm_dag)
+
+ dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+ self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+ dag.clear()
+ dr = dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ )
+ assert dr is not None
+
+ with create_session() as session:
+ tis = dr.get_task_instances(session=session)
+ for ti in tis:
+ ti.state = state
+ ti.start_date = start_date
+ ti.end_date = end_date
+
+ count = self.scheduler_job._schedule_dag_run(dr, session)
+ assert count == 2
+
+ session.refresh(tis[0])
+ session.refresh(tis[1])
+ assert tis[0].state == State.SCHEDULED
+ assert tis[1].state == State.SCHEDULED
+
+ def test_scheduler_job_add_new_task(self):
+ """
+ Test if a task instance will be added if the dag is updated
+ """
+ dag = DAG(dag_id='test_scheduler_add_new_task', start_date=DEFAULT_DATE)
+ BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo test')
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+
+ # Since we don't want to store the code for the DAG defined in this file
+ with mock.patch.object(settings, "STORE_DAG_CODE", False):
+ self.scheduler_job.dagbag.sync_to_db()
+
+ session = settings.Session()
+ orm_dag = session.query(DagModel).get(dag.dag_id)
+ assert orm_dag is not None
+
+ if self.scheduler_job.processor_agent:
+ self.scheduler_job.processor_agent.end()
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+ dag = self.scheduler_job.dagbag.get_dag('test_scheduler_add_new_task', session=session)
+ self.scheduler_job._create_dag_runs([orm_dag], session)
+
+ drs = DagRun.find(dag_id=dag.dag_id, session=session)
+ assert len(drs) == 1
+ dr = drs[0]
+
+ tis = dr.get_task_instances()
+ assert len(tis) == 1
+
+ BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test')
+ SerializedDagModel.write_dag(dag=dag)
+
+ scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
+ session.flush()
+ assert scheduled_tis == 2
+
+ drs = DagRun.find(dag_id=dag.dag_id, session=session)
+ assert len(drs) == 1
+ dr = drs[0]
+
+ tis = dr.get_task_instances()
+ assert len(tis) == 2
+
+ def test_runs_respected_after_clear(self):
+ """
+ Test dag after dag.clear, max_active_runs is respected
+ """
+ dag = DAG(dag_id='test_scheduler_max_active_runs_respected_after_clear', start_date=DEFAULT_DATE)
+ dag.max_active_runs = 1
+
+ BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo Hi')
+
+ session = settings.Session()
+ orm_dag = DagModel(dag_id=dag.dag_id)
+ session.merge(orm_dag)
+ session.commit()
+ session.close()
+ dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ # Write Dag to DB
+ dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False)
+ dagbag.bag_dag(dag, root_dag=dag)
+ dagbag.sync_to_db()
+
+ dag = DagBag(read_dags_from_db=True, include_examples=False).get_dag(dag.dag_id)
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+ self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+
+ date = DEFAULT_DATE
+ dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=date,
+ state=State.QUEUED,
+ )
+ date = dag.following_schedule(date)
+ dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=date,
+ state=State.QUEUED,
+ )
+ date = dag.following_schedule(date)
+ dag.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ execution_date=date,
+ state=State.QUEUED,
+ )
+ dag.clear()
+
+ assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 3
+
+ session = settings.Session()
+ self.scheduler_job._start_queued_dagruns(session)
+ session.commit()
+ # Assert that only 1 dagrun is active
+ assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 1
+ # Assert that the other two are queued
+ assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 2
+
+ @patch.object(TaskInstance, 'handle_failure_with_callback')
+ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
+ dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ with create_session() as session:
+ session.query(TaskInstance).delete()
+ dag = dagbag.get_dag('example_branch_operator')
+ task = dag.get_task(task_id='run_this_first')
+
+ ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
+
+ session.add(ti)
+ session.commit()
+
+ requests = [
+ TaskCallbackRequest(
+ full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message"
+ )
+ ]
+ dag_file_processor.execute_callbacks(dagbag, requests)
+ mock_ti_handle_failure.assert_called_once_with(
+ error="Message",
+ test_mode=conf.getboolean('core', 'unit_test_mode'),
+ )
+
+ def test_process_file_should_failure_callback(self):
+ dag_file = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py'
+ )
+ dagbag = DagBag(dag_folder=dag_file, include_examples=False)
+ dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+ with create_session() as session, NamedTemporaryFile(delete=False) as callback_file:
+ session.query(TaskInstance).delete()
+ dag = dagbag.get_dag('test_om_failure_callback_dag')
+ task = dag.get_task(task_id='test_om_failure_callback_task')
+
+ ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
+
+ session.add(ti)
+ session.commit()
+
+ requests = [
+ TaskCallbackRequest(
+ full_filepath=dag.full_filepath,
+ simple_task_instance=SimpleTaskInstance(ti),
+ msg="Message",
+ )
+ ]
+ callback_file.close()
+
+ with mock.patch.dict("os.environ", {"AIRFLOW_CALLBACK_FILE": callback_file.name}):
+ dag_file_processor.process_file(dag_file, requests)
+ with open(callback_file.name) as callback_file2:
+ content = callback_file2.read()
+ assert "Callback fired" == content
+ os.remove(callback_file.name)
+
+ def test_should_mark_dummy_task_as_success(self):
+ dag_file = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py'
+ )
+
+ # Write DAGs to dag and serialized_dag table
+ dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False)
+ dagbag.sync_to_db()
+
+ self.scheduler_job_job = SchedulerJob(subdir=os.devnull)
+ self.scheduler_job_job.processor_agent = mock.MagicMock()
+ dag = self.scheduler_job_job.dagbag.get_dag("test_only_dummy_tasks")
+
+ # Create DagRun
+ session = settings.Session()
+ orm_dag = session.query(DagModel).get(dag.dag_id)
+ self.scheduler_job_job._create_dag_runs([orm_dag], session)
+
+ drs = DagRun.find(dag_id=dag.dag_id, session=session)
+ assert len(drs) == 1
+ dr = drs[0]
+
+ # Schedule TaskInstances
+ self.scheduler_job_job._schedule_dag_run(dr, session)
+ with create_session() as session:
+ tis = session.query(TaskInstance).all()
+
+ dags = self.scheduler_job_job.dagbag.dags.values()
+ assert ['test_only_dummy_tasks'] == [dag.dag_id for dag in dags]
+ assert 5 == len(tis)
+ assert {
+ ('test_task_a', 'success'),
+ ('test_task_b', None),
+ ('test_task_c', 'success'),
+ ('test_task_on_execute', 'scheduled'),
+ ('test_task_on_success', 'scheduled'),
+ } == {(ti.task_id, ti.state) for ti in tis}
+ for state, start_date, end_date, duration in [
+ (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
+ ]:
+ if state == 'success':
+ assert start_date is not None
+ assert end_date is not None
+ assert 0.0 == duration
+ else:
+ assert start_date is None
+ assert end_date is None
+ assert duration is None
+
+ self.scheduler_job_job._schedule_dag_run(dr, session)
+ with create_session() as session:
+ tis = session.query(TaskInstance).all()
+
+ assert 5 == len(tis)
+ assert {
+ ('test_task_a', 'success'),
+ ('test_task_b', 'success'),
+ ('test_task_c', 'success'),
+ ('test_task_on_execute', 'scheduled'),
+ ('test_task_on_success', 'scheduled'),
+ } == {(ti.task_id, ti.state) for ti in tis}
+ for state, start_date, end_date, duration in [
+ (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
+ ]:
+ if state == 'success':
+ assert start_date is not None
+ assert end_date is not None
+ assert 0.0 == duration
+ else:
+ assert start_date is None
+ assert end_date is None
+ assert duration is None
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index fe0b257..33c82d7 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -23,7 +23,6 @@ import shutil
import unittest
from datetime import timedelta
from tempfile import NamedTemporaryFile, mkdtemp
-from time import sleep
from unittest import mock
from unittest.mock import MagicMock, patch
from zipfile import ZipFile
@@ -1114,7 +1113,6 @@ class TestSchedulerJob(unittest.TestCase):
session.flush()
res = self.scheduler_job._executable_task_instances_to_queued(max_tis=32, session=session)
-
assert 2 == len(res)
res_keys = map(lambda x: x.key, res)
assert ti_no_dagrun.key in res_keys
@@ -2259,15 +2257,16 @@ class TestSchedulerJob(unittest.TestCase):
dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
- self.scheduler_job = SchedulerJob(subdir=os.devnull)
self.scheduler_job._create_dag_runs([orm_dag], session)
+ self.scheduler_job._start_queued_dagruns(session)
drs = DagRun.find(dag_id=dag.dag_id, session=session)
assert len(drs) == 1
dr = drs[0]
- # Should not be able to create a new dag run, as we are at max active runs
- assert orm_dag.next_dagrun_create_after is None
+ # This should have a value since we control max_active_runs
+ # by DagRun State.
+ assert orm_dag.next_dagrun_create_after
# But we should record the date of _what run_ it would be
assert isinstance(orm_dag.next_dagrun, datetime.datetime)
@@ -2279,7 +2278,7 @@ class TestSchedulerJob(unittest.TestCase):
self.scheduler_job.processor_agent = mock.Mock()
self.scheduler_job.processor_agent.send_callback_to_execute = mock.Mock()
- self.scheduler_job._schedule_dag_run(dr, {}, session)
+ self.scheduler_job._schedule_dag_run(dr, session)
session.flush()
session.refresh(dr)
@@ -2336,7 +2335,7 @@ class TestSchedulerJob(unittest.TestCase):
self.scheduler_job.processor_agent = mock.Mock()
self.scheduler_job.processor_agent.send_callback_to_execute = mock.Mock()
- self.scheduler_job._schedule_dag_run(dr, {}, session)
+ self.scheduler_job._schedule_dag_run(dr, session)
session.flush()
session.refresh(dr)
@@ -2395,7 +2394,7 @@ class TestSchedulerJob(unittest.TestCase):
ti = dr.get_task_instance('dummy')
ti.set_state(state, session)
- self.scheduler_job._schedule_dag_run(dr, {}, session)
+ self.scheduler_job._schedule_dag_run(dr, session)
expected_callback = DagCallbackRequest(
full_filepath=dr.dag.fileloc,
@@ -2450,7 +2449,7 @@ class TestSchedulerJob(unittest.TestCase):
ti = dr.get_task_instance('test_task')
ti.set_state(state, session)
- self.scheduler_job._schedule_dag_run(dr, set(), session)
+ self.scheduler_job._schedule_dag_run(dr, session)
# Verify Callback is not set (i.e is None) when no callbacks are set on DAG
self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once_with(dr, None)
@@ -2830,13 +2829,13 @@ class TestSchedulerJob(unittest.TestCase):
execution_date=DEFAULT_DATE,
state=State.RUNNING,
)
- self.scheduler_job._schedule_dag_run(dr, {}, session)
+ self.scheduler_job._schedule_dag_run(dr, session)
dr = dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=dag.following_schedule(dr.execution_date),
state=State.RUNNING,
)
- self.scheduler_job._schedule_dag_run(dr, {}, session)
+ self.scheduler_job._schedule_dag_run(dr, session)
task_instances_list = self.scheduler_job._executable_task_instances_to_queued(
max_tis=32, session=session
)
@@ -2887,7 +2886,7 @@ class TestSchedulerJob(unittest.TestCase):
execution_date=date,
state=State.RUNNING,
)
- self.scheduler_job._schedule_dag_run(dr, {}, session)
+ self.scheduler_job._schedule_dag_run(dr, session)
date = dag.following_schedule(date)
task_instances_list = self.scheduler_job._executable_task_instances_to_queued(
@@ -2950,7 +2949,7 @@ class TestSchedulerJob(unittest.TestCase):
execution_date=date,
state=State.RUNNING,
)
- scheduler._schedule_dag_run(dr, {}, session)
+ scheduler._schedule_dag_run(dr, session)
date = dag_d1.following_schedule(date)
date = DEFAULT_DATE
@@ -2960,7 +2959,7 @@ class TestSchedulerJob(unittest.TestCase):
execution_date=date,
state=State.RUNNING,
)
- scheduler._schedule_dag_run(dr, {}, session)
+ scheduler._schedule_dag_run(dr, session)
date = dag_d2.following_schedule(date)
scheduler._executable_task_instances_to_queued(max_tis=2, session=session)
@@ -3037,7 +3036,7 @@ class TestSchedulerJob(unittest.TestCase):
execution_date=DEFAULT_DATE,
state=State.RUNNING,
)
- self.scheduler_job._schedule_dag_run(dr, {}, session)
+ self.scheduler_job._schedule_dag_run(dr, session)
task_instances_list = self.scheduler_job._executable_task_instances_to_queued(
max_tis=32, session=session
@@ -3096,7 +3095,7 @@ class TestSchedulerJob(unittest.TestCase):
# Verify that DagRun.verify_integrity is not called
with mock.patch('airflow.jobs.scheduler_job.DagRun.verify_integrity') as mock_verify_integrity:
- scheduled_tis = self.scheduler_job._schedule_dag_run(dr, {}, session)
+ scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
mock_verify_integrity.assert_not_called()
session.flush()
@@ -3159,7 +3158,7 @@ class TestSchedulerJob(unittest.TestCase):
dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
assert dag_version_2 != dag_version_1
- scheduled_tis = self.scheduler_job._schedule_dag_run(dr, {}, session)
+ scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
session.flush()
assert scheduled_tis == 2
@@ -3871,14 +3870,13 @@ class TestSchedulerJob(unittest.TestCase):
full_filepath=dag.fileloc, dag_id=dag_id
)
- @freeze_time(DEFAULT_DATE + datetime.timedelta(days=1, seconds=9))
- @mock.patch('airflow.jobs.scheduler_job.Stats.timing')
- def test_create_dag_runs(self, stats_timing):
+ def test_create_dag_runs(self):
"""
Test various invariants of _create_dag_runs.
- That the run created has the creating_job_id set
- - That we emit the right DagRun metrics
+ - That the run created is on QUEUED State
+ - That dag_model has next_dagrun
"""
dag = DAG(dag_id='test_create_dag_runs', start_date=DEFAULT_DATE)
@@ -3902,8 +3900,51 @@ class TestSchedulerJob(unittest.TestCase):
with create_session() as session:
self.scheduler_job._create_dag_runs([dag_model], session)
+ dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).first()
+ # Assert dr state is queued
+ assert dr.state == State.QUEUED
+ assert dr.start_date is None
+
+ assert dag.get_last_dagrun().creating_job_id == self.scheduler_job.id
+
+ @freeze_time(DEFAULT_DATE + datetime.timedelta(days=1, seconds=9))
+ @mock.patch('airflow.jobs.scheduler_job.Stats.timing')
+ def test_start_dagruns(self, stats_timing):
+ """
+ Test that _start_dagrun:
+
+ - moves runs to RUNNING State
+ - emit the right DagRun metrics
+ """
+ dag = DAG(dag_id='test_start_dag_runs', start_date=DEFAULT_DATE)
+
+ DummyOperator(
+ task_id='dummy',
+ dag=dag,
+ )
+
+ dagbag = DagBag(
+ dag_folder=os.devnull,
+ include_examples=False,
+ read_dags_from_db=True,
+ )
+ dagbag.bag_dag(dag=dag, root_dag=dag)
+ dagbag.sync_to_db()
+ dag_model = DagModel.get_dagmodel(dag.dag_id)
+
+ self.scheduler_job = SchedulerJob(executor=self.null_exec)
+ self.scheduler_job.processor_agent = mock.MagicMock()
+
+ with create_session() as session:
+ self.scheduler_job._create_dag_runs([dag_model], session)
+ self.scheduler_job._start_queued_dagruns(session)
+
+ dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).first()
+ # Assert dr state is running
+ assert dr.state == State.RUNNING
+
stats_timing.assert_called_once_with(
- "dagrun.schedule_delay.test_create_dag_runs", datetime.timedelta(seconds=9)
+ "dagrun.schedule_delay.test_start_dag_runs", datetime.timedelta(seconds=9)
)
assert dag.get_last_dagrun().creating_job_id == self.scheduler_job.id
@@ -4102,61 +4143,7 @@ class TestSchedulerJob(unittest.TestCase):
assert dag_model.next_dagrun == DEFAULT_DATE + timedelta(days=1)
session.rollback()
- def test_do_schedule_max_active_runs_upstream_failed(self):
- """
- Test that tasks in upstream failed don't count as actively running.
-
- This test can be removed when adding a queued state to DagRuns.
- """
-
- with DAG(
- dag_id='test_max_active_run_with_upstream_failed',
- start_date=DEFAULT_DATE,
- schedule_interval='@once',
- max_active_runs=1,
- ) as dag:
- # Can't use DummyOperator as that goes straight to success
- task1 = BashOperator(task_id='dummy1', bash_command='true')
-
- session = settings.Session()
- dagbag = DagBag(
- dag_folder=os.devnull,
- include_examples=False,
- read_dags_from_db=True,
- )
-
- dagbag.bag_dag(dag=dag, root_dag=dag)
- dagbag.sync_to_db(session=session)
-
- run1 = dag.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- session=session,
- )
-
- ti = run1.get_task_instance(task1.task_id, session)
- ti.state = State.UPSTREAM_FAILED
-
- run2 = dag.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- execution_date=DEFAULT_DATE + timedelta(hours=1),
- state=State.RUNNING,
- session=session,
- )
-
- dag.sync_to_db(session=session) # Update the date fields
-
- self.scheduler_job = SchedulerJob(subdir=os.devnull)
- self.scheduler_job.executor = MockExecutor(do_update=False)
- self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent)
-
- num_queued = self.scheduler_job._do_scheduling(session)
-
- assert num_queued == 1
- ti = run2.get_task_instance(task1.task_id, session)
- assert ti.state == State.QUEUED
-
+ @conf_vars({('scheduler', 'use_job_schedule'): "false"})
def test_do_schedule_max_active_runs_dag_timed_out(self):
"""Test that tasks are set to a finished state when their DAG times out"""
@@ -4189,33 +4176,36 @@ class TestSchedulerJob(unittest.TestCase):
run_type=DagRunType.SCHEDULED,
execution_date=DEFAULT_DATE,
state=State.RUNNING,
+ start_date=timezone.utcnow() - timedelta(seconds=2),
session=session,
)
+
run1_ti = run1.get_task_instance(task1.task_id, session)
run1_ti.state = State.RUNNING
- sleep(1)
-
run2 = dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=DEFAULT_DATE + timedelta(seconds=10),
- state=State.RUNNING,
+ state=State.QUEUED,
session=session,
)
dag.sync_to_db(session=session)
-
self.scheduler_job = SchedulerJob(subdir=os.devnull)
self.scheduler_job.executor = MockExecutor()
self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent)
- _ = self.scheduler_job._do_scheduling(session)
-
+ self.scheduler_job._do_scheduling(session)
+ session.add(run1)
+ session.refresh(run1)
assert run1.state == State.FAILED
assert run1_ti.state == State.SKIPPED
- assert run2.state == State.RUNNING
- _ = self.scheduler_job._do_scheduling(session)
+ # Run scheduling again to assert run2 has started
+ self.scheduler_job._do_scheduling(session)
+ session.add(run2)
+ session.refresh(run2)
+ assert run2.state == State.RUNNING
run2_ti = run2.get_task_instance(task1.task_id, session)
assert run2_ti.state == State.QUEUED
@@ -4265,8 +4255,8 @@ class TestSchedulerJob(unittest.TestCase):
def test_do_schedule_max_active_runs_and_manual_trigger(self):
"""
- Make sure that when a DAG is already at max_active_runs, that manually triggering a run doesn't cause
- the dag to "stall".
+ Make sure that when a DAG is already at max_active_runs, that manually triggered
+ dagruns don't start running.
"""
with DAG(
@@ -4281,7 +4271,7 @@ class TestSchedulerJob(unittest.TestCase):
task1 >> task2
- task3 = BashOperator(task_id='dummy3', bash_command='true')
+ BashOperator(task_id='dummy3', bash_command='true')
session = settings.Session()
dagbag = DagBag(
@@ -4296,7 +4286,7 @@ class TestSchedulerJob(unittest.TestCase):
dag_run = dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=DEFAULT_DATE,
- state=State.RUNNING,
+ state=State.QUEUED,
session=session,
)
@@ -4314,47 +4304,23 @@ class TestSchedulerJob(unittest.TestCase):
assert num_queued == 2
assert dag_run.state == State.RUNNING
- ti1 = dag_run.get_task_instance(task1.task_id, session)
- assert ti1.state == State.QUEUED
-
- # Set task1 to success (so task2 can run) but keep task3 as "running"
- ti1.state = State.SUCCESS
-
- ti3 = dag_run.get_task_instance(task3.task_id, session)
- ti3.state = State.RUNNING
-
- session.flush()
-
- # At this point, ti2 and ti3 of the scheduled dag run should be running
- num_queued = self.scheduler_job._do_scheduling(session)
-
- assert num_queued == 1
- # Should have queued task2
- ti2 = dag_run.get_task_instance(task2.task_id, session)
- assert ti2.state == State.QUEUED
-
- ti2.state = None
- session.flush()
# Now that this one is running, manually trigger a dag.
- manual_run = dag.create_dagrun(
+ dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=DEFAULT_DATE + timedelta(hours=1),
- state=State.RUNNING,
+ state=State.QUEUED,
session=session,
)
session.flush()
- num_queued = self.scheduler_job._do_scheduling(session)
+ self.scheduler_job._do_scheduling(session)
- assert num_queued == 1
- # Should have queued task2 again.
- ti2 = dag_run.get_task_instance(task2.task_id, session)
- assert ti2.state == State.QUEUED
- # Manual run shouldn't have been started, because we're at max_active_runs with DR1
- ti1 = manual_run.get_task_instance(task1.task_id, session)
- assert ti1.state is None
+ # Assert that only 1 dagrun is active
+ assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 1
+ # Assert that the other one is queued
+ assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 1
@pytest.mark.xfail(reason="Work out where this goes")
diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py
index 9b8fbd0..4f64347 100644
--- a/tests/models/test_cleartasks.py
+++ b/tests/models/test_cleartasks.py
@@ -19,6 +19,8 @@
import datetime
import unittest
+from parameterized import parameterized
+
from airflow import settings
from airflow.models import DAG, TaskInstance as TI, TaskReschedule, clear_task_instances
from airflow.operators.dummy import DummyOperator
@@ -92,6 +94,41 @@ class TestClearTasks(unittest.TestCase):
assert ti0.state is None
assert ti0.external_executor_id is None
+ @parameterized.expand([(State.QUEUED, None), (State.RUNNING, DEFAULT_DATE)])
+ def test_clear_task_instances_dr_state(self, state, last_scheduling):
+ """Test that DR state is set to None after clear.
+ And that DR.last_scheduling_decision is handled OK.
+ start_date is also set to None
+ """
+ dag = DAG(
+ 'test_clear_task_instances',
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+ )
+ task0 = DummyOperator(task_id='0', owner='test', dag=dag)
+ task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2)
+ ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
+ ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
+ session = settings.Session()
+ dr = dag.create_dagrun(
+ execution_date=ti0.execution_date,
+ state=State.RUNNING,
+ run_type=DagRunType.SCHEDULED,
+ )
+ dr.last_scheduling_decision = DEFAULT_DATE
+ session.add(dr)
+ session.commit()
+
+ ti0.run()
+ ti1.run()
+ qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
+ clear_task_instances(qry, session, dag_run_state=state, dag=dag)
+
+ dr = ti0.get_dagrun()
+ assert dr.state == state
+ assert dr.start_date is None
+ assert dr.last_scheduling_decision == last_scheduling
+
def test_clear_task_instances_without_task(self):
dag = DAG(
'test_clear_task_instances_without_task',
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 7899199..fac38bf 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -39,7 +39,7 @@ from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import DagRunType
from tests.models import DEFAULT_DATE
-from tests.test_utils.db import clear_db_jobs, clear_db_pools, clear_db_runs
+from tests.test_utils.db import clear_db_dags, clear_db_jobs, clear_db_pools, clear_db_runs
class TestDagRun(unittest.TestCase):
@@ -50,6 +50,12 @@ class TestDagRun(unittest.TestCase):
def setUp(self):
clear_db_runs()
clear_db_pools()
+ clear_db_dags()
+
+ def tearDown(self) -> None:
+ clear_db_runs()
+ clear_db_pools()
+ clear_db_dags()
def create_dag_run(
self,
@@ -102,7 +108,7 @@ class TestDagRun(unittest.TestCase):
session.commit()
ti0.refresh_from_db()
dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == now).first()
- assert dr0.state == State.RUNNING
+ assert dr0.state == State.QUEUED
def test_dagrun_find(self):
session = settings.Session()
@@ -692,9 +698,11 @@ class TestDagRun(unittest.TestCase):
ti.run()
assert (ti.state == State.SUCCESS) == is_ti_success
- def test_next_dagruns_to_examine_only_unpaused(self):
+ @parameterized.expand([(State.QUEUED,), (State.RUNNING,)])
+ def test_next_dagruns_to_examine_only_unpaused(self, state):
"""
Check that "next_dagruns_to_examine" ignores runs from paused/inactive DAGs
+ and gets running/queued dagruns
"""
dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
@@ -712,25 +720,22 @@ class TestDagRun(unittest.TestCase):
session.flush()
dr = dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
- state=State.RUNNING,
+ state=state,
execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE if state == State.RUNNING else None,
session=session,
)
- runs = DagRun.next_dagruns_to_examine(session).all()
+ runs = DagRun.next_dagruns_to_examine(state, session).all()
assert runs == [dr]
orm_dag.is_paused = True
session.flush()
- runs = DagRun.next_dagruns_to_examine(session).all()
+ runs = DagRun.next_dagruns_to_examine(state, session).all()
assert runs == []
- session.rollback()
- session.close()
-
@mock.patch.object(Stats, 'timing')
def test_no_scheduling_delay_for_nonscheduled_runs(self, stats_mock):
"""
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index 187fdb2..274766c 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -545,16 +545,16 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child):
task_0 = dag_0.get_task("task_0")
clear_tasks(dag_bag, dag_0, task_0, start_date=day_1, end_date=day_2)
- # Assert that dagruns of all the affected dags are set to RUNNING after tasks are cleared.
+ # Assert that dagruns of all the affected dags are set to QUEUED after tasks are cleared.
# Unaffected dagruns should be left as SUCCESS.
dagrun_0_1 = dag_bag.get_dag('parent_dag_0').get_dagrun(execution_date=day_1)
dagrun_0_2 = dag_bag.get_dag('parent_dag_0').get_dagrun(execution_date=day_2)
dagrun_1_1 = dag_bag.get_dag('child_dag_1').get_dagrun(execution_date=day_1)
dagrun_1_2 = dag_bag.get_dag('child_dag_1').get_dagrun(execution_date=day_2)
- assert dagrun_0_1.state == State.RUNNING
- assert dagrun_0_2.state == State.RUNNING
- assert dagrun_1_1.state == State.RUNNING
+ assert dagrun_0_1.state == State.QUEUED
+ assert dagrun_0_2.state == State.QUEUED
+ assert dagrun_1_1.state == State.QUEUED
assert dagrun_1_2.state == State.SUCCESS
diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py
index 58ad010..e38c184 100644
--- a/tests/utils/test_dag_processing.py
+++ b/tests/utils/test_dag_processing.py
@@ -35,7 +35,7 @@ from freezegun import freeze_time
from airflow.configuration import conf
from airflow.jobs.local_task_job import LocalTaskJob as LJ
-from airflow.jobs.scheduler_job import DagFileProcessorProcess
+from airflow.jobs.scheduler_job import DagFileProcessorProcess, SchedulerJob
from airflow.models import DagBag, DagModel, TaskInstance as TI
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import SimpleTaskInstance
@@ -508,8 +508,8 @@ class TestDagFileProcessorManager(unittest.TestCase):
child_pipe.close()
parent_pipe.close()
- @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess.pid", new_callable=PropertyMock)
- @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess.kill")
+ @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.pid", new_callable=PropertyMock)
+ @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.kill")
def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid):
mock_pid.return_value = 1234
manager = DagFileProcessorManager(
@@ -529,8 +529,8 @@ class TestDagFileProcessorManager(unittest.TestCase):
manager._kill_timed_out_processors()
mock_kill.assert_called_once_with()
- @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess.pid", new_callable=PropertyMock)
- @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess")
+ @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.pid", new_callable=PropertyMock)
+ @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess")
def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_pid):
mock_pid.return_value = 1234
manager = DagFileProcessorManager(
@@ -560,7 +560,6 @@ class TestDagFileProcessorManager(unittest.TestCase):
# We need to _actually_ parse the files here to test the behaviour.
# Right now the parsing code lives in SchedulerJob, even though it's
# called via utils.dag_processing.
- from airflow.jobs.scheduler_job import SchedulerJob
dag_id = 'exit_test_dag'
dag_directory = TEST_DAG_FOLDER.parent / 'dags_with_system_exit'