You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/08/14 01:15:02 UTC
[airflow] 06/09: Add Pytest fixture to create dag and dagrun and
use it on local task job tests (#16889)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 343beb65685adc4b87107a18a43c509731985499
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Jul 20 18:48:35 2021 +0100
Add Pytest fixture to create dag and dagrun and use it on local task job tests (#16889)
This change adds pytest fixture to create dag and dagrun then use it on local task job tests
Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
(cherry picked from commit 7c0d8a2f83cc6db25bdddcf6cecb6fb56f05f02f)
---
tests/conftest.py | 50 +++++++++++
tests/jobs/test_local_task_job.py | 178 +++++++++++++-------------------------
2 files changed, 111 insertions(+), 117 deletions(-)
diff --git a/tests/conftest.py b/tests/conftest.py
index 55e1593..f2c5345 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -425,3 +425,53 @@ def app():
from airflow.www import app
return app.create_app(testing=True)
+
+
+@pytest.fixture
+def dag_maker(request):
+ from airflow.models import DAG
+ from airflow.utils import timezone
+ from airflow.utils.state import State
+
+ DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+
+ class DagFactory:
+ def __enter__(self):
+ self.dag.__enter__()
+ return self.dag
+
+ def __exit__(self, type, value, traceback):
+ dag = self.dag
+ dag.__exit__(type, value, traceback)
+ if type is None:
+ dag.clear()
+ self.dag_run = dag.create_dagrun(
+ run_id=self.kwargs.get("run_id", "test"),
+ state=self.kwargs.get('state', State.RUNNING),
+ execution_date=self.kwargs.get('execution_date', self.kwargs['start_date']),
+ start_date=self.kwargs['start_date'],
+ )
+
+ def __call__(self, dag_id='test_dag', **kwargs):
+ self.kwargs = kwargs
+ if "start_date" not in kwargs:
+ if hasattr(request.module, 'DEFAULT_DATE'):
+ kwargs['start_date'] = getattr(request.module, 'DEFAULT_DATE')
+ else:
+ kwargs['start_date'] = DEFAULT_DATE
+ dagrun_fields_not_in_dag = [
+ 'state',
+ 'execution_date',
+ 'run_type',
+ 'queued_at',
+ "run_id",
+ "creating_job_id",
+ "external_trigger",
+ "last_scheduling_decision",
+ "dag_hash",
+ ]
+ kwargs = {k: v for k, v in kwargs.items() if k not in dagrun_fields_not_in_dag}
+ self.dag = DAG(dag_id, **kwargs)
+ return self
+
+ return DagFactory()
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index 11e9adf..d9f1398 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -20,7 +20,6 @@ import multiprocessing
import os
import signal
import time
-import unittest
import uuid
from multiprocessing import Lock, Value
from unittest import mock
@@ -57,21 +56,30 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1)
TEST_DAG_FOLDER = os.environ['AIRFLOW__CORE__DAGS_FOLDER']
-class TestLocalTaskJob(unittest.TestCase):
- def setUp(self):
- db.clear_db_dags()
- db.clear_db_jobs()
- db.clear_db_runs()
- db.clear_db_task_fail()
- patcher = patch('airflow.jobs.base_job.sleep')
- self.addCleanup(patcher.stop)
- self.mock_base_job_sleep = patcher.start()
+@pytest.fixture
+def clear_db():
+ db.clear_db_dags()
+ db.clear_db_jobs()
+ db.clear_db_runs()
+ db.clear_db_task_fail()
+ yield
+
+
+@pytest.fixture(scope='class')
+def clear_db_class():
+ yield
+ db.clear_db_dags()
+ db.clear_db_jobs()
+ db.clear_db_runs()
+ db.clear_db_task_fail()
+
- def tearDown(self) -> None:
- db.clear_db_dags()
- db.clear_db_jobs()
- db.clear_db_runs()
- db.clear_db_task_fail()
+@pytest.mark.usefixtures('clear_db_class', 'clear_db')
+class TestLocalTaskJob:
+ @pytest.fixture(autouse=True)
+ def set_instance_attrs(self):
+ with patch('airflow.jobs.base_job.sleep') as self.mock_base_job_sleep:
+ yield
def validate_ti_states(self, dag_run, ti_state_mapping, error_message):
for task_id, expected_state in ti_state_mapping.items():
@@ -79,23 +87,19 @@ class TestLocalTaskJob(unittest.TestCase):
task_instance.refresh_from_db()
assert task_instance.state == expected_state, error_message
- def test_localtaskjob_essential_attr(self):
+ def test_localtaskjob_essential_attr(self, dag_maker):
"""
Check whether essential attributes
of LocalTaskJob can be assigned with
proper values without intervention
"""
- dag = DAG(
+ with dag_maker(
'test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}
- )
-
- with dag:
+ ):
op1 = DummyOperator(task_id='op1')
- dag.clear()
- dr = dag.create_dagrun(
- run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE
- )
+ dr = dag_maker.dag_run
+
ti = dr.get_task_instance(task_id=op1.task_id)
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
@@ -108,21 +112,12 @@ class TestLocalTaskJob(unittest.TestCase):
check_result_2 = [getattr(job1, attr) is not None for attr in essential_attr]
assert all(check_result_2)
- def test_localtaskjob_heartbeat(self):
+ def test_localtaskjob_heartbeat(self, dag_maker):
session = settings.Session()
- dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
- with dag:
+ with dag_maker('test_localtaskjob_heartbeat'):
op1 = DummyOperator(task_id='op1')
- dag.clear()
- dr = dag.create_dagrun(
- run_id="test",
- state=State.SUCCESS,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session,
- )
+ dr = dag_maker.dag_run
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
ti.state = State.RUNNING
ti.hostname = "blablabla"
@@ -150,22 +145,11 @@ class TestLocalTaskJob(unittest.TestCase):
job1.heartbeat_callback()
@mock.patch('airflow.jobs.local_task_job.psutil')
- def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock):
+ def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock, dag_maker):
session = settings.Session()
- dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
- with dag:
+ with dag_maker('test_localtaskjob_heartbeat'):
op1 = DummyOperator(task_id='op1', run_as_user='myuser')
-
- dag.clear()
- dr = dag.create_dagrun(
- run_id="test",
- state=State.SUCCESS,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session,
- )
-
+ dr = dag_maker.dag_run
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
ti.state = State.RUNNING
ti.pid = 2
@@ -248,7 +232,8 @@ class TestLocalTaskJob(unittest.TestCase):
Test that task heartbeat will sleep when it fails fast
"""
self.mock_base_job_sleep.side_effect = time.sleep
-
+ dag_id = 'test_heartbeat_failed_fast'
+ task_id = 'test_heartbeat_failed_fast_op'
with create_session() as session:
dagbag = DagBag(
dag_folder=TEST_DAG_FOLDER,
@@ -266,6 +251,7 @@ class TestLocalTaskJob(unittest.TestCase):
start_date=DEFAULT_DATE,
session=session,
)
+
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
ti.state = State.RUNNING
@@ -331,6 +317,7 @@ class TestLocalTaskJob(unittest.TestCase):
assert State.SUCCESS == ti.state
def test_localtaskjob_double_trigger(self):
+
dagbag = DagBag(
dag_folder=TEST_DAG_FOLDER,
include_examples=False,
@@ -348,6 +335,7 @@ class TestLocalTaskJob(unittest.TestCase):
start_date=DEFAULT_DATE,
session=session,
)
+
ti = dr.get_task_instance(task_id=task.task_id, session=session)
ti.state = State.RUNNING
ti.hostname = get_hostname()
@@ -418,7 +406,7 @@ class TestLocalTaskJob(unittest.TestCase):
assert time_end - time_start < job1.heartrate
session.close()
- def test_mark_failure_on_failure_callback(self):
+ def test_mark_failure_on_failure_callback(self, dag_maker):
"""
Test that ensures that mark_failure in the UI fails
the task, and executes on_failure_callback
@@ -447,22 +435,12 @@ class TestLocalTaskJob(unittest.TestCase):
with task_terminated_externally.get_lock():
task_terminated_externally.value = 0
- with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag:
+ with dag_maker("test_mark_failure", start_date=DEFAULT_DATE):
task = PythonOperator(
task_id='test_state_succeeded1',
python_callable=task_function,
on_failure_callback=check_failure,
)
-
- dag.clear()
- with create_session() as session:
- dag.create_dagrun(
- run_id="test",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session,
- )
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
@@ -479,7 +457,7 @@ class TestLocalTaskJob(unittest.TestCase):
@patch('airflow.utils.process_utils.subprocess.check_call')
@patch.object(StandardTaskRunner, 'return_code')
- def test_failure_callback_only_called_once(self, mock_return_code, _check_call):
+ def test_failure_callback_only_called_once(self, mock_return_code, _check_call, dag_maker):
"""
Test that ensures that when a task exits with failure by itself,
failure callback is only called once
@@ -498,22 +476,11 @@ class TestLocalTaskJob(unittest.TestCase):
def task_function(ti):
raise AirflowFailException()
- dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE)
- task = PythonOperator(
- task_id='test_exit_on_failure',
- python_callable=task_function,
- on_failure_callback=failure_callback,
- dag=dag,
- )
-
- dag.clear()
- with create_session() as session:
- dag.create_dagrun(
- run_id="test",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session,
+ with dag_maker("test_failure_callback_race"):
+ task = PythonOperator(
+ task_id='test_exit_on_failure',
+ python_callable=task_function,
+ on_failure_callback=failure_callback,
)
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
@@ -544,7 +511,7 @@ class TestLocalTaskJob(unittest.TestCase):
assert failure_callback_called.value == 1
@pytest.mark.quarantined
- def test_mark_success_on_success_callback(self):
+ def test_mark_success_on_success_callback(self, dag_maker):
"""
Test that ensures that where a task is marked success in the UI
on_success_callback gets executed
@@ -560,8 +527,6 @@ class TestLocalTaskJob(unittest.TestCase):
success_callback_called.value += 1
assert context['dag_run'].dag_id == 'test_mark_success'
- dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
def task_function(ti):
time.sleep(60)
@@ -569,23 +534,15 @@ class TestLocalTaskJob(unittest.TestCase):
with shared_mem_lock:
task_terminated_externally.value = 0
- task = PythonOperator(
- task_id='test_state_succeeded1',
- python_callable=task_function,
- on_success_callback=success_callback,
- dag=dag,
- )
+ with dag_maker(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}):
+ task = PythonOperator(
+ task_id='test_state_succeeded1',
+ python_callable=task_function,
+ on_success_callback=success_callback,
+ )
session = settings.Session()
- dag.clear()
- dag.create_dagrun(
- run_id="test",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session,
- )
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
@@ -616,7 +573,7 @@ class TestLocalTaskJob(unittest.TestCase):
(signal.SIGKILL,),
]
)
- def test_process_kill_calls_on_failure_callback(self, signal_type):
+ def test_process_kill_calls_on_failure_callback(self, signal_type, dag_maker):
"""
Test that ensures that when a task is killed with sigterm or sigkill
on_failure_callback gets executed
@@ -632,8 +589,6 @@ class TestLocalTaskJob(unittest.TestCase):
failure_callback_called.value += 1
assert context['dag_run'].dag_id == 'test_mark_failure'
- dag = DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
def task_function(ti):
time.sleep(60)
@@ -641,23 +596,12 @@ class TestLocalTaskJob(unittest.TestCase):
with shared_mem_lock:
task_terminated_externally.value = 0
- task = PythonOperator(
- task_id='test_on_failure',
- python_callable=task_function,
- on_failure_callback=failure_callback,
- dag=dag,
- )
-
- session = settings.Session()
-
- dag.clear()
- dag.create_dagrun(
- run_id="test",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session,
- )
+ with dag_maker(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}):
+ task = PythonOperator(
+ task_id='test_on_failure',
+ python_callable=task_function,
+ on_failure_callback=failure_callback,
+ )
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())