You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/08/16 15:23:41 UTC
[airflow] 02/04: Improve `dag_maker` fixture (#17324)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 06d0918e5f17e0da047771ab24936cc89feca33f
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Mon Aug 2 07:37:40 2021 +0100
Improve `dag_maker` fixture (#17324)
This PR improves the dag_maker fixture to enable creation of dagrun, dag and dag_model separately
Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
(cherry picked from commit 5c1e09cafacea922b9281e901db7da7cadb3e9be)
---
tests/conftest.py | 53 +++++-----
tests/jobs/test_backfill_job.py | 205 ++++++++++++++++++++------------------
tests/jobs/test_local_task_job.py | 104 ++++++++-----------
3 files changed, 176 insertions(+), 186 deletions(-)
diff --git a/tests/conftest.py b/tests/conftest.py
index f2c5345..0873ac4 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -429,8 +429,9 @@ def app():
@pytest.fixture
def dag_maker(request):
- from airflow.models import DAG
+ from airflow.models import DAG, DagModel
from airflow.utils import timezone
+ from airflow.utils.session import provide_session
from airflow.utils.state import State
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -445,33 +446,39 @@ def dag_maker(request):
dag.__exit__(type, value, traceback)
if type is None:
dag.clear()
- self.dag_run = dag.create_dagrun(
- run_id=self.kwargs.get("run_id", "test"),
- state=self.kwargs.get('state', State.RUNNING),
- execution_date=self.kwargs.get('execution_date', self.kwargs['start_date']),
- start_date=self.kwargs['start_date'],
- )
+
+ @provide_session
+ def make_dagmodel(self, session=None, **kwargs):
+ dag = self.dag
+ defaults = dict(dag_id=dag.dag_id, next_dagrun=dag.start_date, is_active=True)
+ kwargs = {**defaults, **kwargs}
+ dag_model = DagModel(**kwargs)
+ session.add(dag_model)
+ session.flush()
+ return dag_model
+
+ def create_dagrun(self, **kwargs):
+ dag = self.dag
+ defaults = dict(
+ run_id='test',
+ state=State.RUNNING,
+ execution_date=self.start_date,
+ start_date=self.start_date,
+ )
+ kwargs = {**defaults, **kwargs}
+ self.dag_run = dag.create_dagrun(**kwargs)
+ return self.dag_run
def __call__(self, dag_id='test_dag', **kwargs):
self.kwargs = kwargs
- if "start_date" not in kwargs:
+ self.start_date = self.kwargs.get('start_date', None)
+ if not self.start_date:
if hasattr(request.module, 'DEFAULT_DATE'):
- kwargs['start_date'] = getattr(request.module, 'DEFAULT_DATE')
+ self.start_date = getattr(request.module, 'DEFAULT_DATE')
else:
- kwargs['start_date'] = DEFAULT_DATE
- dagrun_fields_not_in_dag = [
- 'state',
- 'execution_date',
- 'run_type',
- 'queued_at',
- "run_id",
- "creating_job_id",
- "external_trigger",
- "last_scheduling_decision",
- "dag_hash",
- ]
- kwargs = {k: v for k, v in kwargs.items() if k not in dagrun_fields_not_in_dag}
- self.dag = DAG(dag_id, **kwargs)
+ self.start_date = DEFAULT_DATE
+ self.kwargs['start_date'] = self.start_date
+ self.dag = DAG(dag_id, **self.kwargs)
return self
return DagFactory()
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index 9302911..ca69120 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -46,7 +46,7 @@ from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.types import DagRunType
-from tests.test_utils.db import clear_db_pools, clear_db_runs, set_default_pool_slots
+from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots
from tests.test_utils.mock_executor import MockExecutor
logger = logging.getLogger(__name__)
@@ -59,44 +59,10 @@ def dag_bag():
return DagBag(include_examples=True)
-@pytest.fixture
-def get_dummy_dag_and_run(dag_maker):
- def _get_dummy_dag_and_run(
- dag_id='test_dag', pool=Pool.DEFAULT_POOL_NAME, task_concurrency=None, task_id='op', **kwargs
- ):
- with dag_maker(dag_id=dag_id, schedule_interval='@daily', **kwargs) as dag:
- DummyOperator(task_id=task_id, pool=pool, task_concurrency=task_concurrency)
-
- return dag, dag_maker.dag_run
-
- return _get_dummy_dag_and_run
-
-
-@pytest.fixture
-def get_dag_test_max_active_limits(dag_maker):
- def _get_dag_test_max_active_limits(dag_id='test_dag', max_active_runs=1, **kwargs):
- with dag_maker(
- dag_id=dag_id,
- start_date=DEFAULT_DATE,
- schedule_interval="@hourly",
- max_active_runs=max_active_runs,
- **kwargs,
- ) as dag:
- op1 = DummyOperator(task_id='leave1')
- op2 = DummyOperator(task_id='leave2')
- op3 = DummyOperator(task_id='upstream_level_1')
- op4 = DummyOperator(task_id='upstream_level_2')
-
- op1 >> op2 >> op3
- op4 >> op3
- return dag, dag_maker.dag_run
-
- return _get_dag_test_max_active_limits
-
-
class TestBackfillJob:
@staticmethod
def clean_db():
+ clear_db_dags()
clear_db_runs()
clear_db_pools()
@@ -106,6 +72,20 @@ class TestBackfillJob:
self.parser = cli_parser.get_parser()
self.dagbag = dag_bag
+ def _get_dummy_dag(
+ self,
+ dag_maker_fixture,
+ dag_id='test_dag',
+ pool=Pool.DEFAULT_POOL_NAME,
+ task_concurrency=None,
+ task_id='op',
+ **kwargs,
+ ):
+ with dag_maker_fixture(dag_id=dag_id, schedule_interval='@daily', **kwargs) as dag:
+ DummyOperator(task_id=task_id, pool=pool, task_concurrency=task_concurrency)
+
+ return dag
+
def _times_called_with(self, method, class_):
count = 0
for args in method.call_args_list:
@@ -113,8 +93,9 @@ class TestBackfillJob:
count += 1
return count
- def test_unfinished_dag_runs_set_to_failed(self, get_dummy_dag_and_run):
- dag, dag_run = get_dummy_dag_and_run(dag_id='dummy_dag')
+ def test_unfinished_dag_runs_set_to_failed(self, dag_maker):
+ dag = self._get_dummy_dag(dag_maker)
+ dag_run = dag_maker.create_dagrun()
job = BackfillJob(
dag=dag,
@@ -129,8 +110,9 @@ class TestBackfillJob:
assert State.FAILED == dag_run.state
- def test_dag_run_with_finished_tasks_set_to_success(self, get_dummy_dag_and_run):
- dag, dag_run = get_dummy_dag_and_run(dag_id='dummy_dag')
+ def test_dag_run_with_finished_tasks_set_to_success(self, dag_maker):
+ dag = self._get_dummy_dag(dag_maker)
+ dag_run = dag_maker.create_dagrun()
for ti in dag_run.get_task_instances():
ti.set_state(State.SUCCESS)
@@ -289,8 +271,9 @@ class TestBackfillJob:
for task_id in expected_execution_order
] == executor.sorted_tasks
- def test_backfill_conf(self, get_dummy_dag_and_run):
- dag, _ = get_dummy_dag_and_run(dag_id='test_backfill_conf')
+ def test_backfill_conf(self, dag_maker):
+ dag = self._get_dummy_dag(dag_maker, dag_id='test_backfill_conf')
+ dag_maker.create_dagrun()
executor = MockExecutor()
@@ -312,12 +295,14 @@ class TestBackfillJob:
assert conf_ == dr[0].conf
@patch('airflow.jobs.backfill_job.BackfillJob.log')
- def test_backfill_respect_task_concurrency_limit(self, mock_log, get_dummy_dag_and_run):
+ def test_backfill_respect_task_concurrency_limit(self, mock_log, dag_maker):
task_concurrency = 2
- dag, _ = get_dummy_dag_and_run(
+ dag = self._get_dummy_dag(
+ dag_maker,
dag_id='test_backfill_respect_task_concurrency_limit',
task_concurrency=task_concurrency,
)
+ dag_maker.create_dagrun()
executor = MockExecutor()
@@ -364,9 +349,9 @@ class TestBackfillJob:
assert times_task_concurrency_limit_reached_in_debug > 0
@patch('airflow.jobs.backfill_job.BackfillJob.log')
- def test_backfill_respect_dag_concurrency_limit(self, mock_log, get_dummy_dag_and_run):
-
- dag, _ = get_dummy_dag_and_run(dag_id='test_backfill_respect_concurrency_limit')
+ def test_backfill_respect_dag_concurrency_limit(self, mock_log, dag_maker):
+ dag = self._get_dummy_dag(dag_maker, dag_id='test_backfill_respect_concurrency_limit')
+ dag_maker.create_dagrun()
dag.concurrency = 2
executor = MockExecutor()
@@ -415,11 +400,12 @@ class TestBackfillJob:
assert times_dag_concurrency_limit_reached_in_debug > 0
@patch('airflow.jobs.backfill_job.BackfillJob.log')
- def test_backfill_respect_default_pool_limit(self, mock_log, get_dummy_dag_and_run):
+ def test_backfill_respect_default_pool_limit(self, mock_log, dag_maker):
default_pool_slots = 2
set_default_pool_slots(default_pool_slots)
- dag, _ = get_dummy_dag_and_run(dag_id='test_backfill_with_no_pool_limit')
+ dag = self._get_dummy_dag(dag_maker, dag_id='test_backfill_with_no_pool_limit')
+ dag_maker.create_dagrun()
executor = MockExecutor()
@@ -469,11 +455,13 @@ class TestBackfillJob:
assert 0 == times_task_concurrency_limit_reached_in_debug
assert times_pool_limit_reached_in_debug > 0
- def test_backfill_pool_not_found(self, get_dummy_dag_and_run):
- dag, _ = get_dummy_dag_and_run(
+ def test_backfill_pool_not_found(self, dag_maker):
+ dag = self._get_dummy_dag(
+ dag_maker,
dag_id='test_backfill_pool_not_found',
pool='king_pool',
)
+ dag_maker.create_dagrun()
executor = MockExecutor()
@@ -490,7 +478,7 @@ class TestBackfillJob:
return
@patch('airflow.jobs.backfill_job.BackfillJob.log')
- def test_backfill_respect_pool_limit(self, mock_log, get_dummy_dag_and_run):
+ def test_backfill_respect_pool_limit(self, mock_log, dag_maker):
session = settings.Session()
slots = 2
@@ -501,10 +489,12 @@ class TestBackfillJob:
session.add(pool)
session.commit()
- dag, _ = get_dummy_dag_and_run(
+ dag = self._get_dummy_dag(
+ dag_maker,
dag_id='test_backfill_respect_pool_limit',
pool=pool.pool,
)
+ dag_maker.create_dagrun()
executor = MockExecutor()
@@ -550,10 +540,11 @@ class TestBackfillJob:
assert 0 == times_dag_concurrency_limit_reached_in_debug
assert times_pool_limit_reached_in_debug > 0
- def test_backfill_run_rescheduled(self, get_dummy_dag_and_run):
- dag, _ = get_dummy_dag_and_run(
- dag_id="test_backfill_run_rescheduled", task_id="test_backfill_run_rescheduled_task-1"
+ def test_backfill_run_rescheduled(self, dag_maker):
+ dag = self._get_dummy_dag(
+ dag_maker, dag_id="test_backfill_run_rescheduled", task_id="test_backfill_run_rescheduled_task-1"
)
+ dag_maker.create_dagrun()
executor = MockExecutor()
@@ -581,10 +572,11 @@ class TestBackfillJob:
ti.refresh_from_db()
assert ti.state == State.SUCCESS
- def test_backfill_rerun_failed_tasks(self, get_dummy_dag_and_run):
- dag, _ = get_dummy_dag_and_run(
- dag_id="test_backfill_rerun_failed", task_id="test_backfill_rerun_failed_task-1"
+ def test_backfill_rerun_failed_tasks(self, dag_maker):
+ dag = self._get_dummy_dag(
+ dag_maker, dag_id="test_backfill_rerun_failed", task_id="test_backfill_rerun_failed_task-1"
)
+ dag_maker.create_dagrun()
executor = MockExecutor()
@@ -614,12 +606,11 @@ class TestBackfillJob:
def test_backfill_rerun_upstream_failed_tasks(self, dag_maker):
- with dag_maker(
- dag_id='test_backfill_rerun_upstream_failed', start_date=DEFAULT_DATE, schedule_interval='@daily'
- ) as dag:
+ with dag_maker(dag_id='test_backfill_rerun_upstream_failed', schedule_interval='@daily') as dag:
op1 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-1')
op2 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-2')
op1.set_upstream(op2)
+ dag_maker.create_dagrun()
executor = MockExecutor()
@@ -647,10 +638,11 @@ class TestBackfillJob:
ti.refresh_from_db()
assert ti.state == State.SUCCESS
- def test_backfill_rerun_failed_tasks_without_flag(self, get_dummy_dag_and_run):
- dag, _ = get_dummy_dag_and_run(
- dag_id='test_backfill_rerun_failed', task_id='test_backfill_rerun_failed_task-1'
+ def test_backfill_rerun_failed_tasks_without_flag(self, dag_maker):
+ dag = self._get_dummy_dag(
+ dag_maker, dag_id='test_backfill_rerun_failed', task_id='test_backfill_rerun_failed_task-1'
)
+ dag_maker.create_dagrun()
executor = MockExecutor()
@@ -680,7 +672,6 @@ class TestBackfillJob:
def test_backfill_retry_intermittent_failed_task(self, dag_maker):
with dag_maker(
dag_id='test_intermittent_failure_job',
- start_date=DEFAULT_DATE,
schedule_interval="@daily",
default_args={
'retries': 2,
@@ -688,6 +679,7 @@ class TestBackfillJob:
},
) as dag:
task1 = DummyOperator(task_id="task1")
+ dag_maker.create_dagrun()
executor = MockExecutor(parallelism=16)
executor.mock_task_results[
@@ -707,7 +699,6 @@ class TestBackfillJob:
def test_backfill_retry_always_failed_task(self, dag_maker):
with dag_maker(
dag_id='test_always_failure_job',
- start_date=DEFAULT_DATE,
schedule_interval="@daily",
default_args={
'retries': 1,
@@ -715,6 +706,7 @@ class TestBackfillJob:
},
) as dag:
task1 = DummyOperator(task_id="task1")
+ dag_maker.create_dagrun()
executor = MockExecutor(parallelism=16)
executor.mock_task_results[
@@ -734,7 +726,6 @@ class TestBackfillJob:
with dag_maker(
dag_id='test_backfill_ordered_concurrent_execute',
- start_date=DEFAULT_DATE,
schedule_interval="@daily",
) as dag:
op1 = DummyOperator(task_id='leave1')
@@ -747,6 +738,7 @@ class TestBackfillJob:
op1.set_downstream(op3)
op4.set_downstream(op5)
op3.set_downstream(op4)
+ dag_maker.create_dagrun()
executor = MockExecutor(parallelism=16)
job = BackfillJob(
@@ -881,10 +873,29 @@ class TestBackfillJob:
parsed_args = self.parser.parse_args(args)
assert 0.5 == parsed_args.delay_on_limit
- def test_backfill_max_limit_check_within_limit(self, get_dag_test_max_active_limits):
- dag, _ = get_dag_test_max_active_limits(
- dag_id='test_backfill_max_limit_check_within_limit', max_active_runs=16
+ def _get_dag_test_max_active_limits(
+ self, dag_maker_fixture, dag_id='test_dag', max_active_runs=1, **kwargs
+ ):
+ with dag_maker_fixture(
+ dag_id=dag_id,
+ schedule_interval="@hourly",
+ max_active_runs=max_active_runs,
+ **kwargs,
+ ) as dag:
+ op1 = DummyOperator(task_id='leave1')
+ op2 = DummyOperator(task_id='leave2')
+ op3 = DummyOperator(task_id='upstream_level_1')
+ op4 = DummyOperator(task_id='upstream_level_2')
+
+ op1 >> op2 >> op3
+ op4 >> op3
+ return dag
+
+ def test_backfill_max_limit_check_within_limit(self, dag_maker):
+ dag = self._get_dag_test_max_active_limits(
+ dag_maker, dag_id='test_backfill_max_limit_check_within_limit', max_active_runs=16
)
+ dag_maker.create_dagrun()
start_date = DEFAULT_DATE - datetime.timedelta(hours=1)
end_date = DEFAULT_DATE
@@ -898,7 +909,7 @@ class TestBackfillJob:
assert 2 == len(dagruns)
assert all(run.state == State.SUCCESS for run in dagruns)
- def test_backfill_max_limit_check(self, get_dag_test_max_active_limits):
+ def test_backfill_max_limit_check(self, dag_maker):
dag_id = 'test_backfill_max_limit_check'
run_id = 'test_dag_run'
start_date = DEFAULT_DATE - datetime.timedelta(hours=1)
@@ -911,9 +922,12 @@ class TestBackfillJob:
# this session object is different than the one in the main thread
with create_session() as thread_session:
try:
- dag, _ = get_dag_test_max_active_limits(
- # Existing dagrun that is not within the backfill range
+ dag = self._get_dag_test_max_active_limits(
+ dag_maker,
dag_id=dag_id,
+ )
+ dag_maker.create_dagrun(
+ # Existing dagrun that is not within the backfill range
run_id=run_id,
execution_date=DEFAULT_DATE + datetime.timedelta(hours=1),
)
@@ -960,11 +974,14 @@ class TestBackfillJob:
finally:
dag_run_created_cond.release()
- def test_backfill_max_limit_check_no_count_existing(self, get_dag_test_max_active_limits):
+ def test_backfill_max_limit_check_no_count_existing(self, dag_maker):
start_date = DEFAULT_DATE
end_date = DEFAULT_DATE
# Existing dagrun that is within the backfill range
- dag, _ = get_dag_test_max_active_limits(dag_id='test_backfill_max_limit_check_no_count_existing')
+ dag = self._get_dag_test_max_active_limits(
+ dag_maker, dag_id='test_backfill_max_limit_check_no_count_existing'
+ )
+ dag_maker.create_dagrun()
executor = MockExecutor()
job = BackfillJob(
@@ -980,8 +997,11 @@ class TestBackfillJob:
assert 1 == len(dagruns)
assert State.SUCCESS == dagruns[0].state
- def test_backfill_max_limit_check_complete_loop(self, get_dag_test_max_active_limits):
- dag, _ = get_dag_test_max_active_limits(dag_id='test_backfill_max_limit_check_complete_loop')
+ def test_backfill_max_limit_check_complete_loop(self, dag_maker):
+ dag = self._get_dag_test_max_active_limits(
+ dag_maker, dag_id='test_backfill_max_limit_check_complete_loop'
+ )
+ dag_maker.create_dagrun()
start_date = DEFAULT_DATE - datetime.timedelta(hours=1)
end_date = DEFAULT_DATE
@@ -1003,9 +1023,6 @@ class TestBackfillJob:
with dag_maker(
'test_sub_set_subdag',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'},
- execution_date=DEFAULT_DATE,
) as dag:
op1 = DummyOperator(task_id='leave1')
op2 = DummyOperator(task_id='leave2')
@@ -1018,7 +1035,7 @@ class TestBackfillJob:
op4.set_downstream(op5)
op3.set_downstream(op4)
- dr = dag_maker.dag_run
+ dr = dag_maker.create_dagrun()
executor = MockExecutor()
sub_dag = dag.sub_dag(task_ids_or_regex="leave*", include_downstream=False, include_upstream=False)
@@ -1041,9 +1058,6 @@ class TestBackfillJob:
def test_backfill_fill_blanks(self, dag_maker):
with dag_maker(
'test_backfill_fill_blanks',
- start_date=DEFAULT_DATE,
- default_args={'owner': 'owner1'},
- execution_date=DEFAULT_DATE,
) as dag:
op1 = DummyOperator(task_id='op1')
op2 = DummyOperator(task_id='op2')
@@ -1052,7 +1066,7 @@ class TestBackfillJob:
op5 = DummyOperator(task_id='op5')
op6 = DummyOperator(task_id='op6')
- dr = dag_maker.dag_run
+ dr = dag_maker.create_dagrun()
executor = MockExecutor()
@@ -1229,11 +1243,9 @@ class TestBackfillJob:
dag.clear()
def test_update_counters(self, dag_maker):
- with dag_maker(
- dag_id='test_manage_executor_state', start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE
- ) as dag:
- task1 = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
- dr = dag_maker.dag_run
+ with dag_maker(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE) as dag:
+ task1 = DummyOperator(task_id='dummy', owner='airflow')
+ dr = dag_maker.create_dagrun()
job = BackfillJob(dag=dag)
session = settings.Session()
@@ -1328,7 +1340,6 @@ class TestBackfillJob:
assert [DEFAULT_DATE] == test_dag.get_run_dates(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE,
- align=True,
)
assert [
DEFAULT_DATE - datetime.timedelta(hours=3),
@@ -1377,9 +1388,7 @@ class TestBackfillJob:
states_to_reset = [State.QUEUED, State.SCHEDULED, State.NONE]
tasks = []
- with dag_maker(
- dag_id=prefix, start_date=DEFAULT_DATE, schedule_interval="@daily", run_id='test1'
- ) as dag:
+ with dag_maker(dag_id=prefix, start_date=DEFAULT_DATE, schedule_interval="@daily") as dag:
for i in range(len(states)):
task_id = f"{prefix}_task_{i}"
task = DummyOperator(task_id=task_id)
@@ -1389,7 +1398,7 @@ class TestBackfillJob:
job = BackfillJob(dag=dag)
# create dagruns
- dr1 = dag_maker.dag_run
+ dr1 = dag_maker.create_dagrun()
dr2 = dag.create_dagrun(run_id='test2', state=State.SUCCESS)
# create taskinstances and set states
@@ -1442,15 +1451,13 @@ class TestBackfillJob:
dag_id=dag_id,
start_date=DEFAULT_DATE,
schedule_interval='@daily',
- state=State.SUCCESS,
- run_id='test1',
) as dag:
DummyOperator(task_id=task_id, dag=dag)
job = BackfillJob(dag=dag)
session = settings.Session()
# make two dagruns, only reset for one
- dr1 = dag_maker.dag_run # Already created in dag_maker with state=SUCCESS
+ dr1 = dag_maker.create_dagrun(state=State.SUCCESS)
dr2 = dag.create_dagrun(run_id='test2', state=State.RUNNING)
ti1 = dr1.get_task_instances(session=session)[0]
ti2 = dr2.get_task_instances(session=session)[0]
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index 94f894d..14c74ce 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -27,14 +27,12 @@ from unittest import mock
from unittest.mock import patch
import pytest
-from parameterized import parameterized
from airflow import settings
from airflow.exceptions import AirflowException, AirflowFailException
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.jobs.local_task_job import LocalTaskJob
from airflow.jobs.scheduler_job import SchedulerJob
-from airflow.models.dag import DAG, DagModel
from airflow.models.dagbag import DagBag
from airflow.models.taskinstance import TaskInstance
from airflow.operators.dummy import DummyOperator
@@ -75,10 +73,19 @@ def clear_db_class():
db.clear_db_task_fail()
+@pytest.fixture(scope='module')
+def dagbag():
+ return DagBag(
+ dag_folder=TEST_DAG_FOLDER,
+ include_examples=False,
+ )
+
+
@pytest.mark.usefixtures('clear_db_class', 'clear_db')
class TestLocalTaskJob:
@pytest.fixture(autouse=True)
- def set_instance_attrs(self):
+ def set_instance_attrs(self, dagbag):
+ self.dagbag = dagbag
with patch('airflow.jobs.base_job.sleep') as self.mock_base_job_sleep:
yield
@@ -94,12 +101,10 @@ class TestLocalTaskJob:
of LocalTaskJob can be assigned with
proper values without intervention
"""
- with dag_maker(
- 'test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}
- ):
+ with dag_maker('test_localtaskjob_essential_attr'):
op1 = DummyOperator(task_id='op1')
- dr = dag_maker.dag_run
+ dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task_id=op1.task_id)
@@ -118,7 +123,7 @@ class TestLocalTaskJob:
with dag_maker('test_localtaskjob_heartbeat'):
op1 = DummyOperator(task_id='op1')
- dr = dag_maker.dag_run
+ dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
ti.state = State.RUNNING
ti.hostname = "blablabla"
@@ -150,7 +155,7 @@ class TestLocalTaskJob:
session = settings.Session()
with dag_maker('test_localtaskjob_heartbeat'):
op1 = DummyOperator(task_id='op1', run_as_user='myuser')
- dr = dag_maker.dag_run
+ dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
ti.state = State.RUNNING
ti.pid = 2
@@ -192,7 +197,7 @@ class TestLocalTaskJob:
session = settings.Session()
with dag_maker('test_localtaskjob_heartbeat'):
op1 = DummyOperator(task_id='op1')
- dr = dag_maker.dag_run
+ dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
ti.state = State.RUNNING
ti.pid = 2
@@ -236,13 +241,10 @@ class TestLocalTaskJob:
dag_id = 'test_heartbeat_failed_fast'
task_id = 'test_heartbeat_failed_fast_op'
with create_session() as session:
- dagbag = DagBag(
- dag_folder=TEST_DAG_FOLDER,
- include_examples=False,
- )
+
dag_id = 'test_heartbeat_failed_fast'
task_id = 'test_heartbeat_failed_fast_op'
- dag = dagbag.get_dag(dag_id)
+ dag = self.dagbag.get_dag(dag_id)
task = dag.get_task(task_id)
dag.create_dagrun(
@@ -278,11 +280,7 @@ class TestLocalTaskJob:
Test that ensures that mark_success in the UI doesn't cause
the task to fail, and that the task exits
"""
- dagbag = DagBag(
- dag_folder=TEST_DAG_FOLDER,
- include_examples=False,
- )
- dag = dagbag.dags.get('test_mark_success')
+ dag = self.dagbag.dags.get('test_mark_success')
task = dag.get_task('task1')
session = settings.Session()
@@ -317,11 +315,7 @@ class TestLocalTaskJob:
def test_localtaskjob_double_trigger(self):
- dagbag = DagBag(
- dag_folder=TEST_DAG_FOLDER,
- include_examples=False,
- )
- dag = dagbag.dags.get('test_localtaskjob_double_trigger')
+ dag = self.dagbag.dags.get('test_localtaskjob_double_trigger')
task = dag.get_task('test_localtaskjob_double_trigger_task')
session = settings.Session()
@@ -357,11 +351,8 @@ class TestLocalTaskJob:
@pytest.mark.quarantined
def test_localtaskjob_maintain_heart_rate(self):
- dagbag = DagBag(
- dag_folder=TEST_DAG_FOLDER,
- include_examples=False,
- )
- dag = dagbag.dags.get('test_localtaskjob_double_trigger')
+
+ dag = self.dagbag.dags.get('test_localtaskjob_double_trigger')
task = dag.get_task('test_localtaskjob_double_trigger_task')
session = settings.Session()
@@ -440,6 +431,7 @@ class TestLocalTaskJob:
python_callable=task_function,
on_failure_callback=check_failure,
)
+ dag_maker.create_dagrun()
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
@@ -481,6 +473,7 @@ class TestLocalTaskJob:
python_callable=task_function,
on_failure_callback=failure_callback,
)
+ dag_maker.create_dagrun()
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
@@ -652,7 +645,8 @@ class TestLocalTaskJob:
assert task_terminated_externally.value == 1
assert not process.is_alive()
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "conf, dependencies, init_state, first_run_state, second_run_state, error_message",
[
(
{('scheduler', 'schedule_after_task_execution'): 'True'},
@@ -686,27 +680,17 @@ class TestLocalTaskJob:
None,
"A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.",
),
- ]
+ ],
)
def test_fast_follow(
- self, conf, dependencies, init_state, first_run_state, second_run_state, error_message
+ self, conf, dependencies, init_state, first_run_state, second_run_state, error_message, dag_maker
):
# pylint: disable=too-many-locals
with conf_vars(conf):
session = settings.Session()
- dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE)
-
- dag_model = DagModel(
- dag_id=dag.dag_id,
- next_dagrun=dag.start_date,
- is_active=True,
- )
- session.add(dag_model)
- session.flush()
-
python_callable = lambda: True
- with dag:
+ with dag_maker('test_dagrun_fast_follow') as dag:
task_a = PythonOperator(task_id='A', python_callable=python_callable)
task_b = PythonOperator(task_id='B', python_callable=python_callable)
task_c = PythonOperator(task_id='C', python_callable=python_callable)
@@ -715,6 +699,8 @@ class TestLocalTaskJob:
for upstream, downstream in dependencies.items():
dag.set_dependency(upstream, downstream)
+ dag_maker.make_dagmodel()
+
scheduler_job = SchedulerJob(subdir=os.devnull)
scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
@@ -850,34 +836,25 @@ class TestLocalTaskJob:
assert retry_callback_called.value == 1
assert task_terminated_externally.value == 1
- def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self):
+ def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self, dag_maker):
"""Test that with DAG paused, DagRun state will update when the tasks finishes the run"""
- dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
- op1 = PythonOperator(task_id='dummy', dag=dag, owner='airflow', python_callable=lambda: True)
+ with dag_maker(dag_id='test_dags') as dag:
+ op1 = PythonOperator(task_id='dummy', python_callable=lambda: True)
session = settings.Session()
- orm_dag = DagModel(
- dag_id=dag.dag_id,
+ dag_maker.make_dagmodel(
has_task_concurrency_limits=False,
- next_dagrun=dag.start_date,
next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE),
is_active=True,
is_paused=True,
)
- session.add(orm_dag)
- session.flush()
# Write Dag to DB
dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False)
dagbag.bag_dag(dag, root_dag=dag)
dagbag.sync_to_db()
- dr = dag.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session,
- )
+ dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+
assert dr.state == State.RUNNING
ti = TaskInstance(op1, dr.execution_date)
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
@@ -899,13 +876,12 @@ def clean_db_helper():
class TestLocalTaskJobPerformance:
@pytest.mark.parametrize("return_codes", [[0], 9 * [None] + [0]]) # type: ignore
@mock.patch("airflow.jobs.local_task_job.get_task_runner")
- def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes):
+ def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes, dag_maker):
unique_prefix = str(uuid.uuid4())
- dag = DAG(dag_id=f'{unique_prefix}_test_number_of_queries', start_date=DEFAULT_DATE)
- task = DummyOperator(task_id='test_state_succeeded1', dag=dag)
+ with dag_maker(dag_id=f'{unique_prefix}_test_number_of_queries'):
+ task = DummyOperator(task_id='test_state_succeeded1')
- dag.clear()
- dag.create_dagrun(run_id=unique_prefix, execution_date=DEFAULT_DATE, state=State.NONE)
+ dag_maker.create_dagrun(run_id=unique_prefix, state=State.NONE)
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)