You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/12/07 15:28:32 UTC

[airflow] branch main updated: Improve handling edge-cases in airlfow.models by applying mypy (#20000)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d8e3b8  Improve handling edge-cases in airlfow.models by applying mypy (#20000)
7d8e3b8 is described below

commit 7d8e3b828af0ac90261c341f5cb0e57da75e6a83
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Tue Dec 7 15:28:01 2021 +0000

    Improve handling edge-cases in airlfow.models by applying mypy (#20000)
    
    * Fix many of the mypy typing issues in airflow.models.dag
    
    And to fix these, I needed to fix a few other mistakes that are
    used/called by DAG's methods
    
    * Fix timetable-related typing errors in dag.py
    
    Also moved the sentinel value implementation to a utils module. This
    should be useful when fixing typing issues in other modules.
    
    * Add note about assert allowed inside a TYPE_CHECKING conditional
    
    * Fix docs build of airflow.models.dagrun
    
    * Apply NEW_SESSION to dag, dagrun, ti and operator.subdag
    
    Co-authored-by: Tzu-ping Chung <tp...@astronomer.io>
---
 CONTRIBUTING.rst                              |   8 +
 airflow/models/base.py                        |   2 +-
 airflow/models/baseoperator.py                |  42 ++--
 airflow/models/dag.py                         | 282 ++++++++++++++------------
 airflow/models/dagbag.py                      |  12 +-
 airflow/models/dagrun.py                      |  49 ++---
 airflow/models/serialized_dag.py              |   2 +-
 airflow/models/taskinstance.py                | 108 ++++++----
 airflow/models/variable.py                    |   2 +-
 airflow/operators/subdag.py                   |   6 +-
 airflow/serialization/serialized_objects.py   |   1 -
 airflow/settings.py                           |   4 +-
 airflow/timetables/base.py                    |   2 +-
 airflow/utils/context.pyi                     |   7 +
 airflow/utils/file.py                         |  14 +-
 airflow/utils/state.py                        |   1 +
 airflow/utils/timezone.py                     |  29 ++-
 airflow/utils/types.py                        |  19 ++
 tests/models/test_dag.py                      |   4 -
 tests/models/test_taskinstance.py             |   4 +-
 tests/serialization/test_dag_serialization.py |   6 +-
 21 files changed, 363 insertions(+), 241 deletions(-)

diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 1f00945..886ff27 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -936,6 +936,14 @@ you should do:
     if not some_predicate():
         handle_the_case()
 
+The one exception to this is if you need to make an assert for typechecking (which should be almost a last resort) you can do this:
+
+.. code-block:: python
+
+    if TYPE_CHECKING:
+        assert isinstance(x, MyClass)
+
+
 Database Session Handling
 -------------------------
 
diff --git a/airflow/models/base.py b/airflow/models/base.py
index 29a5320..439308d 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -63,4 +63,4 @@ def get_id_collation_args():
 
 COLLATION_ARGS = get_id_collation_args()
 
-StringID: Type[String] = functools.partial(String, length=ID_LEN, **COLLATION_ARGS)
+StringID: Type[String] = functools.partial(String, length=ID_LEN, **COLLATION_ARGS)  # type: ignore
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 86ec47e..88b1590 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -25,6 +25,7 @@ import warnings
 from abc import ABCMeta, abstractmethod
 from datetime import datetime, timedelta
 from inspect import signature
+from types import FunctionType
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -46,6 +47,7 @@ from typing import (
 
 import attr
 import jinja2
+import pendulum
 from dateutil.relativedelta import relativedelta
 from sqlalchemy.orm import Session
 from sqlalchemy.orm.exc import NoResultFound
@@ -87,7 +89,7 @@ TaskStateChangeCallback = Callable[[Context], None]
 TaskPreExecuteHook = Callable[[Context], None]
 TaskPostExecuteHook = Callable[[Context, Any], None]
 
-T = TypeVar('T', bound=Callable)
+T = TypeVar('T', bound=FunctionType)
 
 
 class BaseOperatorMeta(abc.ABCMeta):
@@ -483,6 +485,12 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
     # Set to True before calling execute method
     _lock_for_execution = False
 
+    _dag: Optional["DAG"] = None
+
+    # subdag parameter is only set for SubDagOperator.
+    # Setting it to None by default as other Operators do not have that field
+    subdag: Optional["DAG"] = None
+
     def __init__(
         self,
         task_id: str,
@@ -612,7 +620,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         self.pool = Pool.DEFAULT_POOL_NAME if pool is None else pool
         self.pool_slots = pool_slots
         if self.pool_slots < 1:
-            raise AirflowException(f"pool slots for {self.task_id} in dag {dag.dag_id} cannot be less than 1")
+            dag_str = f" in dag {dag.dag_id}" if dag else ""
+            raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1")
         self.sla = sla
         self.execution_timeout = execution_timeout
         self.on_execute_callback = on_execute_callback
@@ -636,7 +645,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
                 self.log.debug("max_retry_delay isn't a timedelta object, assuming secs")
                 self.max_retry_delay = timedelta(seconds=max_retry_delay)
 
-        self.params = ParamsDict(params)
+        # At execution_time this becomes a normal dict
+        self.params: Union[ParamsDict, dict] = ParamsDict(params)
         if priority_weight is not None and not isinstance(priority_weight, int):
             raise AirflowException(
                 f"`priority_weight` for task '{self.task_id}' only accepts integers, "
@@ -673,15 +683,10 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         # Private attributes
         self._upstream_task_ids: Set[str] = set()
         self._downstream_task_ids: Set[str] = set()
-        self._dag = None
-
-        self.dag = dag or DagContext.get_current_dag()
-
-        # subdag parameter is only set for SubDagOperator.
-        # Setting it to None by default as other Operators do not have that field
-        from airflow.models.dag import DAG
 
-        self.subdag: Optional[DAG] = None
+        dag = dag or DagContext.get_current_dag()
+        if dag:
+            self.dag = dag
 
         self._log = logging.getLogger("airflow.task.operators")
 
@@ -811,7 +816,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
     @property
     def dag(self) -> 'DAG':
         """Returns the Operator's DAG if set, otherwise raises an error"""
-        if self.has_dag():
+        if self._dag:
             return self._dag
         else:
             raise AirflowException(f'Operator {self} has not been assigned to a DAG yet')
@@ -840,7 +845,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
 
     def has_dag(self):
         """Returns True if the Operator has been assigned to a DAG."""
-        return getattr(self, '_dag', None) is not None
+        return self._dag is not None
 
     @property
     def dag_id(self) -> str:
@@ -1301,8 +1306,13 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         from airflow.models import DagRun
         from airflow.utils.types import DagRunType
 
-        start_date = start_date or self.start_date
-        end_date = end_date or self.end_date or timezone.utcnow()
+        # Assertions for typing -- we need a dag, for this function, and when we have a DAG we are
+        # _guaranteed_ to have start_date (else we couldn't have been added to a DAG)
+        if TYPE_CHECKING:
+            assert self.start_date
+
+        start_date = pendulum.instance(start_date or self.start_date)
+        end_date = pendulum.instance(end_date or self.end_date or timezone.utcnow())
 
         for info in self.dag.iter_dagrun_infos_between(start_date, end_date, align=False):
             ignore_depends_on_past = info.logical_date == start_date and ignore_first_depends_on_past
@@ -1325,7 +1335,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
                     execution_date=info.logical_date,
                     data_interval=info.data_interval,
                 )
-                ti = TaskInstance(self, run_id=None)
+                ti = TaskInstance(self, run_id=dr.run_id)
                 ti.dag_run = dr
                 session.add(dr)
                 session.flush()
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 52d73b8..5d35960 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -27,7 +27,7 @@ import sys
 import traceback
 import warnings
 from collections import OrderedDict
-from datetime import datetime, timedelta, tzinfo
+from datetime import datetime, timedelta
 from inspect import signature
 from typing import (
     TYPE_CHECKING,
@@ -51,8 +51,10 @@ import jinja2
 import pendulum
 from dateutil.relativedelta import relativedelta
 from jinja2.nativetypes import NativeEnvironment
+from pendulum.tz.timezone import Timezone
 from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_
 from sqlalchemy.orm import backref, joinedload, relationship
+from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.session import Session
 from sqlalchemy.sql import expression
 
@@ -81,10 +83,10 @@ from airflow.utils.dates import cron_presets, date_range as utils_date_range
 from airflow.utils.file import correct_maybe_zipped
 from airflow.utils.helpers import validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
-from airflow.utils.state import DagRunState, State
-from airflow.utils.types import DagRunType, EdgeInfoType
+from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
 
 if TYPE_CHECKING:
     from airflow.utils.task_group import TaskGroup
@@ -95,11 +97,14 @@ log = logging.getLogger(__name__)
 DEFAULT_VIEW_PRESETS = ['tree', 'graph', 'duration', 'gantt', 'landing_times']
 ORIENTATION_PRESETS = ['LR', 'TB', 'RL', 'BT']
 
-ScheduleIntervalArgNotSet = type("ScheduleIntervalArgNotSet", (), {})
 
 DagStateChangeCallback = Callable[[Context], None]
-ScheduleInterval = Union[str, timedelta, relativedelta]
-ScheduleIntervalArg = Union[ScheduleInterval, None, Type[ScheduleIntervalArgNotSet]]
+ScheduleInterval = Union[None, str, timedelta, relativedelta]
+
+# FIXME: Ideally this should be Union[Literal[NOTSET], ScheduleInterval],
+# but Mypy cannot handle that right now. Track progress of PEP 661 for progress.
+# See also: https://discuss.python.org/t/9126/7
+ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval]
 
 
 # Backward compatibility: If neither schedule_interval nor timetable is
@@ -145,9 +150,9 @@ def _get_model_data_interval(
     return DataInterval(start, end)
 
 
-def create_timetable(interval: ScheduleIntervalArg, timezone: tzinfo) -> Timetable:
+def create_timetable(interval: ScheduleIntervalArg, timezone: Timezone) -> Timetable:
     """Create a Timetable instance from a ``schedule_interval`` argument."""
-    if interval is ScheduleIntervalArgNotSet:
+    if interval is NOTSET:
         return DeltaDataIntervalTimetable(DEFAULT_SCHEDULE_INTERVAL)
     if interval is None:
         return NullTimetable()
@@ -319,11 +324,13 @@ class DAG(LoggingMixin):
     from a ZIP file or other DAG distribution format.
     """
 
+    parent_dag: Optional["DAG"] = None  # Gets set when DAGs are loaded
+
     def __init__(
         self,
         dag_id: str,
         description: Optional[str] = None,
-        schedule_interval: ScheduleIntervalArg = ScheduleIntervalArgNotSet,
+        schedule_interval: ScheduleIntervalArg = NOTSET,
         timetable: Optional[Timetable] = None,
         start_date: Optional[datetime] = None,
         end_date: Optional[datetime] = None,
@@ -356,15 +363,15 @@ class DAG(LoggingMixin):
         self.user_defined_macros = user_defined_macros
         self.user_defined_filters = user_defined_filters
         self.default_args = copy.deepcopy(default_args or {})
-        self.params = params or {}
+        params = params or {}
 
         # merging potentially conflicting default_args['params'] into params
         if 'params' in self.default_args:
-            self.params.update(self.default_args['params'])
+            params.update(self.default_args['params'])
             del self.default_args['params']
 
         # check self.params and convert them into ParamsDict
-        self.params = ParamsDict(self.params)
+        self.params = ParamsDict(params)
 
         if full_filepath:
             warnings.warn(
@@ -394,15 +401,19 @@ class DAG(LoggingMixin):
         self.task_dict: Dict[str, BaseOperator] = {}
 
         # set timezone from start_date
+        tz = None
         if start_date and start_date.tzinfo:
-            self.timezone = start_date.tzinfo
+            tzinfo = None if start_date.tzinfo else settings.TIMEZONE
+            tz = pendulum.instance(start_date, tz=tzinfo).timezone
         elif 'start_date' in self.default_args and self.default_args['start_date']:
-            if isinstance(self.default_args['start_date'], str):
-                self.default_args['start_date'] = timezone.parse(self.default_args['start_date'])
-            self.timezone = self.default_args['start_date'].tzinfo
+            date = self.default_args['start_date']
+            if not isinstance(date, datetime):
+                date = timezone.parse(date)
+                self.default_args['start_date'] = date
 
-        if not hasattr(self, 'timezone') or not self.timezone:
-            self.timezone = settings.TIMEZONE
+            tzinfo = None if date.tzinfo else settings.TIMEZONE
+            tz = pendulum.instance(date, tz=tzinfo).timezone
+        self.timezone = tz or settings.TIMEZONE
 
         # Apply the timezone we settled on to end_date if it wasn't supplied
         if 'end_date' in self.default_args and self.default_args['end_date']:
@@ -423,10 +434,10 @@ class DAG(LoggingMixin):
         # Calculate the DAG's timetable.
         if timetable is None:
             self.timetable = create_timetable(schedule_interval, self.timezone)
-            if schedule_interval is ScheduleIntervalArgNotSet:
+            if isinstance(schedule_interval, ArgNotSet):
                 schedule_interval = DEFAULT_SCHEDULE_INTERVAL
             self.schedule_interval: ScheduleInterval = schedule_interval
-        elif schedule_interval is ScheduleIntervalArgNotSet:
+        elif isinstance(schedule_interval, ArgNotSet):
             self.timetable = timetable
             self.schedule_interval = self.timetable.summary
         else:
@@ -436,7 +447,6 @@ class DAG(LoggingMixin):
             template_searchpath = [template_searchpath]
         self.template_searchpath = template_searchpath
         self.template_undefined = template_undefined
-        self.parent_dag: Optional[DAG] = None  # Gets set when DAGs are loaded
         self.last_loaded = timezone.utcnow()
         self.safe_dag_id = dag_id.replace('.', '__dot__')
         self.max_active_runs = max_active_runs
@@ -457,7 +467,6 @@ class DAG(LoggingMixin):
                 f'{ORIENTATION_PRESETS}, but get {orientation}'
             )
         self.catchup = catchup
-        self.is_subdag = False  # DagBag.bag_dag() will set this to True if appropriate
 
         self.partial = False
         self.on_success_callback = on_success_callback
@@ -480,7 +489,7 @@ class DAG(LoggingMixin):
 
         self.jinja_environment_kwargs = jinja_environment_kwargs
         self.render_template_as_native_obj = render_template_as_native_obj
-        self.tags = tags
+        self.tags = tags or []
         self._task_group = TaskGroup.create_root(self)
         self.validate_schedule_and_params()
 
@@ -554,21 +563,22 @@ class DAG(LoggingMixin):
 
     def date_range(
         self,
-        start_date: datetime,
+        start_date: pendulum.DateTime,
         num: Optional[int] = None,
-        end_date: Optional[datetime] = timezone.utcnow(),
+        end_date: Optional[datetime] = None,
     ) -> List[datetime]:
         message = "`DAG.date_range()` is deprecated."
         if num is not None:
-            result = utils_date_range(start_date=start_date, num=num)
-        else:
-            message += " Please use `DAG.iter_dagrun_infos_between(..., align=False)` instead."
-            result = [
-                info.logical_date
-                for info in self.iter_dagrun_infos_between(start_date, end_date, align=False)
-            ]
+            warnings.warn(message, category=DeprecationWarning, stacklevel=2)
+            return utils_date_range(start_date=start_date, num=num)
+        message += " Please use `DAG.iter_dagrun_infos_between(..., align=False)` instead."
         warnings.warn(message, category=DeprecationWarning, stacklevel=2)
-        return result
+        if end_date is None:
+            coerced_end_date = timezone.utcnow()
+        else:
+            coerced_end_date = end_date
+        it = self.iter_dagrun_infos_between(start_date, pendulum.instance(coerced_end_date), align=False)
+        return [info.logical_date for info in it]
 
     def is_fixed_time_schedule(self):
         warnings.warn(
@@ -706,6 +716,8 @@ class DAG(LoggingMixin):
         # Never schedule a subdag. It will be scheduled by its parent dag.
         if self.is_subdag:
             return None
+
+        data_interval = None
         if isinstance(last_automated_dagrun, datetime):
             warnings.warn(
                 "Passing a datetime to DAG.next_dagrun_info is deprecated. Use a DataInterval instead.",
@@ -755,17 +767,15 @@ class DAG(LoggingMixin):
         start_dates = [t.start_date for t in self.tasks if t.start_date]
         if self.start_date is not None:
             start_dates.append(self.start_date)
+        earliest = None
         if start_dates:
             earliest = timezone.coerce_datetime(min(start_dates))
-        else:
-            earliest = None
         end_dates = [t.end_date for t in self.tasks if t.end_date]
         if self.end_date is not None:
             end_dates.append(self.end_date)
+        latest = None
         if end_dates:
             latest = timezone.coerce_datetime(max(end_dates))
-        else:
-            latest = None
         return TimeRestriction(earliest, latest, self.catchup)
 
     def iter_dagrun_infos_between(
@@ -793,6 +803,8 @@ class DAG(LoggingMixin):
         """
         if earliest is None:
             earliest = self._time_restriction.earliest
+        if earliest is None:
+            raise ValueError("earliest was None and we had no value in time_restriction to fallback on")
         earliest = timezone.coerce_datetime(earliest)
         latest = timezone.coerce_datetime(latest)
 
@@ -843,7 +855,7 @@ class DAG(LoggingMixin):
             except Exception:
                 self.log.exception(
                     "Failed to fetch run info after data interval %s for DAG %r",
-                    info.data_interval,
+                    info.data_interval if info else "<NONE>",
                     self.dag_id,
                 )
                 break
@@ -891,13 +903,13 @@ class DAG(LoggingMixin):
         return dttm
 
     @provide_session
-    def get_last_dagrun(self, session=None, include_externally_triggered=False):
+    def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False):
         return get_last_dagrun(
             self.dag_id, session=session, include_externally_triggered=include_externally_triggered
         )
 
     @provide_session
-    def has_dag_runs(self, session=None, include_externally_triggered=True) -> bool:
+    def has_dag_runs(self, session=NEW_SESSION, include_externally_triggered=True) -> bool:
         return (
             get_last_dagrun(
                 self.dag_id, session=session, include_externally_triggered=include_externally_triggered
@@ -914,6 +926,10 @@ class DAG(LoggingMixin):
         self._dag_id = value
 
     @property
+    def is_subdag(self) -> bool:
+        return self.parent_dag is not None
+
+    @property
     def full_filepath(self) -> str:
         """:meta private:"""
         warnings.warn(
@@ -1042,7 +1058,7 @@ class DAG(LoggingMixin):
         return settings.ALLOW_FUTURE_EXEC_DATES and not self.timetable.can_run
 
     @provide_session
-    def get_concurrency_reached(self, session=None) -> bool:
+    def get_concurrency_reached(self, session=NEW_SESSION) -> bool:
         """
         Returns a boolean indicating whether the max_active_tasks limit for this DAG
         has been reached
@@ -1065,13 +1081,13 @@ class DAG(LoggingMixin):
         return self.get_concurrency_reached()
 
     @provide_session
-    def get_is_active(self, session=None) -> Optional[None]:
+    def get_is_active(self, session=NEW_SESSION) -> Optional[None]:
         """Returns a boolean indicating whether this DAG is active"""
         qry = session.query(DagModel).filter(DagModel.dag_id == self.dag_id)
         return qry.value(DagModel.is_active)
 
     @provide_session
-    def get_is_paused(self, session=None) -> Optional[None]:
+    def get_is_paused(self, session=NEW_SESSION) -> Optional[None]:
         """Returns a boolean indicating whether this DAG is paused"""
         qry = session.query(DagModel).filter(DagModel.dag_id == self.dag_id)
         return qry.value(DagModel.is_paused)
@@ -1087,14 +1103,14 @@ class DAG(LoggingMixin):
         return self.get_is_paused()
 
     @property
-    def normalized_schedule_interval(self) -> Optional[ScheduleInterval]:
+    def normalized_schedule_interval(self) -> ScheduleInterval:
         warnings.warn(
             "DAG.normalized_schedule_interval() is deprecated.",
             category=DeprecationWarning,
             stacklevel=2,
         )
         if isinstance(self.schedule_interval, str) and self.schedule_interval in cron_presets:
-            _schedule_interval = cron_presets.get(self.schedule_interval)  # type: Optional[ScheduleInterval]
+            _schedule_interval: ScheduleInterval = cron_presets.get(self.schedule_interval)
         elif self.schedule_interval == '@once':
             _schedule_interval = None
         else:
@@ -1102,7 +1118,7 @@ class DAG(LoggingMixin):
         return _schedule_interval
 
     @provide_session
-    def handle_callback(self, dagrun, success=True, reason=None, session=None):
+    def handle_callback(self, dagrun, success=True, reason=None, session=NEW_SESSION):
         """
         Triggers the appropriate callback depending on the value of success, namely the
         on_failure_callback or on_success_callback. This method gets the context of a
@@ -1146,7 +1162,7 @@ class DAG(LoggingMixin):
         return active_dates
 
     @provide_session
-    def get_num_active_runs(self, external_trigger=None, only_running=True, session=None):
+    def get_num_active_runs(self, external_trigger=None, only_running=True, session=NEW_SESSION):
         """
         Returns the number of active "running" dag runs
 
@@ -1174,7 +1190,7 @@ class DAG(LoggingMixin):
         self,
         execution_date: Optional[str] = None,
         run_id: Optional[str] = None,
-        session: Optional[Session] = None,
+        session: Session = NEW_SESSION,
     ):
         """
         Returns the dag run for a given execution date or run_id if it exists, otherwise
@@ -1195,7 +1211,7 @@ class DAG(LoggingMixin):
         return query.first()
 
     @provide_session
-    def get_dagruns_between(self, start_date, end_date, session=None):
+    def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION):
         """
         Returns the list of dag runs between start_date (inclusive) and end_date (inclusive).
 
@@ -1223,9 +1239,9 @@ class DAG(LoggingMixin):
 
     @property
     def latest_execution_date(self):
-        """This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date` method."""
+        """This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`."""
         warnings.warn(
-            "This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date` method.",
+            "This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.",
             DeprecationWarning,
             stacklevel=2,
         )
@@ -1297,7 +1313,7 @@ class DAG(LoggingMixin):
         base_date: datetime,
         num: int,
         *,
-        session: Session,
+        session: Session = NEW_SESSION,
     ) -> List[TaskInstance]:
         """Get ``num`` task instances before (including) ``base_date``.
 
@@ -1324,25 +1340,35 @@ class DAG(LoggingMixin):
 
     @provide_session
     def get_task_instances(
-        self, start_date=None, end_date=None, state=None, session=None
+        self,
+        start_date: Optional[datetime] = None,
+        end_date: Optional[datetime] = None,
+        state: Optional[List[TaskInstanceState]] = None,
+        session: Session = NEW_SESSION,
     ) -> List[TaskInstance]:
         if not start_date:
-            start_date = (timezone.utcnow() - timedelta(30)).date()
-            start_date = timezone.make_aware(datetime.combine(start_date, datetime.min.time()))
+            start_date = (timezone.utcnow() - timedelta(30)).replace(
+                hour=0, minute=0, second=0, microsecond=0
+            )
+
+        if state is None:
+            state = []
 
         return (
-            self._get_task_instances(
-                task_ids=None,
-                start_date=start_date,
-                end_date=end_date,
-                run_id=None,
-                state=state,
-                include_subdags=False,
-                include_parentdag=False,
-                include_dependent_dags=False,
-                exclude_task_ids=[],
-                as_pk_tuple=False,
-                session=session,
+            cast(
+                Query,
+                self._get_task_instances(
+                    task_ids=None,
+                    start_date=start_date,
+                    end_date=end_date,
+                    run_id=None,
+                    state=state,
+                    include_subdags=False,
+                    include_parentdag=False,
+                    include_dependent_dags=False,
+                    exclude_task_ids=cast(List[str], []),
+                    session=session,
+                ),
             )
             .join(TaskInstance.dag_run)
             .order_by(DagRun.execution_date)
@@ -1356,19 +1382,15 @@ class DAG(LoggingMixin):
         task_ids,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
-        run_id: None,
-        state: Union[str, List[str]],
+        run_id: Optional[str],
+        state: Union[TaskInstanceState, List[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
         exclude_task_ids: Collection[str],
-        as_pk_tuple: Literal[True],
         session: Session,
-        dag_bag: "DagBag" = None,
-        recursion_depth: int = 0,
-        max_recursion_depth: int = None,
-        visited_external_tis: Set[Tuple[str, str, datetime]] = None,
-    ) -> Set["TaskInstanceKey"]:
+        dag_bag: Optional["DagBag"] = ...,
+    ) -> Iterable[TaskInstance]:
         ...  # pragma: no cover
 
     @overload
@@ -1376,41 +1398,41 @@ class DAG(LoggingMixin):
         self,
         *,
         task_ids,
+        as_pk_tuple: Literal[True],
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[str, List[str]],
+        state: Union[TaskInstanceState, List[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        as_pk_tuple: Literal[False],
         exclude_task_ids: Collection[str],
         session: Session,
-        dag_bag: "DagBag" = None,
-        recursion_depth: int = 0,
-        max_recursion_depth: int = None,
-        visited_external_tis: Set[Tuple[str, str, datetime]] = None,
-    ) -> Iterable[TaskInstance]:
+        dag_bag: Optional["DagBag"] = ...,
+        recursion_depth: int = ...,
+        max_recursion_depth: int = ...,
+        visited_external_tis: Set[TaskInstanceKey] = ...,
+    ) -> Set["TaskInstanceKey"]:
         ...  # pragma: no cover
 
     def _get_task_instances(
         self,
         *,
         task_ids,
+        as_pk_tuple: Literal[True, None] = None,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[str, List[str]],
+        state: Union[TaskInstanceState, List[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        as_pk_tuple: bool,
         exclude_task_ids: Collection[str],
         session: Session,
-        dag_bag: "DagBag" = None,
+        dag_bag: Optional["DagBag"] = None,
         recursion_depth: int = 0,
-        max_recursion_depth: int = None,
-        visited_external_tis: Set[Tuple[str, str, datetime]] = None,
+        max_recursion_depth: Optional[int] = None,
+        visited_external_tis: Optional[Set[TaskInstanceKey]] = None,
     ) -> Union[Iterable[TaskInstance], Set[TaskInstanceKey]]:
         TI = TaskInstance
 
@@ -1452,7 +1474,7 @@ class DAG(LoggingMixin):
             tis = tis.filter(DagRun.execution_date <= end_date)
 
         if state:
-            if isinstance(state, str):
+            if isinstance(state, (str, TaskInstanceState)):
                 tis = tis.filter(TaskInstance.state == state)
             elif len(state) == 1:
                 tis = tis.filter(TaskInstance.state == state[0])
@@ -1470,7 +1492,11 @@ class DAG(LoggingMixin):
                     tis = tis.filter(TaskInstance.state.in_(state))
 
         # Next, get any of them from our parent DAG (if there is one)
-        if include_parentdag and self.is_subdag and self.parent_dag is not None:
+        if include_parentdag and self.parent_dag is not None:
+
+            if visited_external_tis is None:
+                visited_external_tis = set()
+
             p_dag = self.parent_dag.partial_subset(
                 task_ids_or_regex=r"^{}$".format(self.dag_id.split('.')[1]),
                 include_upstream=False,
@@ -1611,7 +1637,7 @@ class DAG(LoggingMixin):
         future: Optional[bool] = False,
         past: Optional[bool] = False,
         commit: Optional[bool] = True,
-        session=None,
+        session=NEW_SESSION,
     ) -> List[TaskInstance]:
         """
         Set the state of a TaskInstance to the given state, and clear its downstream tasks that are
@@ -1747,10 +1773,10 @@ class DAG(LoggingMixin):
     def set_dag_runs_state(
         self,
         state: str = State.RUNNING,
-        session: Session = None,
+        session: Session = NEW_SESSION,
         start_date: Optional[datetime] = None,
         end_date: Optional[datetime] = None,
-        dag_ids: List[str] = None,
+        dag_ids: List[str] = [],
     ) -> None:
         warnings.warn(
             "This method is deprecated and will be removed in a future version.",
@@ -1769,22 +1795,22 @@ class DAG(LoggingMixin):
     def clear(
         self,
         task_ids=None,
-        start_date=None,
-        end_date=None,
-        only_failed=False,
-        only_running=False,
-        confirm_prompt=False,
-        include_subdags=True,
-        include_parentdag=True,
+        start_date: Optional[datetime] = None,
+        end_date: Optional[datetime] = None,
+        only_failed: bool = False,
+        only_running: bool = False,
+        confirm_prompt: bool = False,
+        include_subdags: bool = True,
+        include_parentdag: bool = True,
         dag_run_state: DagRunState = DagRunState.QUEUED,
-        dry_run=False,
-        session=None,
-        get_tis=False,
-        recursion_depth=0,
-        max_recursion_depth=None,
-        dag_bag=None,
+        dry_run: bool = False,
+        session: Session = NEW_SESSION,
+        get_tis: bool = False,
+        recursion_depth: int = 0,
+        max_recursion_depth: Optional[int] = None,
+        dag_bag: Optional["DagBag"] = None,
         exclude_task_ids: FrozenSet[str] = frozenset({}),
-    ):
+    ) -> Union[int, Iterable[TaskInstance]]:
         """
         Clears a set of task instances associated with the current dag for
         a specified date range.
@@ -1841,11 +1867,9 @@ class DAG(LoggingMixin):
         state = []
         if only_failed:
             state += [State.FAILED, State.UPSTREAM_FAILED]
-            only_failed = None
         if only_running:
             # Yes, having `+=` doesn't make sense, but this was the existing behaviour
             state += [State.RUNNING]
-            only_running = None
 
         tis = self._get_task_instances(
             task_ids=task_ids,
@@ -1856,7 +1880,6 @@ class DAG(LoggingMixin):
             include_subdags=include_subdags,
             include_parentdag=include_parentdag,
             include_dependent_dags=include_subdags,  # compat, yes this is not a typo
-            as_pk_tuple=False,
             session=session,
             dag_bag=dag_bag,
             exclude_task_ids=exclude_task_ids,
@@ -1865,7 +1888,7 @@ class DAG(LoggingMixin):
         if dry_run:
             return tis
 
-        tis = tis.all()
+        tis = list(tis)
 
         count = len(tis)
         do_it = True
@@ -2095,7 +2118,7 @@ class DAG(LoggingMixin):
         return d
 
     @provide_session
-    def pickle(self, session=None) -> DagPickle:
+    def pickle(self, session=NEW_SESSION) -> DagPickle:
         dag = session.query(DagModel).filter(DagModel.dag_id == self.dag_id).first()
         dp = None
         if dag and dag.pickle_id:
@@ -2278,7 +2301,7 @@ class DAG(LoggingMixin):
         external_trigger: Optional[bool] = False,
         conf: Optional[dict] = None,
         run_type: Optional[DagRunType] = None,
-        session=None,
+        session=NEW_SESSION,
         dag_hash: Optional[str] = None,
         creating_job_id: Optional[int] = None,
         data_interval: Optional[Tuple[datetime, datetime]] = None,
@@ -2367,7 +2390,7 @@ class DAG(LoggingMixin):
 
     @classmethod
     @provide_session
-    def bulk_sync_to_db(cls, dags: Collection["DAG"], session=None):
+    def bulk_sync_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
         """This method is deprecated in favor of bulk_write_to_db"""
         warnings.warn(
             "This method is deprecated and will be removed in a future version. Please use bulk_write_to_db",
@@ -2378,7 +2401,7 @@ class DAG(LoggingMixin):
 
     @classmethod
     @provide_session
-    def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
+    def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
         """
         Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB, including
         calculated fields.
@@ -2491,7 +2514,7 @@ class DAG(LoggingMixin):
             cls.bulk_write_to_db(dag.subdags, session=session)
 
     @provide_session
-    def sync_to_db(self, session=None):
+    def sync_to_db(self, session=NEW_SESSION):
         """
         Save attributes about this DAG to the DB. Note that this method
         can be called for both DAGs and SubDAGs. A SubDag is actually a
@@ -2510,7 +2533,7 @@ class DAG(LoggingMixin):
 
     @staticmethod
     @provide_session
-    def deactivate_unknown_dags(active_dag_ids, session=None):
+    def deactivate_unknown_dags(active_dag_ids, session=NEW_SESSION):
         """
         Given a list of known DAGs, deactivate any other DAGs that are
         marked as active in the ORM
@@ -2528,7 +2551,7 @@ class DAG(LoggingMixin):
 
     @staticmethod
     @provide_session
-    def deactivate_stale_dags(expiration_date, session=None):
+    def deactivate_stale_dags(expiration_date, session=NEW_SESSION):
         """
         Deactivate any DAGs that were last touched by the scheduler before
         the expiration date. These DAGs were likely deleted.
@@ -2554,7 +2577,7 @@ class DAG(LoggingMixin):
 
     @staticmethod
     @provide_session
-    def get_num_task_instances(dag_id, task_ids=None, states=None, session=None):
+    def get_num_task_instances(dag_id, task_ids=None, states=None, session=NEW_SESSION):
         """
         Returns the number of task instances in the given DAG.
 
@@ -2604,7 +2627,6 @@ class DAG(LoggingMixin):
                 'params',
                 '_pickle_id',
                 '_log',
-                'is_subdag',
                 'task_dict',
                 'template_searchpath',
                 'sla_miss_callback',
@@ -2621,13 +2643,14 @@ class DAG(LoggingMixin):
     def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType:
         """
         Returns edge information for the given pair of tasks if present, and
-        None if there is no information.
+        an empty edge if there is no information.
         """
         # Note - older serialized DAGs may not have edge_info being a dict at all
+        empty = cast(EdgeInfoType, {})
         if self.edge_info:
-            return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, {})
+            return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty)
         else:
-            return {}
+            return empty
 
     def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType):
         """
@@ -2778,23 +2801,23 @@ class DagModel(Base):
 
     @staticmethod
     @provide_session
-    def get_dagmodel(dag_id, session=None):
+    def get_dagmodel(dag_id, session=NEW_SESSION):
         return session.query(DagModel).options(joinedload(DagModel.parent_dag)).get(dag_id)
 
     @classmethod
     @provide_session
-    def get_current(cls, dag_id, session=None):
+    def get_current(cls, dag_id, session=NEW_SESSION):
         return session.query(cls).filter(cls.dag_id == dag_id).first()
 
     @provide_session
-    def get_last_dagrun(self, session=None, include_externally_triggered=False):
+    def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False):
         return get_last_dagrun(
             self.dag_id, session=session, include_externally_triggered=include_externally_triggered
         )
 
     @staticmethod
     @provide_session
-    def get_paused_dag_ids(dag_ids: List[str], session: Session = None) -> Set[str]:
+    def get_paused_dag_ids(dag_ids: List[str], session: Session = NEW_SESSION) -> Set[str]:
         """
         Given a list of dag_ids, get a set of Paused Dag Ids
 
@@ -2837,7 +2860,7 @@ class DagModel(Base):
             return path
 
     @provide_session
-    def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=None) -> None:
+    def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=NEW_SESSION) -> None:
         """
         Pause/Un-pause a DAG.
 
@@ -2857,7 +2880,7 @@ class DagModel(Base):
 
     @classmethod
     @provide_session
-    def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=None):
+    def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=NEW_SESSION):
         """
         Set ``is_active=False`` on the DAGs for which the DAG files have been removed.
 
@@ -2913,6 +2936,7 @@ class DagModel(Base):
         :param most_recent_dag_run: DataInterval (or datetime) of most recent run of this dag, or none
             if not yet scheduled.
         """
+        most_recent_data_interval: Optional[DataInterval]
         if isinstance(most_recent_dag_run, datetime):
             warnings.warn(
                 "Passing a datetime to `DagModel.calculate_dagrun_date_fields` is deprecated. "
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 9f5a135..2c737e2 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -218,8 +218,8 @@ class DagBag(LoggingMixin):
         root_dag_id = dag_id
         if dag_id in self.dags:
             dag = self.dags[dag_id]
-            if dag.is_subdag:
-                root_dag_id = dag.parent_dag.dag_id  # type: ignore
+            if dag.parent_dag:
+                root_dag_id = dag.parent_dag.dag_id
 
         # If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized?
         orm_dag = DagModel.get_current(root_dag_id, session=session)
@@ -234,7 +234,7 @@ class DagBag(LoggingMixin):
             self.dags = {
                 key: dag
                 for key, dag in self.dags.items()
-                if root_dag_id != key and not (dag.is_subdag and root_dag_id == dag.parent_dag.dag_id)
+                if root_dag_id != key and not (dag.parent_dag and root_dag_id == dag.parent_dag.dag_id)
             }
         if is_missing or is_expired:
             # Reprocess source file.
@@ -397,7 +397,6 @@ class DagBag(LoggingMixin):
         for (dag, mod) in top_level_dags:
             dag.fileloc = mod.__file__
             try:
-                dag.is_subdag = False
                 dag.timetable.validate()
                 self.bag_dag(dag=dag, root_dag=dag)
                 found_dags.append(dag)
@@ -451,7 +450,6 @@ class DagBag(LoggingMixin):
                 for subdag in subdags:
                     subdag.fileloc = dag.fileloc
                     subdag.parent_dag = dag
-                    subdag.is_subdag = True
                     self._bag_dag(dag=subdag, root_dag=root_dag, recursive=False)
 
             prev_dag = self.dags.get(dag.dag_id)
@@ -572,7 +570,7 @@ class DagBag(LoggingMixin):
         return report
 
     @provide_session
-    def sync_to_db(self, session: Optional[Session] = None):
+    def sync_to_db(self, session: Session = None):
         """Save attributes about list of DAG to the DB."""
         # To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag
         from airflow.models.dag import DAG
@@ -628,7 +626,7 @@ class DagBag(LoggingMixin):
                 self.import_errors.update(dict(serialize_errors))
 
     @provide_session
-    def _sync_perm_for_dag(self, dag, session: Optional[Session] = None):
+    def _sync_perm_for_dag(self, dag, session: Session = None):
         """Sync DAG specific permissions, if necessary"""
         from airflow.security.permissions import DAG_ACTIONS, resource_name_for_dag
         from airflow.www.fab_security.sqla.models import Action, Permission, Resource
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 20ec7cd..f9ced3a 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -15,6 +15,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import os
 import warnings
 from datetime import datetime
 from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
@@ -48,10 +49,10 @@ from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES
 from airflow.utils import callback_requests, timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
-from airflow.utils.types import DagRunType
+from airflow.utils.types import NOTSET, ArgNotSet, DagRunType
 
 if TYPE_CHECKING:
     from airflow.models.dag import DAG
@@ -75,8 +76,6 @@ class DagRun(Base, LoggingMixin):
 
     __tablename__ = "dag_run"
 
-    __NO_VALUE = object()
-
     id = Column(Integer, primary_key=True)
     dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
     queued_at = Column(UtcDateTime)
@@ -96,7 +95,11 @@ class DagRun(Base, LoggingMixin):
     last_scheduling_decision = Column(UtcDateTime)
     dag_hash = Column(String(32))
 
-    dag = None
+    # Remove this `if` after upgrading Sphinx-AutoAPI
+    if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
+        dag: "Optional[DAG]"
+    else:
+        dag: "Optional[DAG]" = None
 
     __table_args__ = (
         Index('dag_id_state', dag_id, _state),
@@ -138,7 +141,7 @@ class DagRun(Base, LoggingMixin):
         self,
         dag_id: Optional[str] = None,
         run_id: Optional[str] = None,
-        queued_at: Optional[datetime] = __NO_VALUE,
+        queued_at: Union[datetime, None, ArgNotSet] = NOTSET,  # type: ignore
         execution_date: Optional[datetime] = None,
         start_date: Optional[datetime] = None,
         external_trigger: Optional[bool] = None,
@@ -163,7 +166,7 @@ class DagRun(Base, LoggingMixin):
         self.conf = conf or {}
         if state is not None:
             self.state = state
-        if queued_at is self.__NO_VALUE:
+        if queued_at is NOTSET:
             self.queued_at = timezone.utcnow() if state == State.QUEUED else None
         else:
             self.queued_at = queued_at
@@ -203,7 +206,7 @@ class DagRun(Base, LoggingMixin):
         return synonym('_state', descriptor=property(self.get_state, self.set_state))
 
     @provide_session
-    def refresh_from_db(self, session: Session = None):
+    def refresh_from_db(self, session: Session = NEW_SESSION) -> None:
         """
         Reloads the current dagrun from the database
 
@@ -299,7 +302,7 @@ class DagRun(Base, LoggingMixin):
         external_trigger: Optional[bool] = None,
         no_backfills: bool = False,
         run_type: Optional[DagRunType] = None,
-        session: Session = None,
+        session: Session = NEW_SESSION,
         execution_start_date: Optional[datetime] = None,
         execution_end_date: Optional[datetime] = None,
     ) -> List["DagRun"]:
@@ -363,7 +366,7 @@ class DagRun(Base, LoggingMixin):
         dag_id: str,
         run_id: str,
         execution_date: datetime,
-        session: Session = None,
+        session: Session = NEW_SESSION,
     ) -> Optional['DagRun']:
         """
         Return an existing run for the DAG with a specific run_id or execution_date.
@@ -412,7 +415,7 @@ class DagRun(Base, LoggingMixin):
                 tis = tis.filter(TI.state == state)
             else:
                 # this is required to deal with NULL values
-                if None in state:
+                if TaskInstanceState.NONE in state:
                     if all(x is None for x in state):
                         tis = tis.filter(TI.state.is_(None))
                     else:
@@ -426,7 +429,7 @@ class DagRun(Base, LoggingMixin):
         return tis.all()
 
     @provide_session
-    def get_task_instance(self, task_id: str, session: Session = None) -> Optional[TI]:
+    def get_task_instance(self, task_id: str, session: Session = NEW_SESSION) -> Optional[TI]:
         """
         Returns the task instance specified by task_id for this dag run
 
@@ -454,7 +457,7 @@ class DagRun(Base, LoggingMixin):
 
     @provide_session
     def get_previous_dagrun(
-        self, state: Optional[DagRunState] = None, session: Session = None
+        self, state: Optional[DagRunState] = None, session: Session = NEW_SESSION
     ) -> Optional['DagRun']:
         """The previous DagRun, if there is one"""
         filters = [
@@ -466,7 +469,7 @@ class DagRun(Base, LoggingMixin):
         return session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first()
 
     @provide_session
-    def get_previous_scheduled_dagrun(self, session: Session = None) -> Optional['DagRun']:
+    def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) -> Optional['DagRun']:
         """The previous, SCHEDULED DagRun, if there is one"""
         return (
             session.query(DagRun)
@@ -481,7 +484,7 @@ class DagRun(Base, LoggingMixin):
 
     @provide_session
     def update_state(
-        self, session: Session = None, execute_callbacks: bool = True
+        self, session: Session = NEW_SESSION, execute_callbacks: bool = True
     ) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]:
         """
         Determines the overall state of the DagRun based on the state
@@ -528,7 +531,7 @@ class DagRun(Base, LoggingMixin):
         # if all roots finished and at least one failed, the run failed
         if not unfinished_tasks and any(leaf_ti.state in State.failed_states for leaf_ti in leaf_tis):
             self.log.error('Marking run %s failed', self)
-            self.set_state(State.FAILED)
+            self.set_state(DagRunState.FAILED)
             if execute_callbacks:
                 dag.handle_callback(self, success=False, reason='task_failure', session=session)
             elif dag.has_on_failure_callback:
@@ -543,7 +546,7 @@ class DagRun(Base, LoggingMixin):
         # if all leaves succeeded and no unfinished tasks, the run succeeded
         elif not unfinished_tasks and all(leaf_ti.state in State.success_states for leaf_ti in leaf_tis):
             self.log.info('Marking run %s successful', self)
-            self.set_state(State.SUCCESS)
+            self.set_state(DagRunState.SUCCESS)
             if execute_callbacks:
                 dag.handle_callback(self, success=True, reason='success', session=session)
             elif dag.has_on_success_callback:
@@ -564,7 +567,7 @@ class DagRun(Base, LoggingMixin):
             and not are_runnable_tasks
         ):
             self.log.error('Deadlock; marking run %s failed', self)
-            self.set_state(State.FAILED)
+            self.set_state(DagRunState.FAILED)
             if execute_callbacks:
                 dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session)
             elif dag.has_on_failure_callback:
@@ -578,9 +581,9 @@ class DagRun(Base, LoggingMixin):
 
         # finally, if the roots aren't done, the dag is still running
         else:
-            self.set_state(State.RUNNING)
+            self.set_state(DagRunState.RUNNING)
 
-        if self._state == State.FAILED or self._state == State.SUCCESS:
+        if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS:
             msg = (
                 "DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, "
                 "run_start_date=%s, run_end_date=%s, run_duration=%s, "
@@ -613,7 +616,7 @@ class DagRun(Base, LoggingMixin):
         return schedulable_tis, callback
 
     @provide_session
-    def task_instance_scheduling_decisions(self, session: Session = None) -> TISchedulingDecision:
+    def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision:
 
         schedulable_tis: List[TI] = []
         changed_tis = False
@@ -759,7 +762,7 @@ class DagRun(Base, LoggingMixin):
             Stats.timing(f'dagrun.duration.failed.{self.dag_id}', duration)
 
     @provide_session
-    def verify_integrity(self, session: Session = None):
+    def verify_integrity(self, session: Session = NEW_SESSION):
         """
         Verifies the DagRun by checking for removed tasks or tasks that are not in the
         database yet. It will set state to removed or add the task if required.
@@ -869,7 +872,7 @@ class DagRun(Base, LoggingMixin):
         )
 
     @provide_session
-    def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) -> int:
+    def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SESSION) -> int:
         """
         Set the given task instances in to the scheduled state.
 
diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py
index a9e359e..0a68587 100644
--- a/airflow/models/serialized_dag.py
+++ b/airflow/models/serialized_dag.py
@@ -73,7 +73,7 @@ class SerializedDagModel(Base):
 
     dag_runs = relationship(
         DagRun,
-        primaryjoin=dag_id == foreign(DagRun.dag_id),
+        primaryjoin=dag_id == foreign(DagRun.dag_id),  # type: ignore
         backref=backref('serialized_dag', uselist=False, innerjoin=True),
     )
 
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 2638739..8d7bd36 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -27,7 +27,20 @@ from collections import defaultdict
 from datetime import datetime, timedelta
 from functools import partial
 from tempfile import NamedTemporaryFile
-from typing import IO, TYPE_CHECKING, Any, Iterable, List, NamedTuple, Optional, Tuple, Union
+from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    NamedTuple,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
 from urllib.parse import quote
 
 import dill
@@ -86,13 +99,12 @@ from airflow.typing_compat import Literal
 from airflow.utils import timezone
 from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor
 from airflow.utils.email import send_email
-from airflow.utils.helpers import is_container
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.net import get_hostname
 from airflow.utils.operator_helpers import context_to_airflow_vars
 from airflow.utils.platform import getuser
 from airflow.utils.retries import run_with_db_retries
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
 from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
 from airflow.utils.state import DagRunState, State
 from airflow.utils.timeout import timeout
@@ -117,7 +129,7 @@ if TYPE_CHECKING:
 
 
 @contextlib.contextmanager
-def set_current_context(context: Context) -> None:
+def set_current_context(context: Context) -> Iterator[Context]:
     """
     Sets the current execution context to the provided context object.
     This method should be called once per Task execution, before calling operator.execute.
@@ -179,7 +191,9 @@ def clear_task_instances(
     :param activate_dag_runs: Deprecated parameter, do not pass
     """
     job_ids = []
-    task_id_by_key = defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
+    task_id_by_key: Dict[str, Dict[str, Dict[int, Set[str]]]] = defaultdict(
+        lambda: defaultdict(lambda: defaultdict(set))
+    )
     for ti in tis:
         if ti.state == State.RUNNING:
             if ti.job_id:
@@ -404,7 +418,11 @@ class TaskInstance(Base, LoggingMixin):
     execution_date = association_proxy("dag_run", "execution_date")
 
     def __init__(
-        self, task, execution_date: Optional[datetime] = None, run_id: str = None, state: Optional[str] = None
+        self,
+        task: "BaseOperator",
+        execution_date: Optional[datetime] = None,
+        run_id: Optional[str] = None,
+        state: Optional[str] = None,
     ):
         super().__init__()
         self.dag_id = task.dag_id
@@ -562,7 +580,7 @@ class TaskInstance(Base, LoggingMixin):
     def generate_command(
         dag_id: str,
         task_id: str,
-        run_id: str = None,
+        run_id: str,
         mark_success: bool = False,
         ignore_all_deps: bool = False,
         ignore_depends_on_past: bool = False,
@@ -666,7 +684,7 @@ class TaskInstance(Base, LoggingMixin):
         )
 
     @provide_session
-    def current_state(self, session=None) -> str:
+    def current_state(self, session=NEW_SESSION) -> str:
         """
         Get the very latest state from the database, if a session is passed,
         we use and looking up the state becomes part of the session, otherwise
@@ -691,7 +709,7 @@ class TaskInstance(Base, LoggingMixin):
         return state
 
     @provide_session
-    def error(self, session=None):
+    def error(self, session=NEW_SESSION):
         """
         Forces the task instance's state to FAILED in the database.
 
@@ -704,7 +722,7 @@ class TaskInstance(Base, LoggingMixin):
         session.commit()
 
     @provide_session
-    def refresh_from_db(self, session=None, lock_for_update=False) -> None:
+    def refresh_from_db(self, session=NEW_SESSION, lock_for_update=False) -> None:
         """
         Refreshes the task instance from the database based on the primary key
 
@@ -760,7 +778,7 @@ class TaskInstance(Base, LoggingMixin):
 
         self.log.debug("Refreshed TaskInstance %s", self)
 
-    def refresh_from_task(self, task, pool_override=None):
+    def refresh_from_task(self, task: "BaseOperator", pool_override=None):
         """
         Copy common attributes from the given task.
 
@@ -780,7 +798,7 @@ class TaskInstance(Base, LoggingMixin):
         self.operator = task.task_type
 
     @provide_session
-    def clear_xcom_data(self, session=None):
+    def clear_xcom_data(self, session=NEW_SESSION):
         """
         Clears all XCom data from the database for the task instance
 
@@ -802,7 +820,7 @@ class TaskInstance(Base, LoggingMixin):
         return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number)
 
     @provide_session
-    def set_state(self, state: str, session=None):
+    def set_state(self, state: str, session=NEW_SESSION):
         """
         Set TaskInstance state.
 
@@ -830,7 +848,7 @@ class TaskInstance(Base, LoggingMixin):
         return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
 
     @provide_session
-    def are_dependents_done(self, session=None):
+    def are_dependents_done(self, session=NEW_SESSION):
         """
         Checks whether the immediate dependents of this task instance have succeeded or have been skipped.
         This is meant to be used by wait_for_downstream.
@@ -880,7 +898,7 @@ class TaskInstance(Base, LoggingMixin):
                 # XXX: This uses DAG internals, but as the outer comment
                 # said, the block is only reached for legacy reasons for
                 # development code, so that's OK-ish.
-                schedule = dag.timetable._schedule
+                schedule = dag.timetable._schedule  # type: ignore
             except AttributeError:
                 return None
             dt = pendulum.instance(self.execution_date)
@@ -908,7 +926,7 @@ class TaskInstance(Base, LoggingMixin):
 
     @provide_session
     def get_previous_ti(
-        self, state: Optional[str] = None, session: Session = None
+        self, state: Optional[str] = None, session: Session = NEW_SESSION
     ) -> Optional['TaskInstance']:
         """
         The task instance for the task that ran before this task instance.
@@ -957,7 +975,7 @@ class TaskInstance(Base, LoggingMixin):
     def get_previous_execution_date(
         self,
         state: Optional[str] = None,
-        session: Session = None,
+        session: Session = NEW_SESSION,
     ) -> Optional[pendulum.DateTime]:
         """
         The execution date from property previous_ti_success.
@@ -971,7 +989,7 @@ class TaskInstance(Base, LoggingMixin):
 
     @provide_session
     def get_previous_start_date(
-        self, state: Optional[str] = None, session: Session = None
+        self, state: Optional[str] = None, session: Session = NEW_SESSION
     ) -> Optional[pendulum.DateTime]:
         """
         The start date from property previous_ti_success.
@@ -1001,7 +1019,7 @@ class TaskInstance(Base, LoggingMixin):
         return self.get_previous_start_date(state=State.SUCCESS)
 
     @provide_session
-    def are_dependencies_met(self, dep_context=None, session=None, verbose=False):
+    def are_dependencies_met(self, dep_context=None, session=NEW_SESSION, verbose=False):
         """
         Returns whether or not all the conditions are met for this task instance to be run
         given the context for the dependencies (e.g. a task instance being force run from
@@ -1036,7 +1054,7 @@ class TaskInstance(Base, LoggingMixin):
         return True
 
     @provide_session
-    def get_failed_dep_statuses(self, dep_context=None, session=None):
+    def get_failed_dep_statuses(self, dep_context=None, session=NEW_SESSION):
         """Get failed Dependencies"""
         dep_context = dep_context or DepContext()
         for dep in dep_context.deps | self.task.deps:
@@ -1103,7 +1121,7 @@ class TaskInstance(Base, LoggingMixin):
         return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow()
 
     @provide_session
-    def get_dagrun(self, session: Session = None):
+    def get_dagrun(self, session: Session = NEW_SESSION):
         """
         Returns the DagRun for this TaskInstance
 
@@ -1136,7 +1154,7 @@ class TaskInstance(Base, LoggingMixin):
         job_id: Optional[str] = None,
         pool: Optional[str] = None,
         external_executor_id: Optional[str] = None,
-        session=None,
+        session=NEW_SESSION,
     ) -> bool:
         """
         Checks dependencies and then sets state to RUNNING if they are met. Returns
@@ -1295,7 +1313,7 @@ class TaskInstance(Base, LoggingMixin):
         job_id: Optional[str] = None,
         pool: Optional[str] = None,
         error_file: Optional[str] = None,
-        session=None,
+        session=NEW_SESSION,
     ) -> None:
         """
         Immediately runs the task (without checking or changing db state
@@ -1462,7 +1480,7 @@ class TaskInstance(Base, LoggingMixin):
         Stats.incr('ti_successes')
 
     @provide_session
-    def _update_ti_state_for_sensing(self, session=None):
+    def _update_ti_state_for_sensing(self, session=NEW_SESSION):
         self.log.info('Submitting %s to sensor service', self)
         self.state = State.SENSING
         self.start_date = timezone.utcnow()
@@ -1606,7 +1624,7 @@ class TaskInstance(Base, LoggingMixin):
         test_mode: bool = False,
         job_id: Optional[str] = None,
         pool: Optional[str] = None,
-        session=None,
+        session=NEW_SESSION,
     ) -> None:
         """Run TaskInstance"""
         res = self.check_and_change_state_before_execution(
@@ -1649,7 +1667,9 @@ class TaskInstance(Base, LoggingMixin):
         task_copy.dry_run()
 
     @provide_session
-    def _handle_reschedule(self, actual_start_date, reschedule_exception, test_mode=False, session=None):
+    def _handle_reschedule(
+        self, actual_start_date, reschedule_exception, test_mode=False, session=NEW_SESSION
+    ):
         # Don't record reschedule request in test mode
         if test_mode:
             return
@@ -1690,7 +1710,7 @@ class TaskInstance(Base, LoggingMixin):
         test_mode: Optional[bool] = None,
         force_fail: bool = False,
         error_file: Optional[str] = None,
-        session=None,
+        session=NEW_SESSION,
     ) -> None:
         """Handle Failure for the TaskInstance"""
         if test_mode is None:
@@ -1761,7 +1781,7 @@ class TaskInstance(Base, LoggingMixin):
         error: Union[str, Exception],
         test_mode: Optional[bool] = None,
         force_fail: bool = False,
-        session=None,
+        session=NEW_SESSION,
     ) -> None:
         self.handle_failure(error=error, test_mode=test_mode, force_fail=force_fail, session=session)
         self._run_finished_callback(error=error)
@@ -1775,7 +1795,9 @@ class TaskInstance(Base, LoggingMixin):
 
         return self.task.retries and self.try_number <= self.max_tries
 
-    def get_template_context(self, session: Session = None, ignore_param_exceptions: bool = True) -> Context:
+    def get_template_context(
+        self, session: Session = NEW_SESSION, ignore_param_exceptions: bool = True
+    ) -> Context:
         """Return TI Context"""
         # Do not use provide_session here -- it expunges everything on exit!
         if not session:
@@ -1798,7 +1820,7 @@ class TaskInstance(Base, LoggingMixin):
             params.update(task.params)
         if conf.getboolean('core', 'dag_run_conf_overrides_params'):
             self.overwrite_params_with_dag_run_conf(params=params, dag_run=dag_run)
-        task.params = params.validate()
+        validated_params = task.params = params.validate()
 
         logical_date = timezone.coerce_datetime(self.execution_date)
         ds = logical_date.strftime('%Y-%m-%d')
@@ -1914,7 +1936,7 @@ class TaskInstance(Base, LoggingMixin):
             'next_ds_nodash': get_next_ds_nodash(),
             'next_execution_date': get_next_execution_date(),
             'outlets': task.outlets,
-            'params': task.params,
+            'params': validated_params,
             'prev_data_interval_start_success': get_prev_data_interval_start_success(),
             'prev_data_interval_end_success': get_prev_data_interval_end_success(),
             'prev_ds': get_prev_ds(),
@@ -1947,7 +1969,7 @@ class TaskInstance(Base, LoggingMixin):
         return Context(context)
 
     @provide_session
-    def get_rendered_template_fields(self, session=None):
+    def get_rendered_template_fields(self, session=NEW_SESSION):
         """Fetch rendered template fields from DB"""
         from airflow.models.renderedtifields import RenderedTaskInstanceFields
 
@@ -1967,7 +1989,7 @@ class TaskInstance(Base, LoggingMixin):
                 ) from e
 
     @provide_session
-    def get_rendered_k8s_spec(self, session=None):
+    def get_rendered_k8s_spec(self, session=NEW_SESSION):
         """Fetch rendered template fields from DB"""
         from airflow.models.renderedtifields import RenderedTaskInstanceFields
 
@@ -2006,7 +2028,7 @@ class TaskInstance(Base, LoggingMixin):
             date=self.execution_date,
             args=self.command_as_list(),
             pod_override_object=PodGenerator.from_obj(self.executor_config),
-            scheduler_job_id="worker-config",
+            scheduler_job_id=0,
             namespace=kube_config.executor_namespace,
             base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file),
         )
@@ -2109,7 +2131,7 @@ class TaskInstance(Base, LoggingMixin):
         key: str,
         value: Any,
         execution_date: Optional[datetime] = None,
-        session: Session = None,
+        session: Session = NEW_SESSION,
     ) -> None:
         """
         Make an XCom available for tasks to pull.
@@ -2149,7 +2171,7 @@ class TaskInstance(Base, LoggingMixin):
         dag_id: Optional[str] = None,
         key: str = XCOM_RETURN_KEY,
         include_prior_dates: bool = False,
-        session: Session = None,
+        session: Session = NEW_SESSION,
     ) -> Any:
         """
         Pull XComs that optionally meet certain criteria.
@@ -2199,7 +2221,11 @@ class TaskInstance(Base, LoggingMixin):
         # Since we're only fetching the values field, and not the
         # whole class, the @recreate annotation does not kick in.
         # Therefore we need to deserialize the fields by ourselves.
-        if is_container(task_ids):
+        if task_ids is None or isinstance(task_ids, str):
+            xcom = query.with_entities(XCom.value).first()
+            if xcom:
+                return XCom.deserialize_value(xcom)
+        else:
             vals_kv = {
                 result.task_id: XCom.deserialize_value(result)
                 for result in query.with_entities(XCom.task_id, XCom.value)
@@ -2207,10 +2233,6 @@ class TaskInstance(Base, LoggingMixin):
 
             values_ordered_by_id = [vals_kv.get(task_id) for task_id in task_ids]
             return values_ordered_by_id
-        else:
-            xcom = query.with_entities(XCom.value).first()
-            if xcom:
-                return XCom.deserialize_value(xcom)
 
     @provide_session
     def get_num_running_task_instances(self, session):
@@ -2261,7 +2283,7 @@ class TaskInstance(Base, LoggingMixin):
                 TaskInstance.task_id == first_task_id,
             )
 
-        if settings.Session.bind.dialect.name == 'mssql':
+        if settings.engine.dialect.name == 'mssql':
             return or_(
                 and_(
                     TaskInstance.dag_id == ti.dag_id,
@@ -2291,7 +2313,7 @@ class SimpleTaskInstance:
     def __init__(self, ti: TaskInstance):
         self._dag_id: str = ti.dag_id
         self._task_id: str = ti.task_id
-        self._run_id: datetime = ti.run_id
+        self._run_id: str = ti.run_id
         self._start_date: datetime = ti.start_date
         self._end_date: datetime = ti.end_date
         self._try_number: int = ti.try_number
diff --git a/airflow/models/variable.py b/airflow/models/variable.py
index 00edb55..b5c3921 100644
--- a/airflow/models/variable.py
+++ b/airflow/models/variable.py
@@ -153,7 +153,7 @@ class Variable(Base, LoggingMixin):
         cls,
         key: str,
         value: Any,
-        description: str = None,
+        description: Optional[str] = None,
         serialize_json: bool = False,
         session: Session = None,
     ):
diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py
index 35223cb..599284c 100644
--- a/airflow/operators/subdag.py
+++ b/airflow/operators/subdag.py
@@ -34,7 +34,7 @@ from airflow.models.dag import DAG, DagContext
 from airflow.models.pool import Pool
 from airflow.models.taskinstance import TaskInstance
 from airflow.sensors.base import BaseSensorOperator
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
 from airflow.utils.state import State
 from airflow.utils.types import DagRunType
 
@@ -69,12 +69,14 @@ class SubDagOperator(BaseSensorOperator):
     ui_color = '#555'
     ui_fgcolor = '#fff'
 
+    subdag: "DAG"
+
     @provide_session
     def __init__(
         self,
         *,
         subdag: DAG,
-        session: Optional[Session] = None,
+        session: Session = NEW_SESSION,
         conf: Optional[Dict] = None,
         propagate_skipped_state: Optional[SkippedStatePropagationOptions] = None,
         **kwargs,
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 2559f1e..07437e0 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -918,7 +918,6 @@ class SerializedDAG(DAG, BaseSerialization):
 
             if serializable_task.subdag is not None:
                 setattr(serializable_task.subdag, 'parent_dag', dag)
-                serializable_task.subdag.is_subdag = True
 
             for task_id in serializable_task.downstream_task_ids:
                 # Bypass set_upstream etc here - it does more than we want
diff --git a/airflow/settings.py b/airflow/settings.py
index 9cfed37..a4b76d5 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -79,7 +79,7 @@ LOGGING_CLASS_PATH: Optional[str] = None
 DONOT_MODIFY_HANDLERS: Optional[bool] = None
 DAGS_FOLDER: str = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
 
-engine: Optional[Engine] = None
+engine: Engine
 Session: Callable[..., SASession]
 
 # The JSON library to use for DAG Serialization and De-Serialization
@@ -378,6 +378,8 @@ def configure_adapters():
 
 def validate_session():
     """Validate ORM Session"""
+    global engine
+
     worker_precheck = conf.getboolean('celery', 'worker_precheck', fallback=False)
     if not worker_precheck:
         return True
diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py
index e97f253..850bc47 100644
--- a/airflow/timetables/base.py
+++ b/airflow/timetables/base.py
@@ -33,7 +33,7 @@ class DataInterval(NamedTuple):
     end: DateTime
 
     @classmethod
-    def exact(cls, at: DateTime) -> "DagRunInfo":
+    def exact(cls, at: DateTime) -> "DataInterval":
         """Represent an "interval" containing only an exact time."""
         return cls(start=at, end=at)
 
diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi
index 0921d79..1249112 100644
--- a/airflow/utils/context.pyi
+++ b/airflow/utils/context.pyi
@@ -41,6 +41,13 @@ class _VariableAccessors(TypedDict):
     json: Any
     value: Any
 
+class VariableAccessor:
+    def __init__(self, *, deserialize_json: bool) -> None: ...
+    def get(self, key, default: Any = ...) -> Any: ...
+
+class ConnectionAccessor:
+    def get(self, key: str, default_conn: Any = None) -> Any: ...
+
 class Context(TypedDict, total=False):
     conf: AirflowConfigParser
     conn: Any
diff --git a/airflow/utils/file.py b/airflow/utils/file.py
index a940a60..a7f45e9 100644
--- a/airflow/utils/file.py
+++ b/airflow/utils/file.py
@@ -21,7 +21,7 @@ import os
 import re
 import zipfile
 from pathlib import Path
-from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Pattern, Union
+from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Pattern, Union, overload
 
 from airflow.configuration import conf
 
@@ -68,7 +68,17 @@ def mkdirs(path, mode):
 ZIP_REGEX = re.compile(fr'((.*\.zip){re.escape(os.sep)})?(.*)')
 
 
-def correct_maybe_zipped(fileloc):
+@overload
+def correct_maybe_zipped(fileloc: None) -> None:
+    ...
+
+
+@overload
+def correct_maybe_zipped(fileloc: Union[str, Path]) -> Union[str, Path]:
+    ...
+
+
+def correct_maybe_zipped(fileloc: Union[None, str, Path]) -> Union[None, str, Path]:
     """
     If the path contains a folder with a .zip suffix, then
     the folder is treated as a zip archive and path to zip is returned.
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index 745cdcd..a3e9c96 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -32,6 +32,7 @@ class TaskInstanceState(str, Enum):
 
     # Set by the scheduler
     # None - Task is created but should not run yet
+    NONE = None
     REMOVED = "removed"  # Task vanished from DAG before it ran
     SCHEDULED = "scheduled"  # Task should run and will be handed to executor soon
 
diff --git a/airflow/utils/timezone.py b/airflow/utils/timezone.py
index e5245d9..1051ee1 100644
--- a/airflow/utils/timezone.py
+++ b/airflow/utils/timezone.py
@@ -17,7 +17,7 @@
 # under the License.
 #
 import datetime as dt
-from typing import Optional, Union
+from typing import TYPE_CHECKING, Optional, overload
 
 import pendulum
 from pendulum.datetime import DateTime
@@ -27,6 +27,9 @@ from airflow.settings import TIMEZONE
 # UTC time zone as a tzinfo instance.
 utc = pendulum.tz.timezone('UTC')
 
+if TYPE_CHECKING:
+    from pendulum.tz.timezone import Timezone
+
 
 def is_localized(value):
     """
@@ -97,7 +100,17 @@ def convert_to_utc(value):
     return value.astimezone(utc)
 
 
-def make_aware(value, timezone=None):
+@overload
+def make_aware(v: None, timezone: Optional["Timezone"] = None) -> None:
+    ...
+
+
+@overload
+def make_aware(v: dt.datetime, timezone: Optional["Timezone"] = None) -> dt.datetime:
+    ...
+
+
+def make_aware(value: Optional[dt.datetime], timezone: Optional["Timezone"] = None) -> Optional[dt.datetime]:
     """
     Make a naive datetime.datetime in a given time zone aware.
 
@@ -175,7 +188,17 @@ def parse(string: str, timezone=None) -> DateTime:
     return pendulum.parse(string, tz=timezone or TIMEZONE, strict=False)  # type: ignore
 
 
-def coerce_datetime(v: Union[None, dt.datetime, DateTime]) -> Optional[DateTime]:
+@overload
+def coerce_datetime(v: None) -> None:
+    ...
+
+
+@overload
+def coerce_datetime(v: dt.datetime) -> DateTime:
+    ...
+
+
+def coerce_datetime(v: Optional[dt.datetime]) -> Optional[DateTime]:
     """Convert whatever is passed in to an timezone-aware ``pendulum.DateTime``."""
     if v is None:
         return None
diff --git a/airflow/utils/types.py b/airflow/utils/types.py
index 9f3c559..04688a7 100644
--- a/airflow/utils/types.py
+++ b/airflow/utils/types.py
@@ -20,6 +20,25 @@ from typing import Optional
 from airflow.typing_compat import TypedDict
 
 
+class ArgNotSet:
+    """Sentinel type for annotations, useful when None is not viable.
+
+    Use like this::
+
+        def is_arg_passed(arg: Union[ArgNotSet, None] = NOTSET) -> bool:
+            if arg is NOTSET:
+                return False
+            return True
+
+        is_arg_passed()  # False.
+        is_arg_passed(None)  # True.
+    """
+
+
+NOTSET = ArgNotSet()
+"""Sentinel value for argument default. See ``ArgNotSet``."""
+
+
 class DagRunType(str, enum.Enum):
     """Class with DagRun types"""
 
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 4921e1e..2ec863d 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -866,7 +866,6 @@ class TestDag(unittest.TestCase):
             )
             # parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
             subdag.parent_dag = dag
-            subdag.is_subdag = True
             SubDagOperator(task_id='subtask', owner='owner2', subdag=subdag)
         session = settings.Session()
         dag.sync_to_db(session=session)
@@ -932,7 +931,6 @@ class TestDag(unittest.TestCase):
 
         # parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
         subdag.parent_dag = dag
-        subdag.is_subdag = True
 
         session.query(DagModel).filter(DagModel.dag_id.in_([subdag_id, dag_id])).delete(
             synchronize_session=False
@@ -1427,7 +1425,6 @@ class TestDag(unittest.TestCase):
         SubDagOperator(task_id='test', subdag=subdag, dag=dag)
         t_2 = DummyOperator(task_id='task', dag=subdag)
         subdag.parent_dag = dag
-        subdag.is_subdag = True
 
         dag.sync_to_db()
 
@@ -1806,7 +1803,6 @@ class TestDag(unittest.TestCase):
         subdag = section_1.subdag
         # parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
         subdag.parent_dag = dag
-        subdag.is_subdag = True
 
         next_parent_info = dag.next_dagrun_info(None)
         assert next_parent_info.logical_date == timezone.datetime(2019, 1, 1, 0, 0)
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index b967e33..6dc2d4c 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -357,7 +357,7 @@ class TestTaskInstance:
         test that try to create a task with pool_slots less than 1
         """
 
-        with pytest.raises(AirflowException):
+        with pytest.raises(ValueError, match="pool slots .* cannot be less than 1"):
             dag = models.DAG(dag_id='test_run_pooling_task')
             DummyOperator(
                 task_id='test_run_pooling_task_op',
@@ -1926,7 +1926,7 @@ class TestTaskInstance:
                     'try_number': '1',
                 },
                 'labels': {
-                    'airflow-worker': 'worker-config',
+                    'airflow-worker': '0',
                     'airflow_version': version,
                     'dag_id': 'test_render_k8s_pod_yaml',
                     'execution_date': '2016-01-01T00_00_00_plus_00_00',
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 577a1df..82c848a 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1094,9 +1094,8 @@ class TestStringifiedDAGs:
         """
         base_operator = BaseOperator(task_id="10")
         fields = base_operator.__dict__
-        assert {
+        assert fields == {
             '_BaseOperator__instantiated': True,
-            '_dag': None,
             '_downstream_task_ids': set(),
             '_inlets': [],
             '_log': base_operator.log,
@@ -1139,12 +1138,11 @@ class TestStringifiedDAGs:
             'run_as_user': None,
             'sla': None,
             'start_date': None,
-            'subdag': None,
             'task_id': '10',
             'trigger_rule': 'all_success',
             'wait_for_downstream': False,
             'weight_rule': 'downstream',
-        } == fields, """
+        }, """
 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
 
      ACTION NEEDED! PLEASE READ THIS CAREFULLY AND CORRECT TESTS CAREFULLY