You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/08/13 19:49:16 UTC

[airflow] 07/08: Add Pytest fixture to create dag and dagrun and use it on local task job tests (#16889)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit a3879b2e1bc6c8b30443fd41f02ff580a3001966
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Jul 20 18:48:35 2021 +0100

    Add Pytest fixture to create dag and dagrun and use it on local task job tests (#16889)
    
    This change adds pytest fixture to create dag and dagrun then use it on local task job tests
    
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
    (cherry picked from commit 7c0d8a2f83cc6db25bdddcf6cecb6fb56f05f02f)
---
 tests/conftest.py                 |  50 +++++++++
 tests/jobs/test_local_task_job.py | 215 +++++++++++++++++---------------------
 2 files changed, 148 insertions(+), 117 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index 55e1593..f2c5345 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -425,3 +425,53 @@ def app():
     from airflow.www import app
 
     return app.create_app(testing=True)
+
+
+@pytest.fixture
+def dag_maker(request):
+    from airflow.models import DAG
+    from airflow.utils import timezone
+    from airflow.utils.state import State
+
+    DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+
+    class DagFactory:
+        def __enter__(self):
+            self.dag.__enter__()
+            return self.dag
+
+        def __exit__(self, type, value, traceback):
+            dag = self.dag
+            dag.__exit__(type, value, traceback)
+            if type is None:
+                dag.clear()
+                self.dag_run = dag.create_dagrun(
+                    run_id=self.kwargs.get("run_id", "test"),
+                    state=self.kwargs.get('state', State.RUNNING),
+                    execution_date=self.kwargs.get('execution_date', self.kwargs['start_date']),
+                    start_date=self.kwargs['start_date'],
+                )
+
+        def __call__(self, dag_id='test_dag', **kwargs):
+            self.kwargs = kwargs
+            if "start_date" not in kwargs:
+                if hasattr(request.module, 'DEFAULT_DATE'):
+                    kwargs['start_date'] = getattr(request.module, 'DEFAULT_DATE')
+                else:
+                    kwargs['start_date'] = DEFAULT_DATE
+            dagrun_fields_not_in_dag = [
+                'state',
+                'execution_date',
+                'run_type',
+                'queued_at',
+                "run_id",
+                "creating_job_id",
+                "external_trigger",
+                "last_scheduling_decision",
+                "dag_hash",
+            ]
+            kwargs = {k: v for k, v in kwargs.items() if k not in dagrun_fields_not_in_dag}
+            self.dag = DAG(dag_id, **kwargs)
+            return self
+
+    return DagFactory()
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index 9d80647..3475ef1 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -20,7 +20,6 @@ import multiprocessing
 import os
 import signal
 import time
-import unittest
 import uuid
 from multiprocessing import Lock, Value
 from unittest import mock
@@ -55,21 +54,30 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1)
 TEST_DAG_FOLDER = os.environ['AIRFLOW__CORE__DAGS_FOLDER']
 
 
-class TestLocalTaskJob(unittest.TestCase):
-    def setUp(self):
-        db.clear_db_dags()
-        db.clear_db_jobs()
-        db.clear_db_runs()
-        db.clear_db_task_fail()
-        patcher = patch('airflow.jobs.base_job.sleep')
-        self.addCleanup(patcher.stop)
-        self.mock_base_job_sleep = patcher.start()
+@pytest.fixture
+def clear_db():
+    db.clear_db_dags()
+    db.clear_db_jobs()
+    db.clear_db_runs()
+    db.clear_db_task_fail()
+    yield
+
+
+@pytest.fixture(scope='class')
+def clear_db_class():
+    yield
+    db.clear_db_dags()
+    db.clear_db_jobs()
+    db.clear_db_runs()
+    db.clear_db_task_fail()
+
 
-    def tearDown(self) -> None:
-        db.clear_db_dags()
-        db.clear_db_jobs()
-        db.clear_db_runs()
-        db.clear_db_task_fail()
+@pytest.mark.usefixtures('clear_db_class', 'clear_db')
+class TestLocalTaskJob:
+    @pytest.fixture(autouse=True)
+    def set_instance_attrs(self):
+        with patch('airflow.jobs.base_job.sleep') as self.mock_base_job_sleep:
+            yield
 
     def validate_ti_states(self, dag_run, ti_state_mapping, error_message):
         for task_id, expected_state in ti_state_mapping.items():
@@ -77,23 +85,19 @@ class TestLocalTaskJob(unittest.TestCase):
             task_instance.refresh_from_db()
             assert task_instance.state == expected_state, error_message
 
-    def test_localtaskjob_essential_attr(self):
+    def test_localtaskjob_essential_attr(self, dag_maker):
         """
         Check whether essential attributes
         of LocalTaskJob can be assigned with
         proper values without intervention
         """
-        dag = DAG(
+        with dag_maker(
             'test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}
-        )
-
-        with dag:
+        ):
             op1 = DummyOperator(task_id='op1')
 
-        dag.clear()
-        dr = dag.create_dagrun(
-            run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE
-        )
+        dr = dag_maker.dag_run
+
         ti = dr.get_task_instance(task_id=op1.task_id)
 
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
@@ -106,21 +110,12 @@ class TestLocalTaskJob(unittest.TestCase):
         check_result_2 = [getattr(job1, attr) is not None for attr in essential_attr]
         assert all(check_result_2)
 
-    def test_localtaskjob_heartbeat(self):
+    def test_localtaskjob_heartbeat(self, dag_maker):
         session = settings.Session()
-        dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
-        with dag:
+        with dag_maker('test_localtaskjob_heartbeat'):
             op1 = DummyOperator(task_id='op1')
 
-        dag.clear()
-        dr = dag.create_dagrun(
-            run_id="test",
-            state=State.SUCCESS,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
+        dr = dag_maker.dag_run
         ti = dr.get_task_instance(task_id=op1.task_id, session=session)
         ti.state = State.RUNNING
         ti.hostname = "blablabla"
@@ -148,22 +143,11 @@ class TestLocalTaskJob(unittest.TestCase):
             job1.heartbeat_callback()
 
     @mock.patch('airflow.jobs.local_task_job.psutil')
-    def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock):
+    def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock, dag_maker):
         session = settings.Session()
-        dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
-        with dag:
+        with dag_maker('test_localtaskjob_heartbeat'):
             op1 = DummyOperator(task_id='op1', run_as_user='myuser')
-
-        dag.clear()
-        dr = dag.create_dagrun(
-            run_id="test",
-            state=State.SUCCESS,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
-
+        dr = dag_maker.dag_run
         ti = dr.get_task_instance(task_id=op1.task_id, session=session)
         ti.state = State.RUNNING
         ti.pid = 2
@@ -246,7 +230,8 @@ class TestLocalTaskJob(unittest.TestCase):
         Test that task heartbeat will sleep when it fails fast
         """
         self.mock_base_job_sleep.side_effect = time.sleep
-
+        dag_id = 'test_heartbeat_failed_fast'
+        task_id = 'test_heartbeat_failed_fast_op'
         with create_session() as session:
             dagbag = DagBag(
                 dag_folder=TEST_DAG_FOLDER,
@@ -264,6 +249,7 @@ class TestLocalTaskJob(unittest.TestCase):
                 start_date=DEFAULT_DATE,
                 session=session,
             )
+
             ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
             ti.refresh_from_db()
             ti.state = State.RUNNING
@@ -329,6 +315,7 @@ class TestLocalTaskJob(unittest.TestCase):
         assert State.SUCCESS == ti.state
 
     def test_localtaskjob_double_trigger(self):
+
         dagbag = DagBag(
             dag_folder=TEST_DAG_FOLDER,
             include_examples=False,
@@ -346,6 +333,7 @@ class TestLocalTaskJob(unittest.TestCase):
             start_date=DEFAULT_DATE,
             session=session,
         )
+
         ti = dr.get_task_instance(task_id=task.task_id, session=session)
         ti.state = State.RUNNING
         ti.hostname = get_hostname()
@@ -416,7 +404,7 @@ class TestLocalTaskJob(unittest.TestCase):
         assert time_end - time_start < job1.heartrate
         session.close()
 
-    def test_mark_failure_on_failure_callback(self):
+    def test_mark_failure_on_failure_callback(self, dag_maker):
         """
         Test that ensures that mark_failure in the UI fails
         the task, and executes on_failure_callback
@@ -445,22 +433,12 @@ class TestLocalTaskJob(unittest.TestCase):
             with task_terminated_externally.get_lock():
                 task_terminated_externally.value = 0
 
-        with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag:
+        with dag_maker("test_mark_failure", start_date=DEFAULT_DATE):
             task = PythonOperator(
                 task_id='test_state_succeeded1',
                 python_callable=task_function,
                 on_failure_callback=check_failure,
             )
-
-        dag.clear()
-        with create_session() as session:
-            dag.create_dagrun(
-                run_id="test",
-                state=State.RUNNING,
-                execution_date=DEFAULT_DATE,
-                start_date=DEFAULT_DATE,
-                session=session,
-            )
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
 
@@ -477,7 +455,7 @@ class TestLocalTaskJob(unittest.TestCase):
 
     @patch('airflow.utils.process_utils.subprocess.check_call')
     @patch.object(StandardTaskRunner, 'return_code')
-    def test_failure_callback_only_called_once(self, mock_return_code, _check_call):
+    def test_failure_callback_only_called_once(self, mock_return_code, _check_call, dag_maker):
         """
         Test that ensures that when a task exits with failure by itself,
         failure callback is only called once
@@ -496,22 +474,11 @@ class TestLocalTaskJob(unittest.TestCase):
         def task_function(ti):
             raise AirflowFailException()
 
-        dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE)
-        task = PythonOperator(
-            task_id='test_exit_on_failure',
-            python_callable=task_function,
-            on_failure_callback=failure_callback,
-            dag=dag,
-        )
-
-        dag.clear()
-        with create_session() as session:
-            dag.create_dagrun(
-                run_id="test",
-                state=State.RUNNING,
-                execution_date=DEFAULT_DATE,
-                start_date=DEFAULT_DATE,
-                session=session,
+        with dag_maker("test_failure_callback_race"):
+            task = PythonOperator(
+                task_id='test_exit_on_failure',
+                python_callable=task_function,
+                on_failure_callback=failure_callback,
             )
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
@@ -542,7 +509,7 @@ class TestLocalTaskJob(unittest.TestCase):
         assert failure_callback_called.value == 1
 
     @pytest.mark.quarantined
-    def test_mark_success_on_success_callback(self):
+    def test_mark_success_on_success_callback(self, dag_maker):
         """
         Test that ensures that where a task is marked success in the UI
         on_success_callback gets executed
@@ -558,8 +525,6 @@ class TestLocalTaskJob(unittest.TestCase):
                 success_callback_called.value += 1
             assert context['dag_run'].dag_id == 'test_mark_success'
 
-        dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
         def task_function(ti):
 
             time.sleep(60)
@@ -567,23 +532,15 @@ class TestLocalTaskJob(unittest.TestCase):
             with shared_mem_lock:
                 task_terminated_externally.value = 0
 
-        task = PythonOperator(
-            task_id='test_state_succeeded1',
-            python_callable=task_function,
-            on_success_callback=success_callback,
-            dag=dag,
-        )
+        with dag_maker(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}):
+            task = PythonOperator(
+                task_id='test_state_succeeded1',
+                python_callable=task_function,
+                on_success_callback=success_callback,
+            )
 
         session = settings.Session()
 
-        dag.clear()
-        dag.create_dagrun(
-            run_id="test",
-            state=State.RUNNING,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
@@ -614,7 +571,7 @@ class TestLocalTaskJob(unittest.TestCase):
             (signal.SIGKILL,),
         ]
     )
-    def test_process_kill_calls_on_failure_callback(self, signal_type):
+    def test_process_kill_calls_on_failure_callback(self, signal_type, dag_maker):
         """
         Test that ensures that when a task is killed with sigterm or sigkill
         on_failure_callback gets executed
@@ -630,8 +587,6 @@ class TestLocalTaskJob(unittest.TestCase):
                 failure_callback_called.value += 1
             assert context['dag_run'].dag_id == 'test_mark_failure'
 
-        dag = DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
         def task_function(ti):
 
             time.sleep(60)
@@ -639,23 +594,12 @@ class TestLocalTaskJob(unittest.TestCase):
             with shared_mem_lock:
                 task_terminated_externally.value = 0
 
-        task = PythonOperator(
-            task_id='test_on_failure',
-            python_callable=task_function,
-            on_failure_callback=failure_callback,
-            dag=dag,
-        )
-
-        session = settings.Session()
-
-        dag.clear()
-        dag.create_dagrun(
-            run_id="test",
-            state=State.RUNNING,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
+        with dag_maker(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}):
+            task = PythonOperator(
+                task_id='test_on_failure',
+                python_callable=task_function,
+                on_failure_callback=failure_callback,
+            )
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
@@ -780,6 +724,43 @@ class TestLocalTaskJob(unittest.TestCase):
             if scheduler_job.processor_agent:
                 scheduler_job.processor_agent.end()
 
+    def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self):
+        """Test that with DAG paused, DagRun state will update when the tasks finishes the run"""
+        dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
+        op1 = PythonOperator(task_id='dummy', dag=dag, owner='airflow', python_callable=lambda: True)
+
+        session = settings.Session()
+        orm_dag = DagModel(
+            dag_id=dag.dag_id,
+            has_task_concurrency_limits=False,
+            next_dagrun=dag.start_date,
+            next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE),
+            is_active=True,
+            is_paused=True,
+        )
+        session.add(orm_dag)
+        session.flush()
+        # Write Dag to DB
+        dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False)
+        dagbag.bag_dag(dag, root_dag=dag)
+        dagbag.sync_to_db()
+
+        dr = dag.create_dagrun(
+            run_type=DagRunType.SCHEDULED,
+            state=State.RUNNING,
+            execution_date=DEFAULT_DATE,
+            start_date=DEFAULT_DATE,
+            session=session,
+        )
+        assert dr.state == State.RUNNING
+        ti = TaskInstance(op1, dr.execution_date)
+        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
+        job1.task_runner = StandardTaskRunner(job1)
+        job1.run()
+        session.add(dr)
+        session.refresh(dr)
+        assert dr.state == State.SUCCESS
+
 
 @pytest.fixture()
 def clean_db_helper():