You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by pi...@apache.org on 2023/03/06 21:47:06 UTC
[airflow] 22/37: Refactor python operators/sensor tests (#28493)
This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 061338fad1a9ec4bf12b1aad482b3a72f7d3551c
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Thu Dec 22 12:32:06 2022 +0400
Refactor python operators/sensor tests (#28493)
(cherry picked from commit 884fca8d114ce8e0c982747937a1014f3b5e7491)
---
tests/conftest.py | 8 +-
tests/decorators/test_python.py | 143 +---
tests/decorators/test_python_virtualenv.py | 13 -
tests/operators/test_python.py | 1003 ++++++++++------------------
tests/sensors/test_python.py | 124 +---
5 files changed, 426 insertions(+), 865 deletions(-)
diff --git a/tests/conftest.py b/tests/conftest.py
index 0d4d1170f0..d71d8eb0f0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -22,6 +22,7 @@ import subprocess
import sys
from contextlib import ExitStack, suppress
from datetime import datetime, timedelta
+from typing import TYPE_CHECKING
import freezegun
import pytest
@@ -46,6 +47,9 @@ from tests.test_utils.perf.perf_kit.sqlalchemy import ( # noqa isort:skip
trace_queries,
)
+if TYPE_CHECKING:
+ from airflow.models.taskinstance import TaskInstance
+
@pytest.fixture()
def reset_environment():
@@ -741,7 +745,7 @@ def create_task_instance(dag_maker, create_dummy_dag):
run_type=None,
data_interval=None,
**kwargs,
- ):
+ ) -> TaskInstance:
if execution_date is None:
from airflow.utils import timezone
@@ -775,7 +779,7 @@ def create_task_instance_of_operator(dag_maker):
execution_date=None,
session=None,
**operator_kwargs,
- ):
+ ) -> TaskInstance:
with dag_maker(dag_id=dag_id, session=session):
operator_class(**operator_kwargs)
if execution_date is None:
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 47a908db77..1bbad51a0b 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -37,41 +37,20 @@ from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType
-from tests.operators.test_python import Call, assert_calls_equal, build_recording_function
-from tests.test_utils.db import clear_db_runs
+from tests.operators.test_python import BasePythonTest
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-END_DATE = timezone.datetime(2016, 1, 2)
-INTERVAL = timedelta(hours=12)
-FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1)
-TI_CONTEXT_ENV_VARS = [
- "AIRFLOW_CTX_DAG_ID",
- "AIRFLOW_CTX_TASK_ID",
- "AIRFLOW_CTX_EXECUTION_DATE",
- "AIRFLOW_CTX_DAG_RUN_ID",
-]
-
-class TestAirflowTaskDecorator:
- def setup_class(self):
- clear_db_runs()
-
- def setup_method(self):
- self.dag = DAG("test_dag", default_args={"owner": "airflow", "start_date": DEFAULT_DATE})
- self.run = False
-
- def teardown_method(self):
- self.dag.clear()
- self.run = False
- clear_db_runs()
+class TestAirflowTaskDecorator(BasePythonTest):
+ default_date = DEFAULT_DATE
def test_python_operator_python_callable_is_callable(self):
"""Tests that @task will only instantiate if
the python_callable argument is callable."""
not_callable = {}
with pytest.raises(TypeError):
- task_decorator(not_callable, dag=self.dag)
+ task_decorator(not_callable)
@pytest.mark.parametrize(
"resolve",
@@ -155,13 +134,7 @@ class TestAirflowTaskDecorator:
with self.dag:
res = identity2(8, 4)
- dr = self.dag.create_dagrun(
- run_id=DagRunType.MANUAL.value,
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
+ dr = self.create_dag_run()
res.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
ti = dr.get_task_instances()[0]
@@ -179,13 +152,7 @@ class TestAirflowTaskDecorator:
with self.dag:
ident = identity_tuple(35, 36)
- dr = self.dag.create_dagrun(
- run_id=DagRunType.MANUAL.value,
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
+ dr = self.create_dag_run()
ident.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
ti = dr.get_task_instances()[0]
@@ -227,15 +194,9 @@ class TestAirflowTaskDecorator:
with self.dag:
ret = add_number(2)
- self.dag.create_dagrun(
- run_id=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
+ self.create_dag_run()
with pytest.raises(AirflowException):
-
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
def test_fail_multiple_outputs_no_dict(self):
@@ -245,84 +206,53 @@ class TestAirflowTaskDecorator:
with self.dag:
ret = add_number(2)
- self.dag.create_dagrun(
- run_id=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
+ self.create_dag_run()
with pytest.raises(AirflowException):
-
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
def test_python_callable_arguments_are_templatized(self):
"""Test @task op_args are templatized"""
- recorded_calls = []
+
+ @task_decorator
+ def arg_task(*args):
+ raise RuntimeError("Should not executed")
# Create a named tuple and ensure it is still preserved
# after the rendering is done
Named = namedtuple("Named", ["var1", "var2"])
named_tuple = Named("{{ ds }}", "unchanged")
- task = task_decorator(
- # a Mock instance cannot be used as a callable function or test fails with a
- # TypeError: Object of type Mock is not JSON serializable
- build_recording_function(recorded_calls),
- dag=self.dag,
- )
- ret = task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple)
-
- self.dag.create_dagrun(
- run_id=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
- ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ with self.dag:
+ ret = arg_task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple)
- ds_templated = DEFAULT_DATE.date().isoformat()
- assert len(recorded_calls) == 1
- assert_calls_equal(
- recorded_calls[0],
- Call(
- 4,
- date(2019, 1, 1),
- f"dag {self.dag.dag_id} ran on {ds_templated}.",
- Named(ds_templated, "unchanged"),
- ),
- )
+ dr = self.create_dag_run()
+ ti = TaskInstance(task=ret.operator, run_id=dr.run_id)
+ rendered_op_args = ti.render_templates().op_args
+ assert len(rendered_op_args) == 4
+ assert rendered_op_args[0] == 4
+ assert rendered_op_args[1] == date(2019, 1, 1)
+ assert rendered_op_args[2] == f"dag {self.dag_id} ran on {self.ds_templated}."
+ assert rendered_op_args[3] == Named(self.ds_templated, "unchanged")
def test_python_callable_keyword_arguments_are_templatized(self):
"""Test PythonOperator op_kwargs are templatized"""
- recorded_calls = []
- task = task_decorator(
- # a Mock instance cannot be used as a callable function or test fails with a
- # TypeError: Object of type Mock is not JSON serializable
- build_recording_function(recorded_calls),
- dag=self.dag,
- )
- ret = task(an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag {{dag.dag_id}} ran on {{ds}}.")
- self.dag.create_dagrun(
- run_id=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
- ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ @task_decorator
+ def kwargs_task(an_int, a_date, a_templated_string):
+ raise RuntimeError("Should not executed")
- assert len(recorded_calls) == 1
- assert_calls_equal(
- recorded_calls[0],
- Call(
- an_int=4,
- a_date=date(2019, 1, 1),
- a_templated_string=f"dag {self.dag.dag_id} ran on {DEFAULT_DATE.date().isoformat()}.",
- ),
- )
+ with self.dag:
+ ret = kwargs_task(
+ an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag {{dag.dag_id}} ran on {{ds}}."
+ )
+
+ dr = self.create_dag_run()
+ ti = TaskInstance(task=ret.operator, run_id=dr.run_id)
+ rendered_op_kwargs = ti.render_templates().op_kwargs
+ assert rendered_op_kwargs["an_int"] == 4
+ assert rendered_op_kwargs["a_date"] == date(2019, 1, 1)
+ assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id} ran on {self.ds_templated}."
def test_manual_task_id(self):
"""Test manually setting task_id"""
@@ -415,6 +345,7 @@ class TestAirflowTaskDecorator:
def do_run():
return 4
+ self.dag.default_args["owner"] = "airflow"
with self.dag:
ret = do_run()
assert ret.operator.owner == "airflow"
diff --git a/tests/decorators/test_python_virtualenv.py b/tests/decorators/test_python_virtualenv.py
index 032ec34aa5..88121c5db3 100644
--- a/tests/decorators/test_python_virtualenv.py
+++ b/tests/decorators/test_python_virtualenv.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import datetime
import sys
-from datetime import timedelta
from subprocess import CalledProcessError
import pytest
@@ -28,18 +27,6 @@ from airflow.decorators import task
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-END_DATE = timezone.datetime(2016, 1, 2)
-INTERVAL = timedelta(hours=12)
-FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1)
-
-TI_CONTEXT_ENV_VARS = [
- "AIRFLOW_CTX_DAG_ID",
- "AIRFLOW_CTX_TASK_ID",
- "AIRFLOW_CTX_EXECUTION_DATE",
- "AIRFLOW_CTX_DAG_RUN_ID",
-]
-
-
PYTHON_VERSION = sys.version_info[0]
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index c011c5cb35..8f6c089f08 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -20,16 +20,18 @@ from __future__ import annotations
import copy
import logging
import os
+import re
import sys
-import unittest.mock
import warnings
from collections import namedtuple
from datetime import date, datetime, timedelta
from subprocess import CalledProcessError
+from unittest import mock
import pytest
+from slugify import slugify
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.models import DAG, DagRun, TaskInstance as TI
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstance import clear_task_instances, set_current_context
@@ -45,88 +47,108 @@ from airflow.utils import timezone
from airflow.utils.context import AirflowContextDeprecationWarning, Context
from airflow.utils.python_virtualenv import prepare_virtualenv
from airflow.utils.session import create_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
from airflow.utils.trigger_rule import TriggerRule
-from airflow.utils.types import DagRunType
+from airflow.utils.types import NOTSET, DagRunType
from tests.test_utils import AIRFLOW_MAIN_FOLDER
from tests.test_utils.db import clear_db_runs
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-END_DATE = timezone.datetime(2016, 1, 2)
-INTERVAL = timedelta(hours=12)
-FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1)
-
-TI_CONTEXT_ENV_VARS = [
- "AIRFLOW_CTX_DAG_ID",
- "AIRFLOW_CTX_TASK_ID",
- "AIRFLOW_CTX_EXECUTION_DATE",
- "AIRFLOW_CTX_DAG_RUN_ID",
-]
-
TEMPLATE_SEARCHPATH = os.path.join(AIRFLOW_MAIN_FOLDER, "tests", "config_templates")
+LOGGER_NAME = "airflow.task.operators"
-class Call:
- def __init__(self, *args, **kwargs):
- self.args = args
- self.kwargs = kwargs
-
-
-def build_recording_function(calls_collection):
- """
- We can not use a Mock instance as a PythonOperator callable function or some tests fail with a
- TypeError: Object of type Mock is not JSON serializable
- Then using this custom function recording custom Call objects for further testing
- (replacing Mock.assert_called_with assertion method)
- """
-
- def recording_function(*args, **kwargs):
- calls_collection.append(Call(*args, **kwargs))
-
- return recording_function
-
+class BasePythonTest:
+ """Base test class for TestPythonOperator and TestPythonSensor classes"""
-def assert_calls_equal(first: Call, second: Call) -> None:
- assert isinstance(first, Call)
- assert isinstance(second, Call)
- assert first.args == second.args
- # eliminate context (conf, dag_run, task_instance, etc.)
- test_args = ["an_int", "a_date", "a_templated_string"]
- first.kwargs = {key: value for (key, value) in first.kwargs.items() if key in test_args}
- second.kwargs = {key: value for (key, value) in second.kwargs.items() if key in test_args}
- assert first.kwargs == second.kwargs
+ opcls: type[BaseOperator]
+ dag_id: str
+ task_id: str
+ run_id: str
+ dag: DAG
+ ds_templated: str
+ default_date: datetime = DEFAULT_DATE
+
+ @pytest.fixture(autouse=True)
+ def base_tests_setup(self, request, create_task_instance_of_operator, dag_maker):
+ self.dag_id = f"dag_{slugify(request.cls.__name__)}"
+ self.task_id = f"task_{slugify(request.node.name, max_length=40)}"
+ self.run_id = f"run_{slugify(request.node.name, max_length=40)}"
+ self.ds_templated = self.default_date.date().isoformat()
+ self.ti_maker = create_task_instance_of_operator
+ self.dag_maker = dag_maker
+ self.dag = self.dag_maker(self.dag_id, template_searchpath=TEMPLATE_SEARCHPATH).dag
+ clear_db_runs()
+ yield
+ clear_db_runs()
+
+ @staticmethod
+ def assert_expected_task_states(dag_run: DagRun, expected_states: dict):
+ """Helper function that asserts `TaskInstances` of a given `task_id` are in a given state."""
+ asserts = []
+ for ti in dag_run.get_task_instances():
+ try:
+ expected = expected_states[ti.task_id]
+ except KeyError:
+ asserts.append(f"Unexpected task id {ti.task_id!r} found, expected {expected_states.keys()}")
+ continue
+
+ if ti.state != expected:
+ asserts.append(f"Task {ti.task_id!r} has state {ti.state!r} instead of expected {expected!r}")
+ if asserts:
+ pytest.fail("\n".join(asserts))
+
+ @staticmethod
+ def default_kwargs(**kwargs):
+ """Default arguments for specific Operator."""
+ return kwargs
+
+ def create_dag_run(self) -> DagRun:
+ return self.dag.create_dagrun(
+ state=DagRunState.RUNNING,
+ start_date=self.dag_maker.start_date,
+ session=self.dag_maker.session,
+ execution_date=self.default_date,
+ run_type=DagRunType.MANUAL,
+ )
+ def create_ti(self, fn, **kwargs) -> TI:
+ """Create TaskInstance for class defined Operator."""
+ return self.ti_maker(
+ self.opcls,
+ python_callable=fn,
+ **self.default_kwargs(**kwargs),
+ dag_id=self.dag_id,
+ task_id=self.task_id,
+ execution_date=self.default_date,
+ )
-class TestPythonBase(unittest.TestCase):
- """Base test class for TestPythonOperator and TestPythonSensor classes"""
+ def run_as_operator(self, fn, **kwargs):
+ """Run task by direct call ``run`` method."""
+ with self.dag:
+ task = self.opcls(task_id=self.task_id, python_callable=fn, **self.default_kwargs(**kwargs))
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
+ task.run(start_date=self.default_date, end_date=self.default_date)
+ return task
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
+ def run_as_task(self, fn, **kwargs):
+ """Create TaskInstance and run it."""
+ ti = self.create_ti(fn, **kwargs)
+ ti.run()
+ return ti.task
- def setUp(self):
- super().setUp()
- self.dag = DAG("test_dag", default_args={"owner": "airflow", "start_date": DEFAULT_DATE})
- self.addCleanup(self.dag.clear)
- self.clear_run()
- self.addCleanup(self.clear_run)
+ def render_templates(self, fn, **kwargs):
+ """Create TaskInstance and render templates without actual run."""
+ return self.create_ti(fn, **kwargs).render_templates()
- def tearDown(self):
- super().tearDown()
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
+class TestPythonOperator(BasePythonTest):
+ opcls = PythonOperator
- def clear_run(self):
+ @pytest.fixture(autouse=True)
+ def setup_tests(self):
self.run = False
-
-class TestPythonOperator(TestPythonBase):
def do_run(self):
self.run = True
@@ -135,105 +157,58 @@ class TestPythonOperator(TestPythonBase):
def test_python_operator_run(self):
"""Tests that the python callable is invoked on task run."""
- task = PythonOperator(python_callable=self.do_run, task_id="python_operator", dag=self.dag)
+ ti = self.create_ti(self.do_run)
assert not self.is_run()
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ ti.run()
assert self.is_run()
- def test_python_operator_python_callable_is_callable(self):
- """Tests that PythonOperator will only instantiate if
- the python_callable argument is callable."""
- not_callable = {}
- with pytest.raises(AirflowException):
- PythonOperator(python_callable=not_callable, task_id="python_operator", dag=self.dag)
- not_callable = None
- with pytest.raises(AirflowException):
- PythonOperator(python_callable=not_callable, task_id="python_operator", dag=self.dag)
+ @pytest.mark.parametrize("not_callable", [{}, None])
+ def test_python_operator_python_callable_is_callable(self, not_callable):
+ """Tests that PythonOperator will only instantiate if the python_callable argument is callable."""
+ with pytest.raises(AirflowException, match="`python_callable` param must be callable"):
+ PythonOperator(python_callable=not_callable, task_id="python_operator")
def test_python_callable_arguments_are_templatized(self):
"""Test PythonOperator op_args are templatized"""
- recorded_calls = []
-
# Create a named tuple and ensure it is still preserved
# after the rendering is done
Named = namedtuple("Named", ["var1", "var2"])
named_tuple = Named("{{ ds }}", "unchanged")
- task = PythonOperator(
- task_id="python_operator",
- # a Mock instance cannot be used as a callable function or test fails with a
- # TypeError: Object of type Mock is not JSON serializable
- python_callable=build_recording_function(recorded_calls),
+ task = self.render_templates(
+ lambda: 0,
op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple],
- dag=self.dag,
- )
-
- self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- ds_templated = DEFAULT_DATE.date().isoformat()
- assert 1 == len(recorded_calls)
- assert_calls_equal(
- recorded_calls[0],
- Call(
- 4,
- date(2019, 1, 1),
- f"dag {self.dag.dag_id} ran on {ds_templated}.",
- Named(ds_templated, "unchanged"),
- ),
)
+ rendered_op_args = task.op_args
+ assert len(rendered_op_args) == 4
+ assert rendered_op_args[0] == 4
+ assert rendered_op_args[1] == date(2019, 1, 1)
+ assert rendered_op_args[2] == f"dag {self.dag_id} ran on {self.ds_templated}."
+ assert rendered_op_args[3] == Named(self.ds_templated, "unchanged")
def test_python_callable_keyword_arguments_are_templatized(self):
"""Test PythonOperator op_kwargs are templatized"""
- recorded_calls = []
-
- task = PythonOperator(
- task_id="python_operator",
- # a Mock instance cannot be used as a callable function or test fails with a
- # TypeError: Object of type Mock is not JSON serializable
- python_callable=build_recording_function(recorded_calls),
+ task = self.render_templates(
+ lambda: 0,
op_kwargs={
"an_int": 4,
"a_date": date(2019, 1, 1),
"a_templated_string": "dag {{dag.dag_id}} ran on {{ds}}.",
},
- dag=self.dag,
- )
-
- self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- assert 1 == len(recorded_calls)
- assert_calls_equal(
- recorded_calls[0],
- Call(
- an_int=4,
- a_date=date(2019, 1, 1),
- a_templated_string=f"dag {self.dag.dag_id} ran on {DEFAULT_DATE.date().isoformat()}.",
- ),
)
+ rendered_op_kwargs = task.op_kwargs
+ assert rendered_op_kwargs["an_int"] == 4
+ assert rendered_op_kwargs["a_date"] == date(2019, 1, 1)
+ assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id} ran on {self.ds_templated}."
def test_python_operator_shallow_copy_attr(self):
def not_callable(x):
- return x
+ assert False, "Should not be triggered"
original_task = PythonOperator(
python_callable=not_callable,
- task_id="python_operator",
op_kwargs={"certain_attrs": ""},
- dag=self.dag,
+ task_id=self.task_id,
)
new_task = copy.deepcopy(original_task)
# shallow copy op_kwargs
@@ -242,383 +217,213 @@ class TestPythonOperator(TestPythonBase):
assert id(original_task.python_callable) == id(new_task.python_callable)
def test_conflicting_kwargs(self):
- self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- external_trigger=False,
- )
-
# dag is not allowed since it is a reserved keyword
def func(dag):
- # An ValueError should be triggered since we're using dag as a
- # reserved keyword
+ # An ValueError should be triggered since we're using dag as a reserved keyword
raise RuntimeError(f"Should not be triggered, dag: {dag}")
- python_operator = PythonOperator(
- task_id="python_operator", op_args=[1], python_callable=func, dag=self.dag
- )
-
- with pytest.raises(ValueError) as ctx:
- python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- assert "dag" in str(ctx.value), "'dag' not found in the exception"
+ ti = self.create_ti(func, op_args=[1])
+ error_message = re.escape("The key 'dag' in args is a part of kwargs and therefore reserved.")
+ with pytest.raises(ValueError, match=error_message):
+ ti.run()
def test_provide_context_does_not_fail(self):
- """
- ensures that provide_context doesn't break dags in 2.0
- """
- self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- external_trigger=False,
- )
+ """Ensures that provide_context doesn't break dags in 2.0."""
def func(custom, dag):
assert 1 == custom, "custom should be 1"
assert dag is not None, "dag should be set"
- python_operator = PythonOperator(
- task_id="python_operator",
- op_kwargs={"custom": 1},
- python_callable=func,
- provide_context=True,
- dag=self.dag,
- )
- python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ with pytest.warns(RemovedInAirflow3Warning):
+ self.run_as_task(func, op_kwargs={"custom": 1}, provide_context=True)
def test_context_with_conflicting_op_args(self):
- self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- external_trigger=False,
- )
-
def func(custom, dag):
assert 1 == custom, "custom should be 1"
assert dag is not None, "dag should be set"
- python_operator = PythonOperator(
- task_id="python_operator", op_kwargs={"custom": 1}, python_callable=func, dag=self.dag
- )
- python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ self.run_as_task(func, op_kwargs={"custom": 1})
def test_context_with_kwargs(self):
- self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- external_trigger=False,
- )
-
def func(**context):
# check if context is being set
assert len(context) > 0, "Context has not been injected"
- python_operator = PythonOperator(
- task_id="python_operator", op_kwargs={"custom": 1}, python_callable=func, dag=self.dag
- )
- python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- def test_return_value_log_with_show_return_value_in_logs_default(self):
- self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- external_trigger=False,
- )
-
- def func():
- return "test_return_value"
-
- python_operator = PythonOperator(task_id="python_operator", python_callable=func, dag=self.dag)
-
- with self.assertLogs("airflow.task.operators", level=logging.INFO) as cm:
- python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ self.run_as_task(func, op_kwargs={"custom": 1})
- assert (
- "INFO:airflow.task.operators:Done. Returned value was: test_return_value" in cm.output
- ), "Return value should be shown"
-
- def test_return_value_log_with_show_return_value_in_logs_false(self):
- self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- external_trigger=False,
- )
+ @pytest.mark.parametrize(
+ "show_return_value_in_logs, should_shown",
+ [
+ pytest.param(NOTSET, True, id="default"),
+ pytest.param(True, True, id="show"),
+ pytest.param(False, False, id="hide"),
+ ],
+ )
+ def test_return_value_log(self, show_return_value_in_logs, should_shown, caplog):
+ caplog.set_level(logging.INFO, logger=LOGGER_NAME)
def func():
return "test_return_value"
- python_operator = PythonOperator(
- task_id="python_operator",
- python_callable=func,
- dag=self.dag,
- show_return_value_in_logs=False,
- )
-
- with self.assertLogs("airflow.task.operators", level=logging.INFO) as cm:
- python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- assert (
- "INFO:airflow.task.operators:Done. Returned value was: test_return_value" not in cm.output
- ), "Return value should not be shown"
- assert (
- "INFO:airflow.task.operators:Done. Returned value not shown" in cm.output
- ), "Log message that the option is turned off should be shown"
+ if show_return_value_in_logs is NOTSET:
+ self.run_as_task(func)
+ else:
+ self.run_as_task(func, show_return_value_in_logs=show_return_value_in_logs)
+ if should_shown:
+ assert "Done. Returned value was: test_return_value" in caplog.messages
+ assert "Done. Returned value not shown" not in caplog.messages
+ else:
+ assert "Done. Returned value was: test_return_value" not in caplog.messages
+ assert "Done. Returned value not shown" in caplog.messages
-class TestBranchOperator(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
-
- def setUp(self):
- self.dag = DAG(
- "branch_operator_test",
- default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
- schedule=INTERVAL,
- )
- self.branch_1 = EmptyOperator(task_id="branch_1", dag=self.dag)
- self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag)
- self.branch_3 = None
+class TestBranchOperator(BasePythonTest):
+ opcls = BranchPythonOperator
- def tearDown(self):
- super().tearDown()
-
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
+ @pytest.fixture(autouse=True)
+ def setup_tests(self):
+ self.branch_1 = EmptyOperator(task_id="branch_1")
+ self.branch_2 = EmptyOperator(task_id="branch_2")
def test_with_dag_run(self):
- branch_op = BranchPythonOperator(
- task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1"
- )
-
- self.branch_1.set_upstream(branch_op)
- self.branch_2.set_upstream(branch_op)
- self.dag.clear()
+ with self.dag:
+ branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1")
+ branch_op >> [self.branch_1, self.branch_2]
- dr = self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
+ dr = self.create_dag_run()
+ branch_op.run(start_date=self.default_date, end_date=self.default_date)
+ self.assert_expected_task_states(
+ dr, {self.task_id: State.SUCCESS, "branch_1": State.NONE, "branch_2": State.SKIPPED}
)
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- tis = dr.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.NONE
- elif ti.task_id == "branch_2":
- assert ti.state == State.SKIPPED
- else:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
-
def test_with_skip_in_branch_downstream_dependencies(self):
- branch_op = BranchPythonOperator(
- task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1"
- )
-
- branch_op >> self.branch_1 >> self.branch_2
- branch_op >> self.branch_2
- self.dag.clear()
+ with self.dag:
+ branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1")
+ branch_op >> self.branch_1 >> self.branch_2
+ branch_op >> self.branch_2
- dr = self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
+ dr = self.create_dag_run()
+ branch_op.run(start_date=self.default_date, end_date=self.default_date)
+ self.assert_expected_task_states(
+ dr, {self.task_id: State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE}
)
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- tis = dr.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.NONE
- elif ti.task_id == "branch_2":
- assert ti.state == State.NONE
- else:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
-
def test_with_skip_in_branch_downstream_dependencies2(self):
- branch_op = BranchPythonOperator(
- task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_2"
- )
-
- branch_op >> self.branch_1 >> self.branch_2
- branch_op >> self.branch_2
- self.dag.clear()
+ with self.dag:
+ branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_2")
+ branch_op >> self.branch_1 >> self.branch_2
+ branch_op >> self.branch_2
- dr = self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
+ dr = self.create_dag_run()
+ branch_op.run(start_date=self.default_date, end_date=self.default_date)
+ self.assert_expected_task_states(
+ dr, {self.task_id: State.SUCCESS, "branch_1": State.SKIPPED, "branch_2": State.NONE}
)
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- tis = dr.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.SKIPPED
- elif ti.task_id == "branch_2":
- assert ti.state == State.NONE
- else:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
-
def test_xcom_push(self):
- branch_op = BranchPythonOperator(
- task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1"
- )
-
- self.branch_1.set_upstream(branch_op)
- self.branch_2.set_upstream(branch_op)
- self.dag.clear()
-
- dr = self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ with self.dag:
+ branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1")
+ branch_op >> [self.branch_1, self.branch_2]
- tis = dr.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.xcom_pull(task_ids="make_choice") == "branch_1"
+ dr = self.create_dag_run()
+ branch_op.run(start_date=self.default_date, end_date=self.default_date)
+ for ti in dr.get_task_instances():
+ if ti.task_id == self.task_id:
+ assert ti.xcom_pull(task_ids=self.task_id) == "branch_1"
+ break
+ else:
+ pytest.fail(f"{self.task_id!r} not found.")
def test_clear_skipped_downstream_task(self):
"""
After a downstream task is skipped by BranchPythonOperator, clearing the skipped task
should not cause it to be executed.
"""
- branch_op = BranchPythonOperator(
- task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1"
- )
- branches = [self.branch_1, self.branch_2]
- branch_op >> branches
- self.dag.clear()
-
- dr = self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ with self.dag:
+ branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1")
+ branches = [self.branch_1, self.branch_2]
+ branch_op >> branches
+ dr = self.create_dag_run()
+ branch_op.run(start_date=self.default_date, end_date=self.default_date)
for task in branches:
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ task.run(start_date=self.default_date, end_date=self.default_date)
- tis = dr.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_2":
- assert ti.state == State.SKIPPED
- else:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
+ expected_states = {
+ self.task_id: State.SUCCESS,
+ "branch_1": State.SUCCESS,
+ "branch_2": State.SKIPPED,
+ }
- children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()]
+ self.assert_expected_task_states(dr, expected_states)
# Clear the children tasks.
+ tis = dr.get_task_instances()
+ children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()]
with create_session() as session:
- clear_task_instances(children_tis, session=session, dag=self.dag)
+ clear_task_instances(children_tis, session=session, dag=branch_op.dag)
# Run the cleared tasks again.
for task in branches:
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ task.run(start_date=self.default_date, end_date=self.default_date)
# Check if the states are correct after children tasks are cleared.
- for ti in dr.get_task_instances():
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_2":
- assert ti.state == State.SKIPPED
- else:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
+ self.assert_expected_task_states(dr, expected_states)
def test_raise_exception_on_no_accepted_type_return(self):
- branch_op = BranchPythonOperator(task_id="make_choice", dag=self.dag, python_callable=lambda: 5)
- self.dag.clear()
- with pytest.raises(AirflowException) as ctx:
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- assert "must be either None, a task ID, or an Iterable of IDs" in str(ctx.value)
+ ti = self.create_ti(lambda: 5)
+ with pytest.raises(AirflowException, match="must be either None, a task ID, or an Iterable of IDs"):
+ ti.run()
def test_raise_exception_on_invalid_task_id(self):
- branch_op = BranchPythonOperator(
- task_id="make_choice", dag=self.dag, python_callable=lambda: "some_task_id"
- )
- self.dag.clear()
- with pytest.raises(AirflowException) as ctx:
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- assert "Invalid tasks found: {'some_task_id'}" in str(ctx.value)
+ ti = self.create_ti(lambda: "some_task_id")
+ with pytest.raises(AirflowException, match="Invalid tasks found: {'some_task_id'}"):
+ ti.run()
+ @pytest.mark.parametrize(
+ "choice,expected_states",
+ [
+ ("task1", [State.SUCCESS, State.SUCCESS, State.SUCCESS]),
+ ("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]),
+ ],
+ )
+ def test_empty_branch(self, choice, expected_states):
+ """
+ Tests that BranchPythonOperator handles empty branches properly.
+ """
+ with self.dag:
+ branch = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: choice)
+ task1 = EmptyOperator(task_id="task1")
+ join = EmptyOperator(task_id="join", trigger_rule="none_failed_min_one_success")
-class TestShortCircuitOperator:
- def setup(self):
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
+ branch >> [task1, join]
+ task1 >> join
- self.dag = DAG(
- "short_circuit_op_test",
- start_date=DEFAULT_DATE,
- schedule=INTERVAL,
- )
+ dr = self.create_dag_run()
+ task_ids = [self.task_id, "task1", "join"]
+ tis = {ti.task_id: ti for ti in dr.task_instances}
- with self.dag:
- self.op1 = EmptyOperator(task_id="op1")
- self.op2 = EmptyOperator(task_id="op2")
- self.op1.set_downstream(self.op2)
+ for task_id in task_ids: # Mimic the specific order the scheduling would run the tests.
+ task_instance = tis[task_id]
+ task_instance.refresh_from_task(self.dag.get_task(task_id))
+ task_instance.run()
- def teardown(self):
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
+ def get_state(ti):
+ ti.refresh_from_db()
+ return ti.state
- def _assert_expected_task_states(self, dagrun, expected_states):
- """Helper function that asserts `TaskInstances` of a given `task_id` are in a given state."""
+ assert [get_state(tis[task_id]) for task_id in task_ids] == expected_states
- tis = dagrun.get_task_instances()
- for ti in tis:
- try:
- expected_state = expected_states[ti.task_id]
- except KeyError:
- raise ValueError(f"Invalid task id {ti.task_id} found!")
- else:
- assert ti.state == expected_state
+
+class TestShortCircuitOperator(BasePythonTest):
+ opcls = ShortCircuitOperator
+
+ @pytest.fixture(autouse=True)
+ def setup_tests(self):
+ self.task_id = "short_circuit"
+ self.op1 = EmptyOperator(task_id="op1")
+ self.op2 = EmptyOperator(task_id="op2")
all_downstream_skipped_states = {
"short_circuit": State.SUCCESS,
@@ -725,62 +530,41 @@ class TestShortCircuitOperator:
Checking the behavior of the ShortCircuitOperator in several scenarios enabling/disabling the skipping
of downstream tasks, both short-circuiting modes, and various trigger rules of downstream tasks.
"""
-
- self.short_circuit = ShortCircuitOperator(
- task_id="short_circuit",
- python_callable=lambda: callable_return,
- ignore_downstream_trigger_rules=test_ignore_downstream_trigger_rules,
- dag=self.dag,
- )
- self.short_circuit.set_downstream(self.op1)
- self.op2.trigger_rule = test_trigger_rule
- self.dag.clear()
-
- dagrun = self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
- self.short_circuit.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- self.op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- self.op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- assert self.short_circuit.ignore_downstream_trigger_rules == test_ignore_downstream_trigger_rules
- assert self.short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS
+ with self.dag:
+ short_circuit = ShortCircuitOperator(
+ task_id="short_circuit",
+ python_callable=lambda: callable_return,
+ ignore_downstream_trigger_rules=test_ignore_downstream_trigger_rules,
+ )
+ short_circuit >> self.op1 >> self.op2
+ self.op2.trigger_rule = test_trigger_rule
+
+ dr = self.create_dag_run()
+ short_circuit.run(start_date=self.default_date, end_date=self.default_date)
+ self.op1.run(start_date=self.default_date, end_date=self.default_date)
+ self.op2.run(start_date=self.default_date, end_date=self.default_date)
+
+ assert short_circuit.ignore_downstream_trigger_rules == test_ignore_downstream_trigger_rules
+ assert short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS
assert self.op1.trigger_rule == TriggerRule.ALL_SUCCESS
assert self.op2.trigger_rule == test_trigger_rule
-
- self._assert_expected_task_states(dagrun, expected_task_states)
+ self.assert_expected_task_states(dr, expected_task_states)
def test_clear_skipped_downstream_task(self):
"""
After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task
should not cause it to be executed.
"""
-
- self.short_circuit = ShortCircuitOperator(
- task_id="short_circuit",
- python_callable=lambda: False,
- dag=self.dag,
- )
- self.short_circuit.set_downstream(self.op1)
- self.dag.clear()
-
- dagrun = self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
-
- self.short_circuit.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- self.op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- self.op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- assert self.short_circuit.ignore_downstream_trigger_rules
- assert self.short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS
+ with self.dag:
+ short_circuit = ShortCircuitOperator(task_id="short_circuit", python_callable=lambda: False)
+ short_circuit >> self.op1 >> self.op2
+ dr = self.create_dag_run()
+
+ short_circuit.run(start_date=self.default_date, end_date=self.default_date)
+ self.op1.run(start_date=self.default_date, end_date=self.default_date)
+ self.op2.run(start_date=self.default_date, end_date=self.default_date)
+ assert short_circuit.ignore_downstream_trigger_rules
+ assert short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS
assert self.op1.trigger_rule == TriggerRule.ALL_SUCCESS
assert self.op2.trigger_rule == TriggerRule.ALL_SUCCESS
@@ -789,82 +573,45 @@ class TestShortCircuitOperator:
"op1": State.SKIPPED,
"op2": State.SKIPPED,
}
- self._assert_expected_task_states(dagrun, expected_states)
+ self.assert_expected_task_states(dr, expected_states)
# Clear downstream task "op1" that was previously executed.
- tis = dagrun.get_task_instances()
-
+ tis = dr.get_task_instances()
with create_session() as session:
- clear_task_instances([ti for ti in tis if ti.task_id == "op1"], session=session, dag=self.dag)
-
- self.op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- self._assert_expected_task_states(dagrun, expected_states)
+ clear_task_instances(
+ [ti for ti in tis if ti.task_id == "op1"], session=session, dag=short_circuit.dag
+ )
+ self.op1.run(start_date=self.default_date, end_date=self.default_date)
+ self.assert_expected_task_states(dr, expected_states)
def test_xcom_push(self):
- short_op_push_xcom = ShortCircuitOperator(
- task_id="push_xcom_from_shortcircuit", dag=self.dag, python_callable=lambda: "signature"
- )
-
- short_op_no_push_xcom = ShortCircuitOperator(
- task_id="do_not_push_xcom_from_shortcircuit", dag=self.dag, python_callable=lambda: False
- )
-
- self.dag.clear()
- dr = self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
+ with self.dag:
+ short_op_push_xcom = ShortCircuitOperator(
+ task_id="push_xcom_from_shortcircuit", python_callable=lambda: "signature"
+ )
+ short_op_no_push_xcom = ShortCircuitOperator(
+ task_id="do_not_push_xcom_from_shortcircuit", python_callable=lambda: False
+ )
- short_op_push_xcom.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- short_op_no_push_xcom.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ dr = self.create_dag_run()
+ short_op_push_xcom.run(start_date=self.default_date, end_date=self.default_date)
+ short_op_no_push_xcom.run(start_date=self.default_date, end_date=self.default_date)
tis = dr.get_task_instances()
- xcom_value_short_op_push_xcom = tis[0].xcom_pull(
- task_ids="push_xcom_from_shortcircuit", key="return_value"
- )
- assert xcom_value_short_op_push_xcom == "signature"
-
- xcom_value_short_op_no_push_xcom = tis[0].xcom_pull(
- task_ids="do_not_push_xcom_from_shortcircuit", key="return_value"
- )
- assert xcom_value_short_op_no_push_xcom is None
+ assert tis[0].xcom_pull(task_ids=short_op_push_xcom.task_id, key="return_value") == "signature"
+ assert tis[0].xcom_pull(task_ids=short_op_no_push_xcom.task_id, key="return_value") is None
virtualenv_string_args: list[str] = []
-class TestPythonVirtualenvOperator(unittest.TestCase):
- def setUp(self):
- super().setUp()
- self.dag = DAG(
- "test_dag",
- default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
- template_searchpath=TEMPLATE_SEARCHPATH,
- schedule=INTERVAL,
- )
- self.dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
- self.addCleanup(self.dag.clear)
-
- def tearDown(self):
- super().tearDown()
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
-
- def _run_as_operator(self, fn, python_version=sys.version_info[0], **kwargs):
+class TestPythonVirtualenvOperator(BasePythonTest):
+ opcls = PythonVirtualenvOperator
- task = PythonVirtualenvOperator(
- python_callable=fn, python_version=python_version, task_id="task", dag=self.dag, **kwargs
- )
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- return task
+ @staticmethod
+ def default_kwargs(*, python_version=sys.version_info[0], **kwargs):
+ kwargs["python_version"] = python_version
+ return kwargs
def test_template_fields(self):
assert set(PythonOperator.template_fields).issubset(PythonVirtualenvOperator.template_fields)
@@ -874,7 +621,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
"""Ensure dill is correctly installed."""
import dill # noqa: F401
- self._run_as_operator(f, use_dill=True, system_site_packages=False)
+ self.run_as_task(f, use_dill=True, system_site_packages=False)
def test_no_requirements(self):
"""Tests that the python callable is invoked on task run."""
@@ -882,7 +629,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
def f():
pass
- self._run_as_operator(f)
+ self.run_as_task(f)
def test_no_system_site_packages(self):
def f():
@@ -892,13 +639,13 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
return True
raise Exception
- self._run_as_operator(f, system_site_packages=False, requirements=["dill"])
+ self.run_as_task(f, system_site_packages=False, requirements=["dill"])
def test_system_site_packages(self):
def f():
import funcsigs # noqa: F401
- self._run_as_operator(f, requirements=["funcsigs"], system_site_packages=True)
+ self.run_as_task(f, requirements=["funcsigs"], system_site_packages=True)
def test_with_requirements_pinned(self):
def f():
@@ -907,44 +654,44 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
if funcsigs.__version__ != "0.4":
raise Exception
- self._run_as_operator(f, requirements=["funcsigs==0.4"])
+ self.run_as_task(f, requirements=["funcsigs==0.4"])
def test_unpinned_requirements(self):
def f():
import funcsigs # noqa: F401
- self._run_as_operator(f, requirements=["funcsigs", "dill"], system_site_packages=False)
+ self.run_as_task(f, requirements=["funcsigs", "dill"], system_site_packages=False)
def test_range_requirements(self):
def f():
import funcsigs # noqa: F401
- self._run_as_operator(f, requirements=["funcsigs>1.0", "dill"], system_site_packages=False)
+ self.run_as_task(f, requirements=["funcsigs>1.0", "dill"], system_site_packages=False)
def test_requirements_file(self):
def f():
import funcsigs # noqa: F401
- self._run_as_operator(f, requirements="requirements.txt", system_site_packages=False)
+ self.run_as_operator(f, requirements="requirements.txt", system_site_packages=False)
- @unittest.mock.patch("airflow.operators.python.prepare_virtualenv")
+ @mock.patch("airflow.operators.python.prepare_virtualenv")
def test_pip_install_options(self, mocked_prepare_virtualenv):
def f():
import funcsigs # noqa: F401
mocked_prepare_virtualenv.side_effect = prepare_virtualenv
- self._run_as_operator(
+ self.run_as_task(
f,
requirements=["funcsigs==0.4"],
system_site_packages=False,
pip_install_options=["--no-deps"],
)
mocked_prepare_virtualenv.assert_called_with(
- venv_directory=unittest.mock.ANY,
- python_bin=unittest.mock.ANY,
+ venv_directory=mock.ANY,
+ python_bin=mock.ANY,
system_site_packages=False,
- requirements_file_path=unittest.mock.ANY,
+ requirements_file_path=mock.ANY,
pip_install_options=["--no-deps"],
)
@@ -954,7 +701,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
assert funcsigs.__version__ == "1.0.2"
- self._run_as_operator(
+ self.run_as_operator(
f,
requirements="requirements.txt",
use_dill=True,
@@ -967,7 +714,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
raise Exception
with pytest.raises(CalledProcessError):
- self._run_as_operator(f)
+ self.run_as_task(f)
def test_python_3(self):
def f():
@@ -980,13 +727,13 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
return
raise Exception
- self._run_as_operator(f, python_version=3, use_dill=False, requirements=["dill"])
+ self.run_as_task(f, python_version=3, use_dill=False, requirements=["dill"])
def test_without_dill(self):
def f(a):
return a
- self._run_as_operator(f, system_site_packages=False, use_dill=False, op_args=[4])
+ self.run_as_task(f, system_site_packages=False, use_dill=False, op_args=[4])
def test_string_args(self):
def f():
@@ -995,7 +742,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
if virtualenv_string_args[0] != virtualenv_string_args[2]:
raise Exception
- self._run_as_operator(f, string_args=[1, 2, 1])
+ self.run_as_task(f, string_args=[1, 2, 1])
def test_with_args(self):
def f(a, b, c=False, d=False):
@@ -1004,37 +751,38 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
else:
raise Exception
- self._run_as_operator(f, op_args=[0, 1], op_kwargs={"c": True})
+ self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True})
def test_return_none(self):
def f():
return None
- task = self._run_as_operator(f)
+ task = self.run_as_task(f)
assert task.execute_callable() is None
def test_return_false(self):
def f():
return False
- task = self._run_as_operator(f)
+ task = self.run_as_task(f)
assert task.execute_callable() is False
def test_lambda(self):
with pytest.raises(AirflowException):
- PythonVirtualenvOperator(python_callable=lambda x: 4, task_id="task", dag=self.dag)
+ PythonVirtualenvOperator(python_callable=lambda x: 4, task_id=self.task_id)
def test_nonimported_as_arg(self):
def f(_):
return None
- self._run_as_operator(f, op_args=[datetime.utcnow()])
+ self.run_as_task(f, op_args=[datetime.utcnow()])
def test_context(self):
def f(templates_dict):
return templates_dict["ds"]
- self._run_as_operator(f, templates_dict={"ds": "{{ ds }}"})
+ task = self.run_as_task(f, templates_dict={"ds": "{{ ds }}"})
+ assert task.templates_dict == {"ds": self.ds_templated}
# This tests might take longer than default 60 seconds as it is serializing a lot of
# context using dill (which is slow apparently).
@@ -1078,7 +826,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
):
pass
- self._run_as_operator(f, use_dill=True, system_site_packages=True, requirements=None)
+ self.run_as_operator(f, use_dill=True, system_site_packages=True, requirements=None)
@pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_pendulum_context(self):
@@ -1112,7 +860,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
):
pass
- self._run_as_operator(f, use_dill=True, system_site_packages=False, requirements=["pendulum"])
+ self.run_as_task(f, use_dill=True, system_site_packages=False, requirements=["pendulum"])
@pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_base_context(self):
@@ -1140,7 +888,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
):
pass
- self._run_as_operator(f, use_dill=True, system_site_packages=False, requirements=None)
+ self.run_as_task(f, use_dill=True, system_site_packages=False, requirements=None)
def test_deepcopy(self):
"""Test that PythonVirtualenvOperator are deep-copyable."""
@@ -1148,13 +896,37 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
def f():
return 1
- task = PythonVirtualenvOperator(
- python_callable=f,
- task_id="task",
- dag=self.dag,
- )
+ task = PythonVirtualenvOperator(python_callable=f, task_id="task")
copy.deepcopy(task)
+ def test_virtualenv_serializable_context_fields(self, create_task_instance):
+ """Ensure all template context fields are listed in the operator.
+
+ This exists mainly so when a field is added to the context, we remember to
+ also add it to PythonVirtualenvOperator.
+ """
+ # These are intentionally NOT serialized into the virtual environment:
+ # * Variables pointing to the task instance itself.
+ # * Variables that are accessor instances.
+ intentionally_excluded_context_keys = [
+ "task_instance",
+ "ti",
+ "var", # Accessor for Variable; var->json and var->value.
+ "conn", # Accessor for Connection.
+ ]
+
+ ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None)
+ context = ti.get_template_context()
+
+ declared_keys = {
+ *PythonVirtualenvOperator.BASE_SERIALIZABLE_CONTEXT_KEYS,
+ *PythonVirtualenvOperator.PENDULUM_SERIALIZABLE_CONTEXT_KEYS,
+ *PythonVirtualenvOperator.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS,
+ *intentionally_excluded_context_keys,
+ }
+
+ assert set(context) == declared_keys
+
DEFAULT_ARGS = {
"owner": "test",
@@ -1221,7 +993,7 @@ def get_all_the_context(**context):
assert context == current_context._context
-@pytest.fixture()
+@pytest.fixture
def clear_db():
clear_db_runs()
yield
@@ -1239,76 +1011,3 @@ class TestCurrentContextRuntime:
with DAG(dag_id="edge_case_context_dag", default_args=DEFAULT_ARGS):
op = PythonOperator(python_callable=get_all_the_context, task_id="get_all_the_context")
op.run(ignore_first_depends_on_past=True, ignore_ti_state=True)
-
-
-@pytest.mark.parametrize(
- "choice,expected_states",
- [
- ("task1", [State.SUCCESS, State.SUCCESS, State.SUCCESS]),
- ("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]),
- ],
-)
-def test_empty_branch(dag_maker, choice, expected_states):
- """
- Tests that BranchPythonOperator handles empty branches properly.
- """
- with dag_maker(
- "test_empty_branch",
- start_date=DEFAULT_DATE,
- ) as dag:
- branch = BranchPythonOperator(task_id="branch", python_callable=lambda: choice)
- task1 = EmptyOperator(task_id="task1")
- join = EmptyOperator(task_id="join", trigger_rule="none_failed_min_one_success")
-
- branch >> [task1, join]
- task1 >> join
-
- dag.clear(start_date=DEFAULT_DATE)
- dag_run = dag_maker.create_dagrun()
-
- task_ids = ["branch", "task1", "join"]
- tis = {ti.task_id: ti for ti in dag_run.task_instances}
-
- for task_id in task_ids: # Mimic the specific order the scheduling would run the tests.
- task_instance = tis[task_id]
- task_instance.refresh_from_task(dag.get_task(task_id))
- task_instance.run()
-
- def get_state(ti):
- ti.refresh_from_db()
- return ti.state
-
- assert [get_state(tis[task_id]) for task_id in task_ids] == expected_states
-
-
-def test_virtualenv_serializable_context_fields(create_task_instance):
- """Ensure all template context fields are listed in the operator.
-
- This exists mainly so when a field is added to the context, we remember to
- also add it to PythonVirtualenvOperator.
- """
- # These are intentionally NOT serialized into the virtual environment:
- # * Variables pointing to the task instance itself.
- # * Variables that are accessor instances.
- intentionally_excluded_context_keys = [
- "task_instance",
- "ti",
- "var", # Accessor for Variable; var->json and var->value.
- "conn", # Accessor for Connection.
- ]
-
- ti = create_task_instance(
- dag_id="test_virtualenv_serializable_context_fields",
- task_id="test_virtualenv_serializable_context_fields_task",
- schedule=None,
- )
- context = ti.get_template_context()
-
- declared_keys = {
- *PythonVirtualenvOperator.BASE_SERIALIZABLE_CONTEXT_KEYS,
- *PythonVirtualenvOperator.PENDULUM_SERIALIZABLE_CONTEXT_KEYS,
- *PythonVirtualenvOperator.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS,
- *intentionally_excluded_context_keys,
- }
-
- assert set(context) == declared_keys
diff --git a/tests/sensors/test_python.py b/tests/sensors/test_python.py
index f3c258185d..73a0ddffe4 100644
--- a/tests/sensors/test_python.py
+++ b/tests/sensors/test_python.py
@@ -24,112 +24,52 @@ import pytest
from airflow.exceptions import AirflowSensorTimeout
from airflow.sensors.python import PythonSensor
-from airflow.utils.state import State
-from airflow.utils.timezone import datetime
-from airflow.utils.types import DagRunType
-from tests.operators.test_python import Call, assert_calls_equal, build_recording_function
+from tests.operators.test_python import BasePythonTest
-DEFAULT_DATE = datetime(2015, 1, 1)
+class TestPythonSensor(BasePythonTest):
+ opcls = PythonSensor
-class TestPythonSensor:
- def test_python_sensor_true(self, dag_maker):
- with dag_maker():
- op = PythonSensor(task_id="python_sensor_check_true", python_callable=lambda: True)
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ def test_python_sensor_true(self):
+ self.run_as_task(fn=lambda: True)
- def test_python_sensor_false(self, dag_maker):
- with dag_maker():
- op = PythonSensor(
- task_id="python_sensor_check_false",
- timeout=0.01,
- poke_interval=0.01,
- python_callable=lambda: False,
- )
+ def test_python_sensor_false(self):
with pytest.raises(AirflowSensorTimeout):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ self.run_as_task(lambda: False, timeout=0.01, poke_interval=0.01)
- def test_python_sensor_raise(self, dag_maker):
- with dag_maker():
- op = PythonSensor(task_id="python_sensor_check_raise", python_callable=lambda: 1 / 0)
+ def test_python_sensor_raise(self):
with pytest.raises(ZeroDivisionError):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ self.run_as_task(lambda: 1 / 0)
- def test_python_callable_arguments_are_templatized(self, dag_maker):
+ def test_python_callable_arguments_are_templatized(self):
"""Test PythonSensor op_args are templatized"""
- recorded_calls = []
-
# Create a named tuple and ensure it is still preserved
# after the rendering is done
Named = namedtuple("Named", ["var1", "var2"])
named_tuple = Named("{{ ds }}", "unchanged")
- with dag_maker() as dag:
- task = PythonSensor(
- task_id="python_sensor",
- timeout=0.01,
- poke_interval=0.3,
- # a Mock instance cannot be used as a callable function or test fails with a
- # TypeError: Object of type Mock is not JSON serializable
- python_callable=build_recording_function(recorded_calls),
- op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple],
- )
-
- dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
- with pytest.raises(AirflowSensorTimeout):
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- ds_templated = DEFAULT_DATE.date().isoformat()
- assert_calls_equal(
- recorded_calls[0],
- Call(
- 4,
- date(2019, 1, 1),
- f"dag {dag.dag_id} ran on {ds_templated}.",
- Named(ds_templated, "unchanged"),
- ),
+ task = self.render_templates(
+ lambda: 0,
+ op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple],
)
-
- def test_python_callable_keyword_arguments_are_templatized(self, dag_maker):
+ rendered_op_args = task.op_args
+ assert len(rendered_op_args) == 4
+ assert rendered_op_args[0] == 4
+ assert rendered_op_args[1] == date(2019, 1, 1)
+ assert rendered_op_args[2] == f"dag {self.dag_id} ran on {self.ds_templated}."
+ assert rendered_op_args[3] == Named(self.ds_templated, "unchanged")
+
+ def test_python_callable_keyword_arguments_are_templatized(self):
"""Test PythonSensor op_kwargs are templatized"""
- recorded_calls = []
-
- with dag_maker() as dag:
- task = PythonSensor(
- task_id="python_sensor",
- timeout=0.01,
- poke_interval=0.01,
- # a Mock instance cannot be used as a callable function or test fails with a
- # TypeError: Object of type Mock is not JSON serializable
- python_callable=build_recording_function(recorded_calls),
- op_kwargs={
- "an_int": 4,
- "a_date": date(2019, 1, 1),
- "a_templated_string": "dag {{dag.dag_id}} ran on {{ds}}.",
- },
- )
-
- dag.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=DEFAULT_DATE,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- start_date=DEFAULT_DATE,
- state=State.RUNNING,
- )
- with pytest.raises(AirflowSensorTimeout):
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- assert_calls_equal(
- recorded_calls[0],
- Call(
- an_int=4,
- a_date=date(2019, 1, 1),
- a_templated_string=f"dag {dag.dag_id} ran on {DEFAULT_DATE.date().isoformat()}.",
- ),
+ task = self.render_templates(
+ lambda: 0,
+ op_kwargs={
+ "an_int": 4,
+ "a_date": date(2019, 1, 1),
+ "a_templated_string": "dag {{dag.dag_id}} ran on {{ds}}.",
+ },
)
+ rendered_op_kwargs = task.op_kwargs
+ assert rendered_op_kwargs["an_int"] == 4
+ assert rendered_op_kwargs["a_date"] == date(2019, 1, 1)
+ assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id} ran on {self.ds_templated}."