You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2021/08/09 06:13:17 UTC

[airflow] branch main updated: Use `dag_maker` in tests/core/test_core.py (#17462)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new daece96  Use `dag_maker` in tests/core/test_core.py (#17462)
daece96 is described below

commit daece96370d1087723d55c5b1711b1079182f572
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Mon Aug 9 07:13:01 2021 +0100

    Use `dag_maker` in tests/core/test_core.py (#17462)
    
    This PR applies dag_maker to tests in test_core.py module
---
 tests/core/test_core.py | 311 +++++++++++++++++++++++-------------------------
 1 file changed, 152 insertions(+), 159 deletions(-)

diff --git a/tests/core/test_core.py b/tests/core/test_core.py
index c4a4bc5..fe31bd2 100644
--- a/tests/core/test_core.py
+++ b/tests/core/test_core.py
@@ -20,7 +20,6 @@ import logging
 import multiprocessing
 import os
 import signal
-import unittest
 from datetime import timedelta
 from time import sleep
 from unittest.mock import MagicMock
@@ -31,19 +30,17 @@ from airflow import settings
 from airflow.exceptions import AirflowException, AirflowTaskTimeout
 from airflow.hooks.base import BaseHook
 from airflow.jobs.local_task_job import LocalTaskJob
-from airflow.models import DagBag, DagRun, TaskFail, TaskInstance
+from airflow.models import DagBag, TaskFail, TaskInstance
 from airflow.models.baseoperator import BaseOperator
-from airflow.models.dag import DAG
 from airflow.operators.bash import BashOperator
 from airflow.operators.check_operator import CheckOperator, ValueCheckOperator
 from airflow.operators.dummy import DummyOperator
 from airflow.operators.python import PythonOperator
-from airflow.settings import Session
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 from airflow.utils.types import DagRunType
 from tests.test_utils.config import conf_vars
-from tests.test_utils.db import clear_db_dags, clear_db_runs
+from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_task_fail
 
 DEV_NULL = '/dev/null'
 DEFAULT_DATE = datetime(2015, 1, 1)
@@ -65,43 +62,38 @@ class OperatorSubclass(BaseOperator):
         pass
 
 
-class TestCore(unittest.TestCase):
+class TestCore:
+    @staticmethod
+    def clean_db():
+        clear_db_task_fail()
+        clear_db_dags()
+        clear_db_runs()
+
     default_scheduler_args = {"num_runs": 1}
 
-    def setUp(self):
+    def setup_method(self):
+        self.clean_db()
         self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True, read_dags_from_db=False)
-        self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
-        self.dag = DAG(TEST_DAG_ID, default_args=self.args)
         self.dag_bash = self.dagbag.dags['example_bash_operator']
         self.runme_0 = self.dag_bash.get_task('runme_0')
         self.run_after_loop = self.dag_bash.get_task('run_after_loop')
         self.run_this_last = self.dag_bash.get_task('run_this_last')
 
-    def tearDown(self):
-        session = Session()
-        session.query(DagRun).filter(DagRun.dag_id == TEST_DAG_ID).delete(synchronize_session=False)
-        session.query(TaskInstance).filter(TaskInstance.dag_id == TEST_DAG_ID).delete(
-            synchronize_session=False
-        )
-        session.query(TaskFail).filter(TaskFail.dag_id == TEST_DAG_ID).delete(synchronize_session=False)
-        session.commit()
-        session.close()
-        clear_db_dags()
-        clear_db_runs()
+    def teardown_method(self):
+        self.clean_db()
 
-    def test_check_operators(self):
+    def test_check_operators(self, dag_maker):
 
         conn_id = "sqlite_default"
 
         captain_hook = BaseHook.get_hook(conn_id=conn_id)  # quite funny :D
         captain_hook.run("CREATE TABLE operator_test_table (a, b)")
         captain_hook.run("insert into operator_test_table values (1,2)")
-
-        self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
-        op = CheckOperator(
-            task_id='check', sql="select count(*) from operator_test_table", conn_id=conn_id, dag=self.dag
-        )
-
+        with dag_maker(TEST_DAG_ID) as dag:
+            op = CheckOperator(
+                task_id='check', sql="select count(*) from operator_test_table", conn_id=conn_id
+            )
+        dag_maker.create_dagrun(run_type=DagRunType.MANUAL)
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
 
         op = ValueCheckOperator(
@@ -110,7 +102,7 @@ class TestCore(unittest.TestCase):
             tolerance=0.1,
             conn_id=conn_id,
             sql="SELECT 100",
-            dag=self.dag,
+            dag=dag,
         )
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
 
@@ -122,61 +114,65 @@ class TestCore(unittest.TestCase):
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.are_dependents_done()
 
-    def test_illegal_args(self):
+    def test_illegal_args(self, dag_maker):
         """
         Tests that Operators reject illegal arguments
         """
         msg = 'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).'
         with conf_vars({('operators', 'allow_illegal_arguments'): 'True'}):
             with pytest.warns(PendingDeprecationWarning) as warnings:
-                BashOperator(
-                    task_id='test_illegal_args',
-                    bash_command='echo success',
-                    dag=self.dag,
-                    illegal_argument_1234='hello?',
-                )
+                with dag_maker():
+                    BashOperator(
+                        task_id='test_illegal_args',
+                        bash_command='echo success',
+                        illegal_argument_1234='hello?',
+                    )
+                dag_maker.create_dagrun()
                 assert any(msg in str(w) for w in warnings)
 
-    def test_illegal_args_forbidden(self):
+    def test_illegal_args_forbidden(self, dag_maker):
         """
         Tests that operators raise exceptions on illegal arguments when
         illegal arguments are not allowed.
         """
         with pytest.raises(AirflowException) as ctx:
-            BashOperator(
-                task_id='test_illegal_args',
-                bash_command='echo success',
-                dag=self.dag,
-                illegal_argument_1234='hello?',
-            )
+            with dag_maker():
+                BashOperator(
+                    task_id='test_illegal_args',
+                    bash_command='echo success',
+                    illegal_argument_1234='hello?',
+                )
+            dag_maker.create_dagrun()
         assert 'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).' in str(ctx.value)
 
-    def test_bash_operator(self):
-        op = BashOperator(task_id='test_bash_operator', bash_command="echo success", dag=self.dag)
-        self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
+    def test_bash_operator(self, dag_maker):
+        with dag_maker():
+            op = BashOperator(task_id='test_bash_operator', bash_command="echo success")
+        dag_maker.create_dagrun(run_type=DagRunType.MANUAL)
 
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
 
-    def test_bash_operator_multi_byte_output(self):
-        op = BashOperator(
-            task_id='test_multi_byte_bash_operator',
-            bash_command="echo \u2600",
-            dag=self.dag,
-            output_encoding='utf-8',
-        )
-        self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
+    def test_bash_operator_multi_byte_output(self, dag_maker):
+        with dag_maker():
+            op = BashOperator(
+                task_id='test_multi_byte_bash_operator',
+                bash_command="echo \u2600",
+                output_encoding='utf-8',
+            )
+        dag_maker.create_dagrun(run_type=DagRunType.MANUAL)
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
 
-    def test_bash_operator_kill(self):
+    def test_bash_operator_kill(self, dag_maker):
         import psutil
 
         sleep_time = "100%d" % os.getpid()
-        op = BashOperator(
-            task_id='test_bash_operator_kill',
-            execution_timeout=timedelta(seconds=1),
-            bash_command=f"/bin/bash -c 'sleep {sleep_time}'",
-            dag=self.dag,
-        )
+        with dag_maker():
+            op = BashOperator(
+                task_id='test_bash_operator_kill',
+                execution_timeout=timedelta(seconds=1),
+                bash_command=f"/bin/bash -c 'sleep {sleep_time}'",
+            )
+        dag_maker.create_dagrun()
         with pytest.raises(AirflowTaskTimeout):
             op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
         sleep(2)
@@ -188,67 +184,74 @@ class TestCore(unittest.TestCase):
             os.kill(pid, signal.SIGTERM)
             self.fail("BashOperator's subprocess still running after stopping on timeout!")
 
-    def test_on_failure_callback(self):
+    def test_on_failure_callback(self, dag_maker):
         mock_failure_callback = MagicMock()
 
-        op = BashOperator(
-            task_id='check_on_failure_callback',
-            bash_command="exit 1",
-            dag=self.dag,
-            on_failure_callback=mock_failure_callback,
-        )
+        with dag_maker():
+            op = BashOperator(
+                task_id='check_on_failure_callback',
+                bash_command="exit 1",
+                on_failure_callback=mock_failure_callback,
+            )
+        dag_maker.create_dagrun()
         with pytest.raises(AirflowException):
             op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
         mock_failure_callback.assert_called_once()
 
-    def test_dryrun(self):
-        op = BashOperator(task_id='test_dryrun', bash_command="echo success", dag=self.dag)
+    def test_dryrun(self, dag_maker):
+        with dag_maker():
+            op = BashOperator(task_id='test_dryrun', bash_command="echo success")
+        dag_maker.create_dagrun()
         op.dry_run()
 
-    def test_sqlite(self):
+    def test_sqlite(self, dag_maker):
         import airflow.providers.sqlite.operators.sqlite
 
-        op = airflow.providers.sqlite.operators.sqlite.SqliteOperator(
-            task_id='time_sqlite', sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))", dag=self.dag
-        )
-        self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
+        with dag_maker():
+            op = airflow.providers.sqlite.operators.sqlite.SqliteOperator(
+                task_id='time_sqlite',
+                sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))",
+            )
+        dag_maker.create_dagrun(run_type=DagRunType.MANUAL)
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
 
-    def test_timeout(self):
-        op = PythonOperator(
-            task_id='test_timeout',
-            execution_timeout=timedelta(seconds=1),
-            python_callable=lambda: sleep(5),
-            dag=self.dag,
-        )
+    def test_timeout(self, dag_maker):
+        with dag_maker():
+            op = PythonOperator(
+                task_id='test_timeout',
+                execution_timeout=timedelta(seconds=1),
+                python_callable=lambda: sleep(5),
+            )
+        dag_maker.create_dagrun()
         with pytest.raises(AirflowTaskTimeout):
             op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
 
-    def test_python_op(self):
+    def test_python_op(self, dag_maker):
         def test_py_op(templates_dict, ds, **kwargs):
             if not templates_dict['ds'] == ds:
                 raise Exception("failure")
 
-        op = PythonOperator(
-            task_id='test_py_op', python_callable=test_py_op, templates_dict={'ds': "{{ ds }}"}, dag=self.dag
-        )
-        self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
+        with dag_maker():
+            op = PythonOperator(
+                task_id='test_py_op', python_callable=test_py_op, templates_dict={'ds': "{{ ds }}"}
+            )
+        dag_maker.create_dagrun(run_type=DagRunType.MANUAL)
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
 
-    def test_complex_template(self):
+    def test_complex_template(self, dag_maker):
         def verify_templated_field(context):
             assert context['ti'].task.some_templated_field['bar'][1] == context['ds']
 
-        op = OperatorSubclass(
-            task_id='test_complex_template',
-            some_templated_field={'foo': '123', 'bar': ['baz', '{{ ds }}']},
-            dag=self.dag,
-        )
+        with dag_maker():
+            op = OperatorSubclass(
+                task_id='test_complex_template',
+                some_templated_field={'foo': '123', 'bar': ['baz', '{{ ds }}']},
+            )
         op.execute = verify_templated_field
-        self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE)
+        dag_maker.create_dagrun(run_type=DagRunType.MANUAL)
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
 
-    def test_template_non_bool(self):
+    def test_template_non_bool(self, dag_maker):
         """
         Test templates can handle objects with no sense of truthiness
         """
@@ -260,9 +263,9 @@ class TestCore(unittest.TestCase):
             def __bool__(self):
                 return NotImplemented
 
-        op = OperatorSubclass(
-            task_id='test_bad_template_obj', some_templated_field=NonBoolObject(), dag=self.dag
-        )
+        with dag_maker():
+            op = OperatorSubclass(task_id='test_bad_template_obj', some_templated_field=NonBoolObject())
+        dag_maker.create_dagrun()
         op.resolve_template_files()
 
     def test_task_get_template(self):
@@ -312,9 +315,11 @@ class TestCore(unittest.TestCase):
         )
         ti.run(ignore_ti_state=True)
 
-    def test_bad_trigger_rule(self):
+    def test_bad_trigger_rule(self, dag_maker):
         with pytest.raises(AirflowException):
-            DummyOperator(task_id='test_bad_trigger', trigger_rule="non_existent", dag=self.dag)
+            with dag_maker():
+                DummyOperator(task_id='test_bad_trigger', trigger_rule="non_existent")
+            dag_maker.create_dagrun()
 
     def test_terminate_task(self):
         """If a task instance's db state get deleted, it should fail"""
@@ -353,17 +358,16 @@ class TestCore(unittest.TestCase):
         assert State.FAILED == ti.state
         session.close()
 
-    def test_task_fail_duration(self):
+    def test_task_fail_duration(self, dag_maker):
         """If a task fails, the duration should be recorded in TaskFail"""
-
-        op1 = BashOperator(task_id='pass_sleepy', bash_command='sleep 3', dag=self.dag)
-        op2 = BashOperator(
-            task_id='fail_sleepy',
-            bash_command='sleep 5',
-            execution_timeout=timedelta(seconds=3),
-            retry_delay=timedelta(seconds=0),
-            dag=self.dag,
-        )
+        with dag_maker() as dag:
+            op1 = BashOperator(task_id='pass_sleepy', bash_command='sleep 3')
+            op2 = BashOperator(
+                task_id='fail_sleepy',
+                bash_command='sleep 5',
+                execution_timeout=timedelta(seconds=3),
+                retry_delay=timedelta(seconds=0),
+            )
         session = settings.Session()
         try:
             op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@@ -375,12 +379,12 @@ class TestCore(unittest.TestCase):
             pass
         op1_fails = (
             session.query(TaskFail)
-            .filter_by(task_id='pass_sleepy', dag_id=self.dag.dag_id, execution_date=DEFAULT_DATE)
+            .filter_by(task_id='pass_sleepy', dag_id=dag.dag_id, execution_date=DEFAULT_DATE)
             .all()
         )
         op2_fails = (
             session.query(TaskFail)
-            .filter_by(task_id='fail_sleepy', dag_id=self.dag.dag_id, execution_date=DEFAULT_DATE)
+            .filter_by(task_id='fail_sleepy', dag_id=dag.dag_id, execution_date=DEFAULT_DATE)
             .all()
         )
 
@@ -388,7 +392,7 @@ class TestCore(unittest.TestCase):
         assert 1 == len(op2_fails)
         assert sum(f.duration for f in op2_fails) >= 3
 
-    def test_externally_triggered_dagrun(self):
+    def test_externally_triggered_dagrun(self, dag_maker):
         TI = TaskInstance
 
         # Create the dagrun between two "scheduled" execution dates of the DAG
@@ -396,14 +400,11 @@ class TestCore(unittest.TestCase):
         execution_ds = execution_date.strftime('%Y-%m-%d')
         execution_ds_nodash = execution_ds.replace('-', '')
 
-        dag = DAG(
-            TEST_DAG_ID, default_args=self.args, schedule_interval=timedelta(weeks=1), start_date=DEFAULT_DATE
-        )
-        task = DummyOperator(task_id='test_externally_triggered_dag_context', dag=dag)
-        dag.create_dagrun(
+        with dag_maker(schedule_interval=timedelta(weeks=1)):
+            task = DummyOperator(task_id='test_externally_triggered_dag_context')
+        dag_maker.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=execution_date,
-            state=State.RUNNING,
             external_trigger=True,
         )
         task.run(start_date=execution_date, end_date=execution_date)
@@ -418,7 +419,7 @@ class TestCore(unittest.TestCase):
         assert context['prev_ds'] == execution_ds
         assert context['prev_ds_nodash'] == execution_ds_nodash
 
-    def test_dag_params_and_task_params(self):
+    def test_dag_params_and_task_params(self, dag_maker):
         # This test case guards how params of DAG and Operator work together.
         # - If any key exists in either DAG's or Operator's params,
         #   it is guaranteed to be available eventually.
@@ -426,23 +427,17 @@ class TestCore(unittest.TestCase):
         #   the latter has precedence.
         TI = TaskInstance
 
-        dag = DAG(
-            TEST_DAG_ID,
-            default_args=self.args,
+        with dag_maker(
             schedule_interval=timedelta(weeks=1),
-            start_date=DEFAULT_DATE,
             params={'key_1': 'value_1', 'key_2': 'value_2_old'},
-        )
-        task1 = DummyOperator(
-            task_id='task1',
-            dag=dag,
-            params={'key_2': 'value_2_new', 'key_3': 'value_3'},
-        )
-        task2 = DummyOperator(task_id='task2', dag=dag)
-        dag.create_dagrun(
+        ):
+            task1 = DummyOperator(
+                task_id='task1',
+                params={'key_2': 'value_2_new', 'key_3': 'value_3'},
+            )
+            task2 = DummyOperator(task_id='task2')
+        dag_maker.create_dagrun(
             run_type=DagRunType.SCHEDULED,
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
             external_trigger=True,
         )
         task1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -456,30 +451,27 @@ class TestCore(unittest.TestCase):
         assert context2['params'] == {'key_1': 'value_1', 'key_2': 'value_2_old'}
 
 
-@pytest.fixture()
-def dag():
-    return DAG(TEST_DAG_ID, default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
-
-
-def test_operator_retries_invalid(dag):
+def test_operator_retries_invalid(dag_maker):
     with pytest.raises(AirflowException) as ctx:
-        BashOperator(
-            task_id='test_illegal_args',
-            bash_command='echo success',
-            dag=dag,
-            retries='foo',
-        )
+        with dag_maker():
+            BashOperator(
+                task_id='test_illegal_args',
+                bash_command='echo success',
+                retries='foo',
+            )
+        dag_maker.create_dagrun()
     assert str(ctx.value) == "'retries' type must be int, not str"
 
 
-def test_operator_retries_coerce(caplog, dag):
+def test_operator_retries_coerce(caplog, dag_maker):
     with caplog.at_level(logging.WARNING):
-        BashOperator(
-            task_id='test_illegal_args',
-            bash_command='echo success',
-            dag=dag,
-            retries='1',
-        )
+        with dag_maker():
+            BashOperator(
+                task_id='test_illegal_args',
+                bash_command='echo success',
+                retries='1',
+            )
+        dag_maker.create_dagrun()
     assert caplog.record_tuples == [
         (
             "airflow.operators.bash.BashOperator",
@@ -490,12 +482,13 @@ def test_operator_retries_coerce(caplog, dag):
 
 
 @pytest.mark.parametrize("retries", [None, 5])
-def test_operator_retries(caplog, dag, retries):
+def test_operator_retries(caplog, dag_maker, retries):
     with caplog.at_level(logging.WARNING):
-        BashOperator(
-            task_id='test_illegal_args',
-            bash_command='echo success',
-            dag=dag,
-            retries=retries,
-        )
+        with dag_maker(TEST_DAG_ID + str(retries)):
+            BashOperator(
+                task_id='test_illegal_args',
+                bash_command='echo success',
+                retries=retries,
+            )
+        dag_maker.create_dagrun()
     assert caplog.records == []