You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by pi...@apache.org on 2023/03/06 21:47:06 UTC

[airflow] 22/37: Refactor python operators/sensor tests (#28493)

This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 061338fad1a9ec4bf12b1aad482b3a72f7d3551c
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Thu Dec 22 12:32:06 2022 +0400

    Refactor python operators/sensor tests (#28493)
    
    (cherry picked from commit 884fca8d114ce8e0c982747937a1014f3b5e7491)
---
 tests/conftest.py                          |    8 +-
 tests/decorators/test_python.py            |  143 +---
 tests/decorators/test_python_virtualenv.py |   13 -
 tests/operators/test_python.py             | 1003 ++++++++++------------------
 tests/sensors/test_python.py               |  124 +---
 5 files changed, 426 insertions(+), 865 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index 0d4d1170f0..d71d8eb0f0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -22,6 +22,7 @@ import subprocess
 import sys
 from contextlib import ExitStack, suppress
 from datetime import datetime, timedelta
+from typing import TYPE_CHECKING
 
 import freezegun
 import pytest
@@ -46,6 +47,9 @@ from tests.test_utils.perf.perf_kit.sqlalchemy import (  # noqa isort:skip
     trace_queries,
 )
 
+if TYPE_CHECKING:
+    from airflow.models.taskinstance import TaskInstance
+
 
 @pytest.fixture()
 def reset_environment():
@@ -741,7 +745,7 @@ def create_task_instance(dag_maker, create_dummy_dag):
         run_type=None,
         data_interval=None,
         **kwargs,
-    ):
+    ) -> TaskInstance:
         if execution_date is None:
             from airflow.utils import timezone
 
@@ -775,7 +779,7 @@ def create_task_instance_of_operator(dag_maker):
         execution_date=None,
         session=None,
         **operator_kwargs,
-    ):
+    ) -> TaskInstance:
         with dag_maker(dag_id=dag_id, session=session):
             operator_class(**operator_kwargs)
         if execution_date is None:
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 47a908db77..1bbad51a0b 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -37,41 +37,20 @@ from airflow.utils import timezone
 from airflow.utils.state import State
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.types import DagRunType
-from tests.operators.test_python import Call, assert_calls_equal, build_recording_function
-from tests.test_utils.db import clear_db_runs
+from tests.operators.test_python import BasePythonTest
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-END_DATE = timezone.datetime(2016, 1, 2)
-INTERVAL = timedelta(hours=12)
-FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1)
 
-TI_CONTEXT_ENV_VARS = [
-    "AIRFLOW_CTX_DAG_ID",
-    "AIRFLOW_CTX_TASK_ID",
-    "AIRFLOW_CTX_EXECUTION_DATE",
-    "AIRFLOW_CTX_DAG_RUN_ID",
-]
 
-
-class TestAirflowTaskDecorator:
-    def setup_class(self):
-        clear_db_runs()
-
-    def setup_method(self):
-        self.dag = DAG("test_dag", default_args={"owner": "airflow", "start_date": DEFAULT_DATE})
-        self.run = False
-
-    def teardown_method(self):
-        self.dag.clear()
-        self.run = False
-        clear_db_runs()
+class TestAirflowTaskDecorator(BasePythonTest):
+    default_date = DEFAULT_DATE
 
     def test_python_operator_python_callable_is_callable(self):
         """Tests that @task will only instantiate if
         the python_callable argument is callable."""
         not_callable = {}
         with pytest.raises(TypeError):
-            task_decorator(not_callable, dag=self.dag)
+            task_decorator(not_callable)
 
     @pytest.mark.parametrize(
         "resolve",
@@ -155,13 +134,7 @@ class TestAirflowTaskDecorator:
         with self.dag:
             res = identity2(8, 4)
 
-        dr = self.dag.create_dagrun(
-            run_id=DagRunType.MANUAL.value,
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
+        dr = self.create_dag_run()
         res.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
         ti = dr.get_task_instances()[0]
@@ -179,13 +152,7 @@ class TestAirflowTaskDecorator:
         with self.dag:
             ident = identity_tuple(35, 36)
 
-        dr = self.dag.create_dagrun(
-            run_id=DagRunType.MANUAL.value,
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
+        dr = self.create_dag_run()
         ident.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
         ti = dr.get_task_instances()[0]
@@ -227,15 +194,9 @@ class TestAirflowTaskDecorator:
 
         with self.dag:
             ret = add_number(2)
-        self.dag.create_dagrun(
-            run_id=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
 
+        self.create_dag_run()
         with pytest.raises(AirflowException):
-
             ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
     def test_fail_multiple_outputs_no_dict(self):
@@ -245,84 +206,53 @@ class TestAirflowTaskDecorator:
 
         with self.dag:
             ret = add_number(2)
-        self.dag.create_dagrun(
-            run_id=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
 
+        self.create_dag_run()
         with pytest.raises(AirflowException):
-
             ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
     def test_python_callable_arguments_are_templatized(self):
         """Test @task op_args are templatized"""
-        recorded_calls = []
+
+        @task_decorator
+        def arg_task(*args):
+            raise RuntimeError("Should not executed")
 
         # Create a named tuple and ensure it is still preserved
         # after the rendering is done
         Named = namedtuple("Named", ["var1", "var2"])
         named_tuple = Named("{{ ds }}", "unchanged")
 
-        task = task_decorator(
-            # a Mock instance cannot be used as a callable function or test fails with a
-            # TypeError: Object of type Mock is not JSON serializable
-            build_recording_function(recorded_calls),
-            dag=self.dag,
-        )
-        ret = task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple)
-
-        self.dag.create_dagrun(
-            run_id=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        with self.dag:
+            ret = arg_task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple)
 
-        ds_templated = DEFAULT_DATE.date().isoformat()
-        assert len(recorded_calls) == 1
-        assert_calls_equal(
-            recorded_calls[0],
-            Call(
-                4,
-                date(2019, 1, 1),
-                f"dag {self.dag.dag_id} ran on {ds_templated}.",
-                Named(ds_templated, "unchanged"),
-            ),
-        )
+        dr = self.create_dag_run()
+        ti = TaskInstance(task=ret.operator, run_id=dr.run_id)
+        rendered_op_args = ti.render_templates().op_args
+        assert len(rendered_op_args) == 4
+        assert rendered_op_args[0] == 4
+        assert rendered_op_args[1] == date(2019, 1, 1)
+        assert rendered_op_args[2] == f"dag {self.dag_id} ran on {self.ds_templated}."
+        assert rendered_op_args[3] == Named(self.ds_templated, "unchanged")
 
     def test_python_callable_keyword_arguments_are_templatized(self):
         """Test PythonOperator op_kwargs are templatized"""
-        recorded_calls = []
 
-        task = task_decorator(
-            # a Mock instance cannot be used as a callable function or test fails with a
-            # TypeError: Object of type Mock is not JSON serializable
-            build_recording_function(recorded_calls),
-            dag=self.dag,
-        )
-        ret = task(an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag {{dag.dag_id}} ran on {{ds}}.")
-        self.dag.create_dagrun(
-            run_id=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        @task_decorator
+        def kwargs_task(an_int, a_date, a_templated_string):
+            raise RuntimeError("Should not executed")
 
-        assert len(recorded_calls) == 1
-        assert_calls_equal(
-            recorded_calls[0],
-            Call(
-                an_int=4,
-                a_date=date(2019, 1, 1),
-                a_templated_string=f"dag {self.dag.dag_id} ran on {DEFAULT_DATE.date().isoformat()}.",
-            ),
-        )
+        with self.dag:
+            ret = kwargs_task(
+                an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag {{dag.dag_id}} ran on {{ds}}."
+            )
+
+        dr = self.create_dag_run()
+        ti = TaskInstance(task=ret.operator, run_id=dr.run_id)
+        rendered_op_kwargs = ti.render_templates().op_kwargs
+        assert rendered_op_kwargs["an_int"] == 4
+        assert rendered_op_kwargs["a_date"] == date(2019, 1, 1)
+        assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id} ran on {self.ds_templated}."
 
     def test_manual_task_id(self):
         """Test manually setting task_id"""
@@ -415,6 +345,7 @@ class TestAirflowTaskDecorator:
         def do_run():
             return 4
 
+        self.dag.default_args["owner"] = "airflow"
         with self.dag:
             ret = do_run()
         assert ret.operator.owner == "airflow"
diff --git a/tests/decorators/test_python_virtualenv.py b/tests/decorators/test_python_virtualenv.py
index 032ec34aa5..88121c5db3 100644
--- a/tests/decorators/test_python_virtualenv.py
+++ b/tests/decorators/test_python_virtualenv.py
@@ -19,7 +19,6 @@ from __future__ import annotations
 
 import datetime
 import sys
-from datetime import timedelta
 from subprocess import CalledProcessError
 
 import pytest
@@ -28,18 +27,6 @@ from airflow.decorators import task
 from airflow.utils import timezone
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-END_DATE = timezone.datetime(2016, 1, 2)
-INTERVAL = timedelta(hours=12)
-FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1)
-
-TI_CONTEXT_ENV_VARS = [
-    "AIRFLOW_CTX_DAG_ID",
-    "AIRFLOW_CTX_TASK_ID",
-    "AIRFLOW_CTX_EXECUTION_DATE",
-    "AIRFLOW_CTX_DAG_RUN_ID",
-]
-
-
 PYTHON_VERSION = sys.version_info[0]
 
 
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index c011c5cb35..8f6c089f08 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -20,16 +20,18 @@ from __future__ import annotations
 import copy
 import logging
 import os
+import re
 import sys
-import unittest.mock
 import warnings
 from collections import namedtuple
 from datetime import date, datetime, timedelta
 from subprocess import CalledProcessError
+from unittest import mock
 
 import pytest
+from slugify import slugify
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
 from airflow.models import DAG, DagRun, TaskInstance as TI
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.taskinstance import clear_task_instances, set_current_context
@@ -45,88 +47,108 @@ from airflow.utils import timezone
 from airflow.utils.context import AirflowContextDeprecationWarning, Context
 from airflow.utils.python_virtualenv import prepare_virtualenv
 from airflow.utils.session import create_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.trigger_rule import TriggerRule
-from airflow.utils.types import DagRunType
+from airflow.utils.types import NOTSET, DagRunType
 from tests.test_utils import AIRFLOW_MAIN_FOLDER
 from tests.test_utils.db import clear_db_runs
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-END_DATE = timezone.datetime(2016, 1, 2)
-INTERVAL = timedelta(hours=12)
-FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1)
-
-TI_CONTEXT_ENV_VARS = [
-    "AIRFLOW_CTX_DAG_ID",
-    "AIRFLOW_CTX_TASK_ID",
-    "AIRFLOW_CTX_EXECUTION_DATE",
-    "AIRFLOW_CTX_DAG_RUN_ID",
-]
-
 TEMPLATE_SEARCHPATH = os.path.join(AIRFLOW_MAIN_FOLDER, "tests", "config_templates")
+LOGGER_NAME = "airflow.task.operators"
 
 
-class Call:
-    def __init__(self, *args, **kwargs):
-        self.args = args
-        self.kwargs = kwargs
-
-
-def build_recording_function(calls_collection):
-    """
-    We can not use a Mock instance as a PythonOperator callable function or some tests fail with a
-    TypeError: Object of type Mock is not JSON serializable
-    Then using this custom function recording custom Call objects for further testing
-    (replacing Mock.assert_called_with assertion method)
-    """
-
-    def recording_function(*args, **kwargs):
-        calls_collection.append(Call(*args, **kwargs))
-
-    return recording_function
-
+class BasePythonTest:
+    """Base test class for TestPythonOperator and TestPythonSensor classes"""
 
-def assert_calls_equal(first: Call, second: Call) -> None:
-    assert isinstance(first, Call)
-    assert isinstance(second, Call)
-    assert first.args == second.args
-    # eliminate context (conf, dag_run, task_instance, etc.)
-    test_args = ["an_int", "a_date", "a_templated_string"]
-    first.kwargs = {key: value for (key, value) in first.kwargs.items() if key in test_args}
-    second.kwargs = {key: value for (key, value) in second.kwargs.items() if key in test_args}
-    assert first.kwargs == second.kwargs
+    opcls: type[BaseOperator]
+    dag_id: str
+    task_id: str
+    run_id: str
+    dag: DAG
+    ds_templated: str
+    default_date: datetime = DEFAULT_DATE
+
+    @pytest.fixture(autouse=True)
+    def base_tests_setup(self, request, create_task_instance_of_operator, dag_maker):
+        self.dag_id = f"dag_{slugify(request.cls.__name__)}"
+        self.task_id = f"task_{slugify(request.node.name, max_length=40)}"
+        self.run_id = f"run_{slugify(request.node.name, max_length=40)}"
+        self.ds_templated = self.default_date.date().isoformat()
+        self.ti_maker = create_task_instance_of_operator
+        self.dag_maker = dag_maker
+        self.dag = self.dag_maker(self.dag_id, template_searchpath=TEMPLATE_SEARCHPATH).dag
+        clear_db_runs()
+        yield
+        clear_db_runs()
+
+    @staticmethod
+    def assert_expected_task_states(dag_run: DagRun, expected_states: dict):
+        """Helper function that asserts `TaskInstances` of a given `task_id` are in a given state."""
+        asserts = []
+        for ti in dag_run.get_task_instances():
+            try:
+                expected = expected_states[ti.task_id]
+            except KeyError:
+                asserts.append(f"Unexpected task id {ti.task_id!r} found, expected {expected_states.keys()}")
+                continue
+
+            if ti.state != expected:
+                asserts.append(f"Task {ti.task_id!r} has state {ti.state!r} instead of expected {expected!r}")
+        if asserts:
+            pytest.fail("\n".join(asserts))
+
+    @staticmethod
+    def default_kwargs(**kwargs):
+        """Default arguments for specific Operator."""
+        return kwargs
+
+    def create_dag_run(self) -> DagRun:
+        return self.dag.create_dagrun(
+            state=DagRunState.RUNNING,
+            start_date=self.dag_maker.start_date,
+            session=self.dag_maker.session,
+            execution_date=self.default_date,
+            run_type=DagRunType.MANUAL,
+        )
 
+    def create_ti(self, fn, **kwargs) -> TI:
+        """Create TaskInstance for class defined Operator."""
+        return self.ti_maker(
+            self.opcls,
+            python_callable=fn,
+            **self.default_kwargs(**kwargs),
+            dag_id=self.dag_id,
+            task_id=self.task_id,
+            execution_date=self.default_date,
+        )
 
-class TestPythonBase(unittest.TestCase):
-    """Base test class for TestPythonOperator and TestPythonSensor classes"""
+    def run_as_operator(self, fn, **kwargs):
+        """Run task by direct call ``run`` method."""
+        with self.dag:
+            task = self.opcls(task_id=self.task_id, python_callable=fn, **self.default_kwargs(**kwargs))
 
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
+        task.run(start_date=self.default_date, end_date=self.default_date)
+        return task
 
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
+    def run_as_task(self, fn, **kwargs):
+        """Create TaskInstance and run it."""
+        ti = self.create_ti(fn, **kwargs)
+        ti.run()
+        return ti.task
 
-    def setUp(self):
-        super().setUp()
-        self.dag = DAG("test_dag", default_args={"owner": "airflow", "start_date": DEFAULT_DATE})
-        self.addCleanup(self.dag.clear)
-        self.clear_run()
-        self.addCleanup(self.clear_run)
+    def render_templates(self, fn, **kwargs):
+        """Create TaskInstance and render templates without actual run."""
+        return self.create_ti(fn, **kwargs).render_templates()
 
-    def tearDown(self):
-        super().tearDown()
 
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
+class TestPythonOperator(BasePythonTest):
+    opcls = PythonOperator
 
-    def clear_run(self):
+    @pytest.fixture(autouse=True)
+    def setup_tests(self):
         self.run = False
 
-
-class TestPythonOperator(TestPythonBase):
     def do_run(self):
         self.run = True
 
@@ -135,105 +157,58 @@ class TestPythonOperator(TestPythonBase):
 
     def test_python_operator_run(self):
         """Tests that the python callable is invoked on task run."""
-        task = PythonOperator(python_callable=self.do_run, task_id="python_operator", dag=self.dag)
+        ti = self.create_ti(self.do_run)
         assert not self.is_run()
-        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        ti.run()
         assert self.is_run()
 
-    def test_python_operator_python_callable_is_callable(self):
-        """Tests that PythonOperator will only instantiate if
-        the python_callable argument is callable."""
-        not_callable = {}
-        with pytest.raises(AirflowException):
-            PythonOperator(python_callable=not_callable, task_id="python_operator", dag=self.dag)
-        not_callable = None
-        with pytest.raises(AirflowException):
-            PythonOperator(python_callable=not_callable, task_id="python_operator", dag=self.dag)
+    @pytest.mark.parametrize("not_callable", [{}, None])
+    def test_python_operator_python_callable_is_callable(self, not_callable):
+        """Tests that PythonOperator will only instantiate if the python_callable argument is callable."""
+        with pytest.raises(AirflowException, match="`python_callable` param must be callable"):
+            PythonOperator(python_callable=not_callable, task_id="python_operator")
 
     def test_python_callable_arguments_are_templatized(self):
         """Test PythonOperator op_args are templatized"""
-        recorded_calls = []
-
         # Create a named tuple and ensure it is still preserved
         # after the rendering is done
         Named = namedtuple("Named", ["var1", "var2"])
         named_tuple = Named("{{ ds }}", "unchanged")
 
-        task = PythonOperator(
-            task_id="python_operator",
-            # a Mock instance cannot be used as a callable function or test fails with a
-            # TypeError: Object of type Mock is not JSON serializable
-            python_callable=build_recording_function(recorded_calls),
+        task = self.render_templates(
+            lambda: 0,
             op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple],
-            dag=self.dag,
-        )
-
-        self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        ds_templated = DEFAULT_DATE.date().isoformat()
-        assert 1 == len(recorded_calls)
-        assert_calls_equal(
-            recorded_calls[0],
-            Call(
-                4,
-                date(2019, 1, 1),
-                f"dag {self.dag.dag_id} ran on {ds_templated}.",
-                Named(ds_templated, "unchanged"),
-            ),
         )
+        rendered_op_args = task.op_args
+        assert len(rendered_op_args) == 4
+        assert rendered_op_args[0] == 4
+        assert rendered_op_args[1] == date(2019, 1, 1)
+        assert rendered_op_args[2] == f"dag {self.dag_id} ran on {self.ds_templated}."
+        assert rendered_op_args[3] == Named(self.ds_templated, "unchanged")
 
     def test_python_callable_keyword_arguments_are_templatized(self):
         """Test PythonOperator op_kwargs are templatized"""
-        recorded_calls = []
-
-        task = PythonOperator(
-            task_id="python_operator",
-            # a Mock instance cannot be used as a callable function or test fails with a
-            # TypeError: Object of type Mock is not JSON serializable
-            python_callable=build_recording_function(recorded_calls),
+        task = self.render_templates(
+            lambda: 0,
             op_kwargs={
                 "an_int": 4,
                 "a_date": date(2019, 1, 1),
                 "a_templated_string": "dag {{dag.dag_id}} ran on {{ds}}.",
             },
-            dag=self.dag,
-        )
-
-        self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        assert 1 == len(recorded_calls)
-        assert_calls_equal(
-            recorded_calls[0],
-            Call(
-                an_int=4,
-                a_date=date(2019, 1, 1),
-                a_templated_string=f"dag {self.dag.dag_id} ran on {DEFAULT_DATE.date().isoformat()}.",
-            ),
         )
+        rendered_op_kwargs = task.op_kwargs
+        assert rendered_op_kwargs["an_int"] == 4
+        assert rendered_op_kwargs["a_date"] == date(2019, 1, 1)
+        assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id} ran on {self.ds_templated}."
 
     def test_python_operator_shallow_copy_attr(self):
         def not_callable(x):
-            return x
+            assert False, "Should not be triggered"
 
         original_task = PythonOperator(
             python_callable=not_callable,
-            task_id="python_operator",
             op_kwargs={"certain_attrs": ""},
-            dag=self.dag,
+            task_id=self.task_id,
         )
         new_task = copy.deepcopy(original_task)
         # shallow copy op_kwargs
@@ -242,383 +217,213 @@ class TestPythonOperator(TestPythonBase):
         assert id(original_task.python_callable) == id(new_task.python_callable)
 
     def test_conflicting_kwargs(self):
-        self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-            external_trigger=False,
-        )
-
         # dag is not allowed since it is a reserved keyword
         def func(dag):
-            # An ValueError should be triggered since we're using dag as a
-            # reserved keyword
+            # An ValueError should be triggered since we're using dag as a reserved keyword
             raise RuntimeError(f"Should not be triggered, dag: {dag}")
 
-        python_operator = PythonOperator(
-            task_id="python_operator", op_args=[1], python_callable=func, dag=self.dag
-        )
-
-        with pytest.raises(ValueError) as ctx:
-            python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        assert "dag" in str(ctx.value), "'dag' not found in the exception"
+        ti = self.create_ti(func, op_args=[1])
+        error_message = re.escape("The key 'dag' in args is a part of kwargs and therefore reserved.")
+        with pytest.raises(ValueError, match=error_message):
+            ti.run()
 
     def test_provide_context_does_not_fail(self):
-        """
-        ensures that provide_context doesn't break dags in 2.0
-        """
-        self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-            external_trigger=False,
-        )
+        """Ensures that provide_context doesn't break dags in 2.0."""
 
         def func(custom, dag):
             assert 1 == custom, "custom should be 1"
             assert dag is not None, "dag should be set"
 
-        python_operator = PythonOperator(
-            task_id="python_operator",
-            op_kwargs={"custom": 1},
-            python_callable=func,
-            provide_context=True,
-            dag=self.dag,
-        )
-        python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        with pytest.warns(RemovedInAirflow3Warning):
+            self.run_as_task(func, op_kwargs={"custom": 1}, provide_context=True)
 
     def test_context_with_conflicting_op_args(self):
-        self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-            external_trigger=False,
-        )
-
         def func(custom, dag):
             assert 1 == custom, "custom should be 1"
             assert dag is not None, "dag should be set"
 
-        python_operator = PythonOperator(
-            task_id="python_operator", op_kwargs={"custom": 1}, python_callable=func, dag=self.dag
-        )
-        python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        self.run_as_task(func, op_kwargs={"custom": 1})
 
     def test_context_with_kwargs(self):
-        self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-            external_trigger=False,
-        )
-
         def func(**context):
             # check if context is being set
             assert len(context) > 0, "Context has not been injected"
 
-        python_operator = PythonOperator(
-            task_id="python_operator", op_kwargs={"custom": 1}, python_callable=func, dag=self.dag
-        )
-        python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-    def test_return_value_log_with_show_return_value_in_logs_default(self):
-        self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-            external_trigger=False,
-        )
-
-        def func():
-            return "test_return_value"
-
-        python_operator = PythonOperator(task_id="python_operator", python_callable=func, dag=self.dag)
-
-        with self.assertLogs("airflow.task.operators", level=logging.INFO) as cm:
-            python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        self.run_as_task(func, op_kwargs={"custom": 1})
 
-        assert (
-            "INFO:airflow.task.operators:Done. Returned value was: test_return_value" in cm.output
-        ), "Return value should be shown"
-
-    def test_return_value_log_with_show_return_value_in_logs_false(self):
-        self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-            external_trigger=False,
-        )
+    @pytest.mark.parametrize(
+        "show_return_value_in_logs, should_shown",
+        [
+            pytest.param(NOTSET, True, id="default"),
+            pytest.param(True, True, id="show"),
+            pytest.param(False, False, id="hide"),
+        ],
+    )
+    def test_return_value_log(self, show_return_value_in_logs, should_shown, caplog):
+        caplog.set_level(logging.INFO, logger=LOGGER_NAME)
 
         def func():
             return "test_return_value"
 
-        python_operator = PythonOperator(
-            task_id="python_operator",
-            python_callable=func,
-            dag=self.dag,
-            show_return_value_in_logs=False,
-        )
-
-        with self.assertLogs("airflow.task.operators", level=logging.INFO) as cm:
-            python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        assert (
-            "INFO:airflow.task.operators:Done. Returned value was: test_return_value" not in cm.output
-        ), "Return value should not be shown"
-        assert (
-            "INFO:airflow.task.operators:Done. Returned value not shown" in cm.output
-        ), "Log message that the option is turned off should be shown"
+        if show_return_value_in_logs is NOTSET:
+            self.run_as_task(func)
+        else:
+            self.run_as_task(func, show_return_value_in_logs=show_return_value_in_logs)
 
+        if should_shown:
+            assert "Done. Returned value was: test_return_value" in caplog.messages
+            assert "Done. Returned value not shown" not in caplog.messages
+        else:
+            assert "Done. Returned value was: test_return_value" not in caplog.messages
+            assert "Done. Returned value not shown" in caplog.messages
 
-class TestBranchOperator(unittest.TestCase):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
-
-    def setUp(self):
-        self.dag = DAG(
-            "branch_operator_test",
-            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
-            schedule=INTERVAL,
-        )
 
-        self.branch_1 = EmptyOperator(task_id="branch_1", dag=self.dag)
-        self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag)
-        self.branch_3 = None
+class TestBranchOperator(BasePythonTest):
+    opcls = BranchPythonOperator
 
-    def tearDown(self):
-        super().tearDown()
-
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
+    @pytest.fixture(autouse=True)
+    def setup_tests(self):
+        self.branch_1 = EmptyOperator(task_id="branch_1")
+        self.branch_2 = EmptyOperator(task_id="branch_2")
 
     def test_with_dag_run(self):
-        branch_op = BranchPythonOperator(
-            task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1"
-        )
-
-        self.branch_1.set_upstream(branch_op)
-        self.branch_2.set_upstream(branch_op)
-        self.dag.clear()
+        with self.dag:
+            branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1")
+            branch_op >> [self.branch_1, self.branch_2]
 
-        dr = self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
+        dr = self.create_dag_run()
+        branch_op.run(start_date=self.default_date, end_date=self.default_date)
+        self.assert_expected_task_states(
+            dr, {self.task_id: State.SUCCESS, "branch_1": State.NONE, "branch_2": State.SKIPPED}
         )
 
-        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        tis = dr.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_1":
-                assert ti.state == State.NONE
-            elif ti.task_id == "branch_2":
-                assert ti.state == State.SKIPPED
-            else:
-                raise ValueError(f"Invalid task id {ti.task_id} found!")
-
     def test_with_skip_in_branch_downstream_dependencies(self):
-        branch_op = BranchPythonOperator(
-            task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1"
-        )
-
-        branch_op >> self.branch_1 >> self.branch_2
-        branch_op >> self.branch_2
-        self.dag.clear()
+        with self.dag:
+            branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1")
+            branch_op >> self.branch_1 >> self.branch_2
+            branch_op >> self.branch_2
 
-        dr = self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
+        dr = self.create_dag_run()
+        branch_op.run(start_date=self.default_date, end_date=self.default_date)
+        self.assert_expected_task_states(
+            dr, {self.task_id: State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE}
         )
 
-        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        tis = dr.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_1":
-                assert ti.state == State.NONE
-            elif ti.task_id == "branch_2":
-                assert ti.state == State.NONE
-            else:
-                raise ValueError(f"Invalid task id {ti.task_id} found!")
-
     def test_with_skip_in_branch_downstream_dependencies2(self):
-        branch_op = BranchPythonOperator(
-            task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_2"
-        )
-
-        branch_op >> self.branch_1 >> self.branch_2
-        branch_op >> self.branch_2
-        self.dag.clear()
+        with self.dag:
+            branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_2")
+            branch_op >> self.branch_1 >> self.branch_2
+            branch_op >> self.branch_2
 
-        dr = self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
+        dr = self.create_dag_run()
+        branch_op.run(start_date=self.default_date, end_date=self.default_date)
+        self.assert_expected_task_states(
+            dr, {self.task_id: State.SUCCESS, "branch_1": State.SKIPPED, "branch_2": State.NONE}
         )
 
-        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        tis = dr.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_1":
-                assert ti.state == State.SKIPPED
-            elif ti.task_id == "branch_2":
-                assert ti.state == State.NONE
-            else:
-                raise ValueError(f"Invalid task id {ti.task_id} found!")
-
     def test_xcom_push(self):
-        branch_op = BranchPythonOperator(
-            task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1"
-        )
-
-        self.branch_1.set_upstream(branch_op)
-        self.branch_2.set_upstream(branch_op)
-        self.dag.clear()
-
-        dr = self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
-        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        with self.dag:
+            branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1")
+            branch_op >> [self.branch_1, self.branch_2]
 
-        tis = dr.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.xcom_pull(task_ids="make_choice") == "branch_1"
+        dr = self.create_dag_run()
+        branch_op.run(start_date=self.default_date, end_date=self.default_date)
+        for ti in dr.get_task_instances():
+            if ti.task_id == self.task_id:
+                assert ti.xcom_pull(task_ids=self.task_id) == "branch_1"
+                break
+        else:
+            pytest.fail(f"{self.task_id!r} not found.")
 
     def test_clear_skipped_downstream_task(self):
         """
         After a downstream task is skipped by BranchPythonOperator, clearing the skipped task
         should not cause it to be executed.
         """
-        branch_op = BranchPythonOperator(
-            task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1"
-        )
-        branches = [self.branch_1, self.branch_2]
-        branch_op >> branches
-        self.dag.clear()
-
-        dr = self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
-        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        with self.dag:
+            branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1")
+            branches = [self.branch_1, self.branch_2]
+            branch_op >> branches
 
+        dr = self.create_dag_run()
+        branch_op.run(start_date=self.default_date, end_date=self.default_date)
         for task in branches:
-            task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+            task.run(start_date=self.default_date, end_date=self.default_date)
 
-        tis = dr.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_1":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_2":
-                assert ti.state == State.SKIPPED
-            else:
-                raise ValueError(f"Invalid task id {ti.task_id} found!")
+        expected_states = {
+            self.task_id: State.SUCCESS,
+            "branch_1": State.SUCCESS,
+            "branch_2": State.SKIPPED,
+        }
 
-        children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()]
+        self.assert_expected_task_states(dr, expected_states)
 
         # Clear the children tasks.
+        tis = dr.get_task_instances()
+        children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()]
         with create_session() as session:
-            clear_task_instances(children_tis, session=session, dag=self.dag)
+            clear_task_instances(children_tis, session=session, dag=branch_op.dag)
 
         # Run the cleared tasks again.
         for task in branches:
-            task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+            task.run(start_date=self.default_date, end_date=self.default_date)
 
         # Check if the states are correct after children tasks are cleared.
-        for ti in dr.get_task_instances():
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_1":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_2":
-                assert ti.state == State.SKIPPED
-            else:
-                raise ValueError(f"Invalid task id {ti.task_id} found!")
+        self.assert_expected_task_states(dr, expected_states)
 
     def test_raise_exception_on_no_accepted_type_return(self):
-        branch_op = BranchPythonOperator(task_id="make_choice", dag=self.dag, python_callable=lambda: 5)
-        self.dag.clear()
-        with pytest.raises(AirflowException) as ctx:
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        assert "must be either None, a task ID, or an Iterable of IDs" in str(ctx.value)
+        ti = self.create_ti(lambda: 5)
+        with pytest.raises(AirflowException, match="must be either None, a task ID, or an Iterable of IDs"):
+            ti.run()
 
     def test_raise_exception_on_invalid_task_id(self):
-        branch_op = BranchPythonOperator(
-            task_id="make_choice", dag=self.dag, python_callable=lambda: "some_task_id"
-        )
-        self.dag.clear()
-        with pytest.raises(AirflowException) as ctx:
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        assert "Invalid tasks found: {'some_task_id'}" in str(ctx.value)
+        ti = self.create_ti(lambda: "some_task_id")
+        with pytest.raises(AirflowException, match="Invalid tasks found: {'some_task_id'}"):
+            ti.run()
 
+    @pytest.mark.parametrize(
+        "choice,expected_states",
+        [
+            ("task1", [State.SUCCESS, State.SUCCESS, State.SUCCESS]),
+            ("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]),
+        ],
+    )
+    def test_empty_branch(self, choice, expected_states):
+        """
+        Tests that BranchPythonOperator handles empty branches properly.
+        """
+        with self.dag:
+            branch = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: choice)
+            task1 = EmptyOperator(task_id="task1")
+            join = EmptyOperator(task_id="join", trigger_rule="none_failed_min_one_success")
 
-class TestShortCircuitOperator:
-    def setup(self):
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
+            branch >> [task1, join]
+            task1 >> join
 
-        self.dag = DAG(
-            "short_circuit_op_test",
-            start_date=DEFAULT_DATE,
-            schedule=INTERVAL,
-        )
+        dr = self.create_dag_run()
+        task_ids = [self.task_id, "task1", "join"]
+        tis = {ti.task_id: ti for ti in dr.task_instances}
 
-        with self.dag:
-            self.op1 = EmptyOperator(task_id="op1")
-            self.op2 = EmptyOperator(task_id="op2")
-            self.op1.set_downstream(self.op2)
+        for task_id in task_ids:  # Mimic the specific order the scheduling would run the tests.
+            task_instance = tis[task_id]
+            task_instance.refresh_from_task(self.dag.get_task(task_id))
+            task_instance.run()
 
-    def teardown(self):
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
+        def get_state(ti):
+            ti.refresh_from_db()
+            return ti.state
 
-    def _assert_expected_task_states(self, dagrun, expected_states):
-        """Helper function that asserts `TaskInstances` of a given `task_id` are in a given state."""
+        assert [get_state(tis[task_id]) for task_id in task_ids] == expected_states
 
-        tis = dagrun.get_task_instances()
-        for ti in tis:
-            try:
-                expected_state = expected_states[ti.task_id]
-            except KeyError:
-                raise ValueError(f"Invalid task id {ti.task_id} found!")
-            else:
-                assert ti.state == expected_state
+
+class TestShortCircuitOperator(BasePythonTest):
+    opcls = ShortCircuitOperator
+
+    @pytest.fixture(autouse=True)
+    def setup_tests(self):
+        self.task_id = "short_circuit"
+        self.op1 = EmptyOperator(task_id="op1")
+        self.op2 = EmptyOperator(task_id="op2")
 
     all_downstream_skipped_states = {
         "short_circuit": State.SUCCESS,
@@ -725,62 +530,41 @@ class TestShortCircuitOperator:
         Checking the behavior of the ShortCircuitOperator in several scenarios enabling/disabling the skipping
         of downstream tasks, both short-circuiting modes, and various trigger rules of downstream tasks.
         """
-
-        self.short_circuit = ShortCircuitOperator(
-            task_id="short_circuit",
-            python_callable=lambda: callable_return,
-            ignore_downstream_trigger_rules=test_ignore_downstream_trigger_rules,
-            dag=self.dag,
-        )
-        self.short_circuit.set_downstream(self.op1)
-        self.op2.trigger_rule = test_trigger_rule
-        self.dag.clear()
-
-        dagrun = self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
-        self.short_circuit.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        self.op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        self.op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        assert self.short_circuit.ignore_downstream_trigger_rules == test_ignore_downstream_trigger_rules
-        assert self.short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS
+        with self.dag:
+            short_circuit = ShortCircuitOperator(
+                task_id="short_circuit",
+                python_callable=lambda: callable_return,
+                ignore_downstream_trigger_rules=test_ignore_downstream_trigger_rules,
+            )
+            short_circuit >> self.op1 >> self.op2
+            self.op2.trigger_rule = test_trigger_rule
+
+        dr = self.create_dag_run()
+        short_circuit.run(start_date=self.default_date, end_date=self.default_date)
+        self.op1.run(start_date=self.default_date, end_date=self.default_date)
+        self.op2.run(start_date=self.default_date, end_date=self.default_date)
+
+        assert short_circuit.ignore_downstream_trigger_rules == test_ignore_downstream_trigger_rules
+        assert short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS
         assert self.op1.trigger_rule == TriggerRule.ALL_SUCCESS
         assert self.op2.trigger_rule == test_trigger_rule
-
-        self._assert_expected_task_states(dagrun, expected_task_states)
+        self.assert_expected_task_states(dr, expected_task_states)
 
     def test_clear_skipped_downstream_task(self):
         """
         After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task
         should not cause it to be executed.
         """
-
-        self.short_circuit = ShortCircuitOperator(
-            task_id="short_circuit",
-            python_callable=lambda: False,
-            dag=self.dag,
-        )
-        self.short_circuit.set_downstream(self.op1)
-        self.dag.clear()
-
-        dagrun = self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-
-        self.short_circuit.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        self.op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        self.op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        assert self.short_circuit.ignore_downstream_trigger_rules
-        assert self.short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS
+        with self.dag:
+            short_circuit = ShortCircuitOperator(task_id="short_circuit", python_callable=lambda: False)
+            short_circuit >> self.op1 >> self.op2
+        dr = self.create_dag_run()
+
+        short_circuit.run(start_date=self.default_date, end_date=self.default_date)
+        self.op1.run(start_date=self.default_date, end_date=self.default_date)
+        self.op2.run(start_date=self.default_date, end_date=self.default_date)
+        assert short_circuit.ignore_downstream_trigger_rules
+        assert short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS
         assert self.op1.trigger_rule == TriggerRule.ALL_SUCCESS
         assert self.op2.trigger_rule == TriggerRule.ALL_SUCCESS
 
@@ -789,82 +573,45 @@ class TestShortCircuitOperator:
             "op1": State.SKIPPED,
             "op2": State.SKIPPED,
         }
-        self._assert_expected_task_states(dagrun, expected_states)
+        self.assert_expected_task_states(dr, expected_states)
 
         # Clear downstream task "op1" that was previously executed.
-        tis = dagrun.get_task_instances()
-
+        tis = dr.get_task_instances()
         with create_session() as session:
-            clear_task_instances([ti for ti in tis if ti.task_id == "op1"], session=session, dag=self.dag)
-
-        self.op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        self._assert_expected_task_states(dagrun, expected_states)
+            clear_task_instances(
+                [ti for ti in tis if ti.task_id == "op1"], session=session, dag=short_circuit.dag
+            )
+        self.op1.run(start_date=self.default_date, end_date=self.default_date)
+        self.assert_expected_task_states(dr, expected_states)
 
     def test_xcom_push(self):
-        short_op_push_xcom = ShortCircuitOperator(
-            task_id="push_xcom_from_shortcircuit", dag=self.dag, python_callable=lambda: "signature"
-        )
-
-        short_op_no_push_xcom = ShortCircuitOperator(
-            task_id="do_not_push_xcom_from_shortcircuit", dag=self.dag, python_callable=lambda: False
-        )
-
-        self.dag.clear()
-        dr = self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
+        with self.dag:
+            short_op_push_xcom = ShortCircuitOperator(
+                task_id="push_xcom_from_shortcircuit", python_callable=lambda: "signature"
+            )
+            short_op_no_push_xcom = ShortCircuitOperator(
+                task_id="do_not_push_xcom_from_shortcircuit", python_callable=lambda: False
+            )
 
-        short_op_push_xcom.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        short_op_no_push_xcom.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        dr = self.create_dag_run()
+        short_op_push_xcom.run(start_date=self.default_date, end_date=self.default_date)
+        short_op_no_push_xcom.run(start_date=self.default_date, end_date=self.default_date)
 
         tis = dr.get_task_instances()
-        xcom_value_short_op_push_xcom = tis[0].xcom_pull(
-            task_ids="push_xcom_from_shortcircuit", key="return_value"
-        )
-        assert xcom_value_short_op_push_xcom == "signature"
-
-        xcom_value_short_op_no_push_xcom = tis[0].xcom_pull(
-            task_ids="do_not_push_xcom_from_shortcircuit", key="return_value"
-        )
-        assert xcom_value_short_op_no_push_xcom is None
+        assert tis[0].xcom_pull(task_ids=short_op_push_xcom.task_id, key="return_value") == "signature"
+        assert tis[0].xcom_pull(task_ids=short_op_no_push_xcom.task_id, key="return_value") is None
 
 
 virtualenv_string_args: list[str] = []
 
 
-class TestPythonVirtualenvOperator(unittest.TestCase):
-    def setUp(self):
-        super().setUp()
-        self.dag = DAG(
-            "test_dag",
-            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
-            template_searchpath=TEMPLATE_SEARCHPATH,
-            schedule=INTERVAL,
-        )
-        self.dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            start_date=timezone.utcnow(),
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-        self.addCleanup(self.dag.clear)
-
-    def tearDown(self):
-        super().tearDown()
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
-
-    def _run_as_operator(self, fn, python_version=sys.version_info[0], **kwargs):
+class TestPythonVirtualenvOperator(BasePythonTest):
+    opcls = PythonVirtualenvOperator
 
-        task = PythonVirtualenvOperator(
-            python_callable=fn, python_version=python_version, task_id="task", dag=self.dag, **kwargs
-        )
-        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        return task
+    @staticmethod
+    def default_kwargs(*, python_version=sys.version_info[0], **kwargs):
+        kwargs["python_version"] = python_version
+        return kwargs
 
     def test_template_fields(self):
         assert set(PythonOperator.template_fields).issubset(PythonVirtualenvOperator.template_fields)
@@ -874,7 +621,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
             """Ensure dill is correctly installed."""
             import dill  # noqa: F401
 
-        self._run_as_operator(f, use_dill=True, system_site_packages=False)
+        self.run_as_task(f, use_dill=True, system_site_packages=False)
 
     def test_no_requirements(self):
         """Tests that the python callable is invoked on task run."""
@@ -882,7 +629,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
         def f():
             pass
 
-        self._run_as_operator(f)
+        self.run_as_task(f)
 
     def test_no_system_site_packages(self):
         def f():
@@ -892,13 +639,13 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
                 return True
             raise Exception
 
-        self._run_as_operator(f, system_site_packages=False, requirements=["dill"])
+        self.run_as_task(f, system_site_packages=False, requirements=["dill"])
 
     def test_system_site_packages(self):
         def f():
             import funcsigs  # noqa: F401
 
-        self._run_as_operator(f, requirements=["funcsigs"], system_site_packages=True)
+        self.run_as_task(f, requirements=["funcsigs"], system_site_packages=True)
 
     def test_with_requirements_pinned(self):
         def f():
@@ -907,44 +654,44 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
             if funcsigs.__version__ != "0.4":
                 raise Exception
 
-        self._run_as_operator(f, requirements=["funcsigs==0.4"])
+        self.run_as_task(f, requirements=["funcsigs==0.4"])
 
     def test_unpinned_requirements(self):
         def f():
             import funcsigs  # noqa: F401
 
-        self._run_as_operator(f, requirements=["funcsigs", "dill"], system_site_packages=False)
+        self.run_as_task(f, requirements=["funcsigs", "dill"], system_site_packages=False)
 
     def test_range_requirements(self):
         def f():
             import funcsigs  # noqa: F401
 
-        self._run_as_operator(f, requirements=["funcsigs>1.0", "dill"], system_site_packages=False)
+        self.run_as_task(f, requirements=["funcsigs>1.0", "dill"], system_site_packages=False)
 
     def test_requirements_file(self):
         def f():
             import funcsigs  # noqa: F401
 
-        self._run_as_operator(f, requirements="requirements.txt", system_site_packages=False)
+        self.run_as_operator(f, requirements="requirements.txt", system_site_packages=False)
 
-    @unittest.mock.patch("airflow.operators.python.prepare_virtualenv")
+    @mock.patch("airflow.operators.python.prepare_virtualenv")
     def test_pip_install_options(self, mocked_prepare_virtualenv):
         def f():
             import funcsigs  # noqa: F401
 
         mocked_prepare_virtualenv.side_effect = prepare_virtualenv
 
-        self._run_as_operator(
+        self.run_as_task(
             f,
             requirements=["funcsigs==0.4"],
             system_site_packages=False,
             pip_install_options=["--no-deps"],
         )
         mocked_prepare_virtualenv.assert_called_with(
-            venv_directory=unittest.mock.ANY,
-            python_bin=unittest.mock.ANY,
+            venv_directory=mock.ANY,
+            python_bin=mock.ANY,
             system_site_packages=False,
-            requirements_file_path=unittest.mock.ANY,
+            requirements_file_path=mock.ANY,
             pip_install_options=["--no-deps"],
         )
 
@@ -954,7 +701,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
 
             assert funcsigs.__version__ == "1.0.2"
 
-        self._run_as_operator(
+        self.run_as_operator(
             f,
             requirements="requirements.txt",
             use_dill=True,
@@ -967,7 +714,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
             raise Exception
 
         with pytest.raises(CalledProcessError):
-            self._run_as_operator(f)
+            self.run_as_task(f)
 
     def test_python_3(self):
         def f():
@@ -980,13 +727,13 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
                 return
             raise Exception
 
-        self._run_as_operator(f, python_version=3, use_dill=False, requirements=["dill"])
+        self.run_as_task(f, python_version=3, use_dill=False, requirements=["dill"])
 
     def test_without_dill(self):
         def f(a):
             return a
 
-        self._run_as_operator(f, system_site_packages=False, use_dill=False, op_args=[4])
+        self.run_as_task(f, system_site_packages=False, use_dill=False, op_args=[4])
 
     def test_string_args(self):
         def f():
@@ -995,7 +742,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
             if virtualenv_string_args[0] != virtualenv_string_args[2]:
                 raise Exception
 
-        self._run_as_operator(f, string_args=[1, 2, 1])
+        self.run_as_task(f, string_args=[1, 2, 1])
 
     def test_with_args(self):
         def f(a, b, c=False, d=False):
@@ -1004,37 +751,38 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
             else:
                 raise Exception
 
-        self._run_as_operator(f, op_args=[0, 1], op_kwargs={"c": True})
+        self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True})
 
     def test_return_none(self):
         def f():
             return None
 
-        task = self._run_as_operator(f)
+        task = self.run_as_task(f)
         assert task.execute_callable() is None
 
     def test_return_false(self):
         def f():
             return False
 
-        task = self._run_as_operator(f)
+        task = self.run_as_task(f)
         assert task.execute_callable() is False
 
     def test_lambda(self):
         with pytest.raises(AirflowException):
-            PythonVirtualenvOperator(python_callable=lambda x: 4, task_id="task", dag=self.dag)
+            PythonVirtualenvOperator(python_callable=lambda x: 4, task_id=self.task_id)
 
     def test_nonimported_as_arg(self):
         def f(_):
             return None
 
-        self._run_as_operator(f, op_args=[datetime.utcnow()])
+        self.run_as_task(f, op_args=[datetime.utcnow()])
 
     def test_context(self):
         def f(templates_dict):
             return templates_dict["ds"]
 
-        self._run_as_operator(f, templates_dict={"ds": "{{ ds }}"})
+        task = self.run_as_task(f, templates_dict={"ds": "{{ ds }}"})
+        assert task.templates_dict == {"ds": self.ds_templated}
 
     # This tests might take longer than default 60 seconds as it is serializing a lot of
     # context using dill (which is slow apparently).
@@ -1078,7 +826,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
         ):
             pass
 
-        self._run_as_operator(f, use_dill=True, system_site_packages=True, requirements=None)
+        self.run_as_operator(f, use_dill=True, system_site_packages=True, requirements=None)
 
     @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
     def test_pendulum_context(self):
@@ -1112,7 +860,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
         ):
             pass
 
-        self._run_as_operator(f, use_dill=True, system_site_packages=False, requirements=["pendulum"])
+        self.run_as_task(f, use_dill=True, system_site_packages=False, requirements=["pendulum"])
 
     @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
     def test_base_context(self):
@@ -1140,7 +888,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
         ):
             pass
 
-        self._run_as_operator(f, use_dill=True, system_site_packages=False, requirements=None)
+        self.run_as_task(f, use_dill=True, system_site_packages=False, requirements=None)
 
     def test_deepcopy(self):
         """Test that PythonVirtualenvOperator are deep-copyable."""
@@ -1148,13 +896,37 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
         def f():
             return 1
 
-        task = PythonVirtualenvOperator(
-            python_callable=f,
-            task_id="task",
-            dag=self.dag,
-        )
+        task = PythonVirtualenvOperator(python_callable=f, task_id="task")
         copy.deepcopy(task)
 
+    def test_virtualenv_serializable_context_fields(self, create_task_instance):
+        """Ensure all template context fields are listed in the operator.
+
+        This exists mainly so when a field is added to the context, we remember to
+        also add it to PythonVirtualenvOperator.
+        """
+        # These are intentionally NOT serialized into the virtual environment:
+        # * Variables pointing to the task instance itself.
+        # * Variables that are accessor instances.
+        intentionally_excluded_context_keys = [
+            "task_instance",
+            "ti",
+            "var",  # Accessor for Variable; var->json and var->value.
+            "conn",  # Accessor for Connection.
+        ]
+
+        ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None)
+        context = ti.get_template_context()
+
+        declared_keys = {
+            *PythonVirtualenvOperator.BASE_SERIALIZABLE_CONTEXT_KEYS,
+            *PythonVirtualenvOperator.PENDULUM_SERIALIZABLE_CONTEXT_KEYS,
+            *PythonVirtualenvOperator.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS,
+            *intentionally_excluded_context_keys,
+        }
+
+        assert set(context) == declared_keys
+
 
 DEFAULT_ARGS = {
     "owner": "test",
@@ -1221,7 +993,7 @@ def get_all_the_context(**context):
         assert context == current_context._context
 
 
-@pytest.fixture()
+@pytest.fixture
 def clear_db():
     clear_db_runs()
     yield
@@ -1239,76 +1011,3 @@ class TestCurrentContextRuntime:
         with DAG(dag_id="edge_case_context_dag", default_args=DEFAULT_ARGS):
             op = PythonOperator(python_callable=get_all_the_context, task_id="get_all_the_context")
             op.run(ignore_first_depends_on_past=True, ignore_ti_state=True)
-
-
-@pytest.mark.parametrize(
-    "choice,expected_states",
-    [
-        ("task1", [State.SUCCESS, State.SUCCESS, State.SUCCESS]),
-        ("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]),
-    ],
-)
-def test_empty_branch(dag_maker, choice, expected_states):
-    """
-    Tests that BranchPythonOperator handles empty branches properly.
-    """
-    with dag_maker(
-        "test_empty_branch",
-        start_date=DEFAULT_DATE,
-    ) as dag:
-        branch = BranchPythonOperator(task_id="branch", python_callable=lambda: choice)
-        task1 = EmptyOperator(task_id="task1")
-        join = EmptyOperator(task_id="join", trigger_rule="none_failed_min_one_success")
-
-        branch >> [task1, join]
-        task1 >> join
-
-    dag.clear(start_date=DEFAULT_DATE)
-    dag_run = dag_maker.create_dagrun()
-
-    task_ids = ["branch", "task1", "join"]
-    tis = {ti.task_id: ti for ti in dag_run.task_instances}
-
-    for task_id in task_ids:  # Mimic the specific order the scheduling would run the tests.
-        task_instance = tis[task_id]
-        task_instance.refresh_from_task(dag.get_task(task_id))
-        task_instance.run()
-
-    def get_state(ti):
-        ti.refresh_from_db()
-        return ti.state
-
-    assert [get_state(tis[task_id]) for task_id in task_ids] == expected_states
-
-
-def test_virtualenv_serializable_context_fields(create_task_instance):
-    """Ensure all template context fields are listed in the operator.
-
-    This exists mainly so when a field is added to the context, we remember to
-    also add it to PythonVirtualenvOperator.
-    """
-    # These are intentionally NOT serialized into the virtual environment:
-    # * Variables pointing to the task instance itself.
-    # * Variables that are accessor instances.
-    intentionally_excluded_context_keys = [
-        "task_instance",
-        "ti",
-        "var",  # Accessor for Variable; var->json and var->value.
-        "conn",  # Accessor for Connection.
-    ]
-
-    ti = create_task_instance(
-        dag_id="test_virtualenv_serializable_context_fields",
-        task_id="test_virtualenv_serializable_context_fields_task",
-        schedule=None,
-    )
-    context = ti.get_template_context()
-
-    declared_keys = {
-        *PythonVirtualenvOperator.BASE_SERIALIZABLE_CONTEXT_KEYS,
-        *PythonVirtualenvOperator.PENDULUM_SERIALIZABLE_CONTEXT_KEYS,
-        *PythonVirtualenvOperator.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS,
-        *intentionally_excluded_context_keys,
-    }
-
-    assert set(context) == declared_keys
diff --git a/tests/sensors/test_python.py b/tests/sensors/test_python.py
index f3c258185d..73a0ddffe4 100644
--- a/tests/sensors/test_python.py
+++ b/tests/sensors/test_python.py
@@ -24,112 +24,52 @@ import pytest
 
 from airflow.exceptions import AirflowSensorTimeout
 from airflow.sensors.python import PythonSensor
-from airflow.utils.state import State
-from airflow.utils.timezone import datetime
-from airflow.utils.types import DagRunType
-from tests.operators.test_python import Call, assert_calls_equal, build_recording_function
+from tests.operators.test_python import BasePythonTest
 
-DEFAULT_DATE = datetime(2015, 1, 1)
 
+class TestPythonSensor(BasePythonTest):
+    opcls = PythonSensor
 
-class TestPythonSensor:
-    def test_python_sensor_true(self, dag_maker):
-        with dag_maker():
-            op = PythonSensor(task_id="python_sensor_check_true", python_callable=lambda: True)
-        op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+    def test_python_sensor_true(self):
+        self.run_as_task(fn=lambda: True)
 
-    def test_python_sensor_false(self, dag_maker):
-        with dag_maker():
-            op = PythonSensor(
-                task_id="python_sensor_check_false",
-                timeout=0.01,
-                poke_interval=0.01,
-                python_callable=lambda: False,
-            )
+    def test_python_sensor_false(self):
         with pytest.raises(AirflowSensorTimeout):
-            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+            self.run_as_task(lambda: False, timeout=0.01, poke_interval=0.01)
 
-    def test_python_sensor_raise(self, dag_maker):
-        with dag_maker():
-            op = PythonSensor(task_id="python_sensor_check_raise", python_callable=lambda: 1 / 0)
+    def test_python_sensor_raise(self):
         with pytest.raises(ZeroDivisionError):
-            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+            self.run_as_task(lambda: 1 / 0)
 
-    def test_python_callable_arguments_are_templatized(self, dag_maker):
+    def test_python_callable_arguments_are_templatized(self):
         """Test PythonSensor op_args are templatized"""
-        recorded_calls = []
-
         # Create a named tuple and ensure it is still preserved
         # after the rendering is done
         Named = namedtuple("Named", ["var1", "var2"])
         named_tuple = Named("{{ ds }}", "unchanged")
 
-        with dag_maker() as dag:
-            task = PythonSensor(
-                task_id="python_sensor",
-                timeout=0.01,
-                poke_interval=0.3,
-                # a Mock instance cannot be used as a callable function or test fails with a
-                # TypeError: Object of type Mock is not JSON serializable
-                python_callable=build_recording_function(recorded_calls),
-                op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple],
-            )
-
-        dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-        with pytest.raises(AirflowSensorTimeout):
-            task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        ds_templated = DEFAULT_DATE.date().isoformat()
-        assert_calls_equal(
-            recorded_calls[0],
-            Call(
-                4,
-                date(2019, 1, 1),
-                f"dag {dag.dag_id} ran on {ds_templated}.",
-                Named(ds_templated, "unchanged"),
-            ),
+        task = self.render_templates(
+            lambda: 0,
+            op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple],
         )
-
-    def test_python_callable_keyword_arguments_are_templatized(self, dag_maker):
+        rendered_op_args = task.op_args
+        assert len(rendered_op_args) == 4
+        assert rendered_op_args[0] == 4
+        assert rendered_op_args[1] == date(2019, 1, 1)
+        assert rendered_op_args[2] == f"dag {self.dag_id} ran on {self.ds_templated}."
+        assert rendered_op_args[3] == Named(self.ds_templated, "unchanged")
+
+    def test_python_callable_keyword_arguments_are_templatized(self):
         """Test PythonSensor op_kwargs are templatized"""
-        recorded_calls = []
-
-        with dag_maker() as dag:
-            task = PythonSensor(
-                task_id="python_sensor",
-                timeout=0.01,
-                poke_interval=0.01,
-                # a Mock instance cannot be used as a callable function or test fails with a
-                # TypeError: Object of type Mock is not JSON serializable
-                python_callable=build_recording_function(recorded_calls),
-                op_kwargs={
-                    "an_int": 4,
-                    "a_date": date(2019, 1, 1),
-                    "a_templated_string": "dag {{dag.dag_id}} ran on {{ds}}.",
-                },
-            )
-
-        dag.create_dagrun(
-            run_type=DagRunType.MANUAL,
-            execution_date=DEFAULT_DATE,
-            data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-            start_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-        with pytest.raises(AirflowSensorTimeout):
-            task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        assert_calls_equal(
-            recorded_calls[0],
-            Call(
-                an_int=4,
-                a_date=date(2019, 1, 1),
-                a_templated_string=f"dag {dag.dag_id} ran on {DEFAULT_DATE.date().isoformat()}.",
-            ),
+        task = self.render_templates(
+            lambda: 0,
+            op_kwargs={
+                "an_int": 4,
+                "a_date": date(2019, 1, 1),
+                "a_templated_string": "dag {{dag.dag_id}} ran on {{ds}}.",
+            },
         )
+        rendered_op_kwargs = task.op_kwargs
+        assert rendered_op_kwargs["an_int"] == 4
+        assert rendered_op_kwargs["a_date"] == date(2019, 1, 1)
+        assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id} ran on {self.ds_templated}."