You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/12/07 15:28:32 UTC
[airflow] branch main updated: Improve handling edge-cases in airlfow.models by applying mypy (#20000)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7d8e3b8 Improve handling edge-cases in airlfow.models by applying mypy (#20000)
7d8e3b8 is described below
commit 7d8e3b828af0ac90261c341f5cb0e57da75e6a83
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Tue Dec 7 15:28:01 2021 +0000
Improve handling edge-cases in airlfow.models by applying mypy (#20000)
* Fix many of the mypy typing issues in airflow.models.dag
And to fix these, I needed to fix a few other mistakes that are
used/called by DAG's methods
* Fix timetable-related typing errors in dag.py
Also moved the sentinel value implementation to a utils module. This
should be useful when fixing typing issues in other modules.
* Add note about assert allowed inside a TYPE_CHECKING conditional
* Fix docs build of airflow.models.dagrun
* Apply NEW_SESSION to dag, dagrun, ti and operator.subdag
Co-authored-by: Tzu-ping Chung <tp...@astronomer.io>
---
CONTRIBUTING.rst | 8 +
airflow/models/base.py | 2 +-
airflow/models/baseoperator.py | 42 ++--
airflow/models/dag.py | 282 ++++++++++++++------------
airflow/models/dagbag.py | 12 +-
airflow/models/dagrun.py | 49 ++---
airflow/models/serialized_dag.py | 2 +-
airflow/models/taskinstance.py | 108 ++++++----
airflow/models/variable.py | 2 +-
airflow/operators/subdag.py | 6 +-
airflow/serialization/serialized_objects.py | 1 -
airflow/settings.py | 4 +-
airflow/timetables/base.py | 2 +-
airflow/utils/context.pyi | 7 +
airflow/utils/file.py | 14 +-
airflow/utils/state.py | 1 +
airflow/utils/timezone.py | 29 ++-
airflow/utils/types.py | 19 ++
tests/models/test_dag.py | 4 -
tests/models/test_taskinstance.py | 4 +-
tests/serialization/test_dag_serialization.py | 6 +-
21 files changed, 363 insertions(+), 241 deletions(-)
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 1f00945..886ff27 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -936,6 +936,14 @@ you should do:
if not some_predicate():
handle_the_case()
+The one exception to this is if you need to make an assert for typechecking (which should be almost a last resort) you can do this:
+
+.. code-block:: python
+
+ if TYPE_CHECKING:
+ assert isinstance(x, MyClass)
+
+
Database Session Handling
-------------------------
diff --git a/airflow/models/base.py b/airflow/models/base.py
index 29a5320..439308d 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -63,4 +63,4 @@ def get_id_collation_args():
COLLATION_ARGS = get_id_collation_args()
-StringID: Type[String] = functools.partial(String, length=ID_LEN, **COLLATION_ARGS)
+StringID: Type[String] = functools.partial(String, length=ID_LEN, **COLLATION_ARGS) # type: ignore
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 86ec47e..88b1590 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -25,6 +25,7 @@ import warnings
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta
from inspect import signature
+from types import FunctionType
from typing import (
TYPE_CHECKING,
Any,
@@ -46,6 +47,7 @@ from typing import (
import attr
import jinja2
+import pendulum
from dateutil.relativedelta import relativedelta
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import NoResultFound
@@ -87,7 +89,7 @@ TaskStateChangeCallback = Callable[[Context], None]
TaskPreExecuteHook = Callable[[Context], None]
TaskPostExecuteHook = Callable[[Context, Any], None]
-T = TypeVar('T', bound=Callable)
+T = TypeVar('T', bound=FunctionType)
class BaseOperatorMeta(abc.ABCMeta):
@@ -483,6 +485,12 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
# Set to True before calling execute method
_lock_for_execution = False
+ _dag: Optional["DAG"] = None
+
+ # subdag parameter is only set for SubDagOperator.
+ # Setting it to None by default as other Operators do not have that field
+ subdag: Optional["DAG"] = None
+
def __init__(
self,
task_id: str,
@@ -612,7 +620,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
self.pool = Pool.DEFAULT_POOL_NAME if pool is None else pool
self.pool_slots = pool_slots
if self.pool_slots < 1:
- raise AirflowException(f"pool slots for {self.task_id} in dag {dag.dag_id} cannot be less than 1")
+ dag_str = f" in dag {dag.dag_id}" if dag else ""
+ raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1")
self.sla = sla
self.execution_timeout = execution_timeout
self.on_execute_callback = on_execute_callback
@@ -636,7 +645,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
self.log.debug("max_retry_delay isn't a timedelta object, assuming secs")
self.max_retry_delay = timedelta(seconds=max_retry_delay)
- self.params = ParamsDict(params)
+ # At execution_time this becomes a normal dict
+ self.params: Union[ParamsDict, dict] = ParamsDict(params)
if priority_weight is not None and not isinstance(priority_weight, int):
raise AirflowException(
f"`priority_weight` for task '{self.task_id}' only accepts integers, "
@@ -673,15 +683,10 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
# Private attributes
self._upstream_task_ids: Set[str] = set()
self._downstream_task_ids: Set[str] = set()
- self._dag = None
-
- self.dag = dag or DagContext.get_current_dag()
-
- # subdag parameter is only set for SubDagOperator.
- # Setting it to None by default as other Operators do not have that field
- from airflow.models.dag import DAG
- self.subdag: Optional[DAG] = None
+ dag = dag or DagContext.get_current_dag()
+ if dag:
+ self.dag = dag
self._log = logging.getLogger("airflow.task.operators")
@@ -811,7 +816,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
@property
def dag(self) -> 'DAG':
"""Returns the Operator's DAG if set, otherwise raises an error"""
- if self.has_dag():
+ if self._dag:
return self._dag
else:
raise AirflowException(f'Operator {self} has not been assigned to a DAG yet')
@@ -840,7 +845,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
def has_dag(self):
"""Returns True if the Operator has been assigned to a DAG."""
- return getattr(self, '_dag', None) is not None
+ return self._dag is not None
@property
def dag_id(self) -> str:
@@ -1301,8 +1306,13 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
from airflow.models import DagRun
from airflow.utils.types import DagRunType
- start_date = start_date or self.start_date
- end_date = end_date or self.end_date or timezone.utcnow()
+ # Assertions for typing -- we need a dag, for this function, and when we have a DAG we are
+ # _guaranteed_ to have start_date (else we couldn't have been added to a DAG)
+ if TYPE_CHECKING:
+ assert self.start_date
+
+ start_date = pendulum.instance(start_date or self.start_date)
+ end_date = pendulum.instance(end_date or self.end_date or timezone.utcnow())
for info in self.dag.iter_dagrun_infos_between(start_date, end_date, align=False):
ignore_depends_on_past = info.logical_date == start_date and ignore_first_depends_on_past
@@ -1325,7 +1335,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
execution_date=info.logical_date,
data_interval=info.data_interval,
)
- ti = TaskInstance(self, run_id=None)
+ ti = TaskInstance(self, run_id=dr.run_id)
ti.dag_run = dr
session.add(dr)
session.flush()
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 52d73b8..5d35960 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -27,7 +27,7 @@ import sys
import traceback
import warnings
from collections import OrderedDict
-from datetime import datetime, timedelta, tzinfo
+from datetime import datetime, timedelta
from inspect import signature
from typing import (
TYPE_CHECKING,
@@ -51,8 +51,10 @@ import jinja2
import pendulum
from dateutil.relativedelta import relativedelta
from jinja2.nativetypes import NativeEnvironment
+from pendulum.tz.timezone import Timezone
from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_
from sqlalchemy.orm import backref, joinedload, relationship
+from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import expression
@@ -81,10 +83,10 @@ from airflow.utils.dates import cron_presets, date_range as utils_date_range
from airflow.utils.file import correct_maybe_zipped
from airflow.utils.helpers import validate_key
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
-from airflow.utils.state import DagRunState, State
-from airflow.utils.types import DagRunType, EdgeInfoType
+from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
if TYPE_CHECKING:
from airflow.utils.task_group import TaskGroup
@@ -95,11 +97,14 @@ log = logging.getLogger(__name__)
DEFAULT_VIEW_PRESETS = ['tree', 'graph', 'duration', 'gantt', 'landing_times']
ORIENTATION_PRESETS = ['LR', 'TB', 'RL', 'BT']
-ScheduleIntervalArgNotSet = type("ScheduleIntervalArgNotSet", (), {})
DagStateChangeCallback = Callable[[Context], None]
-ScheduleInterval = Union[str, timedelta, relativedelta]
-ScheduleIntervalArg = Union[ScheduleInterval, None, Type[ScheduleIntervalArgNotSet]]
+ScheduleInterval = Union[None, str, timedelta, relativedelta]
+
+# FIXME: Ideally this should be Union[Literal[NOTSET], ScheduleInterval],
+# but Mypy cannot handle that right now. Track progress of PEP 661 for progress.
+# See also: https://discuss.python.org/t/9126/7
+ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval]
# Backward compatibility: If neither schedule_interval nor timetable is
@@ -145,9 +150,9 @@ def _get_model_data_interval(
return DataInterval(start, end)
-def create_timetable(interval: ScheduleIntervalArg, timezone: tzinfo) -> Timetable:
+def create_timetable(interval: ScheduleIntervalArg, timezone: Timezone) -> Timetable:
"""Create a Timetable instance from a ``schedule_interval`` argument."""
- if interval is ScheduleIntervalArgNotSet:
+ if interval is NOTSET:
return DeltaDataIntervalTimetable(DEFAULT_SCHEDULE_INTERVAL)
if interval is None:
return NullTimetable()
@@ -319,11 +324,13 @@ class DAG(LoggingMixin):
from a ZIP file or other DAG distribution format.
"""
+ parent_dag: Optional["DAG"] = None # Gets set when DAGs are loaded
+
def __init__(
self,
dag_id: str,
description: Optional[str] = None,
- schedule_interval: ScheduleIntervalArg = ScheduleIntervalArgNotSet,
+ schedule_interval: ScheduleIntervalArg = NOTSET,
timetable: Optional[Timetable] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
@@ -356,15 +363,15 @@ class DAG(LoggingMixin):
self.user_defined_macros = user_defined_macros
self.user_defined_filters = user_defined_filters
self.default_args = copy.deepcopy(default_args or {})
- self.params = params or {}
+ params = params or {}
# merging potentially conflicting default_args['params'] into params
if 'params' in self.default_args:
- self.params.update(self.default_args['params'])
+ params.update(self.default_args['params'])
del self.default_args['params']
# check self.params and convert them into ParamsDict
- self.params = ParamsDict(self.params)
+ self.params = ParamsDict(params)
if full_filepath:
warnings.warn(
@@ -394,15 +401,19 @@ class DAG(LoggingMixin):
self.task_dict: Dict[str, BaseOperator] = {}
# set timezone from start_date
+ tz = None
if start_date and start_date.tzinfo:
- self.timezone = start_date.tzinfo
+ tzinfo = None if start_date.tzinfo else settings.TIMEZONE
+ tz = pendulum.instance(start_date, tz=tzinfo).timezone
elif 'start_date' in self.default_args and self.default_args['start_date']:
- if isinstance(self.default_args['start_date'], str):
- self.default_args['start_date'] = timezone.parse(self.default_args['start_date'])
- self.timezone = self.default_args['start_date'].tzinfo
+ date = self.default_args['start_date']
+ if not isinstance(date, datetime):
+ date = timezone.parse(date)
+ self.default_args['start_date'] = date
- if not hasattr(self, 'timezone') or not self.timezone:
- self.timezone = settings.TIMEZONE
+ tzinfo = None if date.tzinfo else settings.TIMEZONE
+ tz = pendulum.instance(date, tz=tzinfo).timezone
+ self.timezone = tz or settings.TIMEZONE
# Apply the timezone we settled on to end_date if it wasn't supplied
if 'end_date' in self.default_args and self.default_args['end_date']:
@@ -423,10 +434,10 @@ class DAG(LoggingMixin):
# Calculate the DAG's timetable.
if timetable is None:
self.timetable = create_timetable(schedule_interval, self.timezone)
- if schedule_interval is ScheduleIntervalArgNotSet:
+ if isinstance(schedule_interval, ArgNotSet):
schedule_interval = DEFAULT_SCHEDULE_INTERVAL
self.schedule_interval: ScheduleInterval = schedule_interval
- elif schedule_interval is ScheduleIntervalArgNotSet:
+ elif isinstance(schedule_interval, ArgNotSet):
self.timetable = timetable
self.schedule_interval = self.timetable.summary
else:
@@ -436,7 +447,6 @@ class DAG(LoggingMixin):
template_searchpath = [template_searchpath]
self.template_searchpath = template_searchpath
self.template_undefined = template_undefined
- self.parent_dag: Optional[DAG] = None # Gets set when DAGs are loaded
self.last_loaded = timezone.utcnow()
self.safe_dag_id = dag_id.replace('.', '__dot__')
self.max_active_runs = max_active_runs
@@ -457,7 +467,6 @@ class DAG(LoggingMixin):
f'{ORIENTATION_PRESETS}, but get {orientation}'
)
self.catchup = catchup
- self.is_subdag = False # DagBag.bag_dag() will set this to True if appropriate
self.partial = False
self.on_success_callback = on_success_callback
@@ -480,7 +489,7 @@ class DAG(LoggingMixin):
self.jinja_environment_kwargs = jinja_environment_kwargs
self.render_template_as_native_obj = render_template_as_native_obj
- self.tags = tags
+ self.tags = tags or []
self._task_group = TaskGroup.create_root(self)
self.validate_schedule_and_params()
@@ -554,21 +563,22 @@ class DAG(LoggingMixin):
def date_range(
self,
- start_date: datetime,
+ start_date: pendulum.DateTime,
num: Optional[int] = None,
- end_date: Optional[datetime] = timezone.utcnow(),
+ end_date: Optional[datetime] = None,
) -> List[datetime]:
message = "`DAG.date_range()` is deprecated."
if num is not None:
- result = utils_date_range(start_date=start_date, num=num)
- else:
- message += " Please use `DAG.iter_dagrun_infos_between(..., align=False)` instead."
- result = [
- info.logical_date
- for info in self.iter_dagrun_infos_between(start_date, end_date, align=False)
- ]
+ warnings.warn(message, category=DeprecationWarning, stacklevel=2)
+ return utils_date_range(start_date=start_date, num=num)
+ message += " Please use `DAG.iter_dagrun_infos_between(..., align=False)` instead."
warnings.warn(message, category=DeprecationWarning, stacklevel=2)
- return result
+ if end_date is None:
+ coerced_end_date = timezone.utcnow()
+ else:
+ coerced_end_date = end_date
+ it = self.iter_dagrun_infos_between(start_date, pendulum.instance(coerced_end_date), align=False)
+ return [info.logical_date for info in it]
def is_fixed_time_schedule(self):
warnings.warn(
@@ -706,6 +716,8 @@ class DAG(LoggingMixin):
# Never schedule a subdag. It will be scheduled by its parent dag.
if self.is_subdag:
return None
+
+ data_interval = None
if isinstance(last_automated_dagrun, datetime):
warnings.warn(
"Passing a datetime to DAG.next_dagrun_info is deprecated. Use a DataInterval instead.",
@@ -755,17 +767,15 @@ class DAG(LoggingMixin):
start_dates = [t.start_date for t in self.tasks if t.start_date]
if self.start_date is not None:
start_dates.append(self.start_date)
+ earliest = None
if start_dates:
earliest = timezone.coerce_datetime(min(start_dates))
- else:
- earliest = None
end_dates = [t.end_date for t in self.tasks if t.end_date]
if self.end_date is not None:
end_dates.append(self.end_date)
+ latest = None
if end_dates:
latest = timezone.coerce_datetime(max(end_dates))
- else:
- latest = None
return TimeRestriction(earliest, latest, self.catchup)
def iter_dagrun_infos_between(
@@ -793,6 +803,8 @@ class DAG(LoggingMixin):
"""
if earliest is None:
earliest = self._time_restriction.earliest
+ if earliest is None:
+ raise ValueError("earliest was None and we had no value in time_restriction to fallback on")
earliest = timezone.coerce_datetime(earliest)
latest = timezone.coerce_datetime(latest)
@@ -843,7 +855,7 @@ class DAG(LoggingMixin):
except Exception:
self.log.exception(
"Failed to fetch run info after data interval %s for DAG %r",
- info.data_interval,
+ info.data_interval if info else "<NONE>",
self.dag_id,
)
break
@@ -891,13 +903,13 @@ class DAG(LoggingMixin):
return dttm
@provide_session
- def get_last_dagrun(self, session=None, include_externally_triggered=False):
+ def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False):
return get_last_dagrun(
self.dag_id, session=session, include_externally_triggered=include_externally_triggered
)
@provide_session
- def has_dag_runs(self, session=None, include_externally_triggered=True) -> bool:
+ def has_dag_runs(self, session=NEW_SESSION, include_externally_triggered=True) -> bool:
return (
get_last_dagrun(
self.dag_id, session=session, include_externally_triggered=include_externally_triggered
@@ -914,6 +926,10 @@ class DAG(LoggingMixin):
self._dag_id = value
@property
+ def is_subdag(self) -> bool:
+ return self.parent_dag is not None
+
+ @property
def full_filepath(self) -> str:
""":meta private:"""
warnings.warn(
@@ -1042,7 +1058,7 @@ class DAG(LoggingMixin):
return settings.ALLOW_FUTURE_EXEC_DATES and not self.timetable.can_run
@provide_session
- def get_concurrency_reached(self, session=None) -> bool:
+ def get_concurrency_reached(self, session=NEW_SESSION) -> bool:
"""
Returns a boolean indicating whether the max_active_tasks limit for this DAG
has been reached
@@ -1065,13 +1081,13 @@ class DAG(LoggingMixin):
return self.get_concurrency_reached()
@provide_session
- def get_is_active(self, session=None) -> Optional[None]:
+ def get_is_active(self, session=NEW_SESSION) -> Optional[None]:
"""Returns a boolean indicating whether this DAG is active"""
qry = session.query(DagModel).filter(DagModel.dag_id == self.dag_id)
return qry.value(DagModel.is_active)
@provide_session
- def get_is_paused(self, session=None) -> Optional[None]:
+ def get_is_paused(self, session=NEW_SESSION) -> Optional[None]:
"""Returns a boolean indicating whether this DAG is paused"""
qry = session.query(DagModel).filter(DagModel.dag_id == self.dag_id)
return qry.value(DagModel.is_paused)
@@ -1087,14 +1103,14 @@ class DAG(LoggingMixin):
return self.get_is_paused()
@property
- def normalized_schedule_interval(self) -> Optional[ScheduleInterval]:
+ def normalized_schedule_interval(self) -> ScheduleInterval:
warnings.warn(
"DAG.normalized_schedule_interval() is deprecated.",
category=DeprecationWarning,
stacklevel=2,
)
if isinstance(self.schedule_interval, str) and self.schedule_interval in cron_presets:
- _schedule_interval = cron_presets.get(self.schedule_interval) # type: Optional[ScheduleInterval]
+ _schedule_interval: ScheduleInterval = cron_presets.get(self.schedule_interval)
elif self.schedule_interval == '@once':
_schedule_interval = None
else:
@@ -1102,7 +1118,7 @@ class DAG(LoggingMixin):
return _schedule_interval
@provide_session
- def handle_callback(self, dagrun, success=True, reason=None, session=None):
+ def handle_callback(self, dagrun, success=True, reason=None, session=NEW_SESSION):
"""
Triggers the appropriate callback depending on the value of success, namely the
on_failure_callback or on_success_callback. This method gets the context of a
@@ -1146,7 +1162,7 @@ class DAG(LoggingMixin):
return active_dates
@provide_session
- def get_num_active_runs(self, external_trigger=None, only_running=True, session=None):
+ def get_num_active_runs(self, external_trigger=None, only_running=True, session=NEW_SESSION):
"""
Returns the number of active "running" dag runs
@@ -1174,7 +1190,7 @@ class DAG(LoggingMixin):
self,
execution_date: Optional[str] = None,
run_id: Optional[str] = None,
- session: Optional[Session] = None,
+ session: Session = NEW_SESSION,
):
"""
Returns the dag run for a given execution date or run_id if it exists, otherwise
@@ -1195,7 +1211,7 @@ class DAG(LoggingMixin):
return query.first()
@provide_session
- def get_dagruns_between(self, start_date, end_date, session=None):
+ def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION):
"""
Returns the list of dag runs between start_date (inclusive) and end_date (inclusive).
@@ -1223,9 +1239,9 @@ class DAG(LoggingMixin):
@property
def latest_execution_date(self):
- """This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date` method."""
+ """This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`."""
warnings.warn(
- "This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date` method.",
+ "This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.",
DeprecationWarning,
stacklevel=2,
)
@@ -1297,7 +1313,7 @@ class DAG(LoggingMixin):
base_date: datetime,
num: int,
*,
- session: Session,
+ session: Session = NEW_SESSION,
) -> List[TaskInstance]:
"""Get ``num`` task instances before (including) ``base_date``.
@@ -1324,25 +1340,35 @@ class DAG(LoggingMixin):
@provide_session
def get_task_instances(
- self, start_date=None, end_date=None, state=None, session=None
+ self,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
+ state: Optional[List[TaskInstanceState]] = None,
+ session: Session = NEW_SESSION,
) -> List[TaskInstance]:
if not start_date:
- start_date = (timezone.utcnow() - timedelta(30)).date()
- start_date = timezone.make_aware(datetime.combine(start_date, datetime.min.time()))
+ start_date = (timezone.utcnow() - timedelta(30)).replace(
+ hour=0, minute=0, second=0, microsecond=0
+ )
+
+ if state is None:
+ state = []
return (
- self._get_task_instances(
- task_ids=None,
- start_date=start_date,
- end_date=end_date,
- run_id=None,
- state=state,
- include_subdags=False,
- include_parentdag=False,
- include_dependent_dags=False,
- exclude_task_ids=[],
- as_pk_tuple=False,
- session=session,
+ cast(
+ Query,
+ self._get_task_instances(
+ task_ids=None,
+ start_date=start_date,
+ end_date=end_date,
+ run_id=None,
+ state=state,
+ include_subdags=False,
+ include_parentdag=False,
+ include_dependent_dags=False,
+ exclude_task_ids=cast(List[str], []),
+ session=session,
+ ),
)
.join(TaskInstance.dag_run)
.order_by(DagRun.execution_date)
@@ -1356,19 +1382,15 @@ class DAG(LoggingMixin):
task_ids,
start_date: Optional[datetime],
end_date: Optional[datetime],
- run_id: None,
- state: Union[str, List[str]],
+ run_id: Optional[str],
+ state: Union[TaskInstanceState, List[TaskInstanceState]],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
exclude_task_ids: Collection[str],
- as_pk_tuple: Literal[True],
session: Session,
- dag_bag: "DagBag" = None,
- recursion_depth: int = 0,
- max_recursion_depth: int = None,
- visited_external_tis: Set[Tuple[str, str, datetime]] = None,
- ) -> Set["TaskInstanceKey"]:
+ dag_bag: Optional["DagBag"] = ...,
+ ) -> Iterable[TaskInstance]:
... # pragma: no cover
@overload
@@ -1376,41 +1398,41 @@ class DAG(LoggingMixin):
self,
*,
task_ids,
+ as_pk_tuple: Literal[True],
start_date: Optional[datetime],
end_date: Optional[datetime],
run_id: Optional[str],
- state: Union[str, List[str]],
+ state: Union[TaskInstanceState, List[TaskInstanceState]],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
- as_pk_tuple: Literal[False],
exclude_task_ids: Collection[str],
session: Session,
- dag_bag: "DagBag" = None,
- recursion_depth: int = 0,
- max_recursion_depth: int = None,
- visited_external_tis: Set[Tuple[str, str, datetime]] = None,
- ) -> Iterable[TaskInstance]:
+ dag_bag: Optional["DagBag"] = ...,
+ recursion_depth: int = ...,
+ max_recursion_depth: int = ...,
+ visited_external_tis: Set[TaskInstanceKey] = ...,
+ ) -> Set["TaskInstanceKey"]:
... # pragma: no cover
def _get_task_instances(
self,
*,
task_ids,
+ as_pk_tuple: Literal[True, None] = None,
start_date: Optional[datetime],
end_date: Optional[datetime],
run_id: Optional[str],
- state: Union[str, List[str]],
+ state: Union[TaskInstanceState, List[TaskInstanceState]],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
- as_pk_tuple: bool,
exclude_task_ids: Collection[str],
session: Session,
- dag_bag: "DagBag" = None,
+ dag_bag: Optional["DagBag"] = None,
recursion_depth: int = 0,
- max_recursion_depth: int = None,
- visited_external_tis: Set[Tuple[str, str, datetime]] = None,
+ max_recursion_depth: Optional[int] = None,
+ visited_external_tis: Optional[Set[TaskInstanceKey]] = None,
) -> Union[Iterable[TaskInstance], Set[TaskInstanceKey]]:
TI = TaskInstance
@@ -1452,7 +1474,7 @@ class DAG(LoggingMixin):
tis = tis.filter(DagRun.execution_date <= end_date)
if state:
- if isinstance(state, str):
+ if isinstance(state, (str, TaskInstanceState)):
tis = tis.filter(TaskInstance.state == state)
elif len(state) == 1:
tis = tis.filter(TaskInstance.state == state[0])
@@ -1470,7 +1492,11 @@ class DAG(LoggingMixin):
tis = tis.filter(TaskInstance.state.in_(state))
# Next, get any of them from our parent DAG (if there is one)
- if include_parentdag and self.is_subdag and self.parent_dag is not None:
+ if include_parentdag and self.parent_dag is not None:
+
+ if visited_external_tis is None:
+ visited_external_tis = set()
+
p_dag = self.parent_dag.partial_subset(
task_ids_or_regex=r"^{}$".format(self.dag_id.split('.')[1]),
include_upstream=False,
@@ -1611,7 +1637,7 @@ class DAG(LoggingMixin):
future: Optional[bool] = False,
past: Optional[bool] = False,
commit: Optional[bool] = True,
- session=None,
+ session=NEW_SESSION,
) -> List[TaskInstance]:
"""
Set the state of a TaskInstance to the given state, and clear its downstream tasks that are
@@ -1747,10 +1773,10 @@ class DAG(LoggingMixin):
def set_dag_runs_state(
self,
state: str = State.RUNNING,
- session: Session = None,
+ session: Session = NEW_SESSION,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
- dag_ids: List[str] = None,
+ dag_ids: List[str] = [],
) -> None:
warnings.warn(
"This method is deprecated and will be removed in a future version.",
@@ -1769,22 +1795,22 @@ class DAG(LoggingMixin):
def clear(
self,
task_ids=None,
- start_date=None,
- end_date=None,
- only_failed=False,
- only_running=False,
- confirm_prompt=False,
- include_subdags=True,
- include_parentdag=True,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
+ only_failed: bool = False,
+ only_running: bool = False,
+ confirm_prompt: bool = False,
+ include_subdags: bool = True,
+ include_parentdag: bool = True,
dag_run_state: DagRunState = DagRunState.QUEUED,
- dry_run=False,
- session=None,
- get_tis=False,
- recursion_depth=0,
- max_recursion_depth=None,
- dag_bag=None,
+ dry_run: bool = False,
+ session: Session = NEW_SESSION,
+ get_tis: bool = False,
+ recursion_depth: int = 0,
+ max_recursion_depth: Optional[int] = None,
+ dag_bag: Optional["DagBag"] = None,
exclude_task_ids: FrozenSet[str] = frozenset({}),
- ):
+ ) -> Union[int, Iterable[TaskInstance]]:
"""
Clears a set of task instances associated with the current dag for
a specified date range.
@@ -1841,11 +1867,9 @@ class DAG(LoggingMixin):
state = []
if only_failed:
state += [State.FAILED, State.UPSTREAM_FAILED]
- only_failed = None
if only_running:
# Yes, having `+=` doesn't make sense, but this was the existing behaviour
state += [State.RUNNING]
- only_running = None
tis = self._get_task_instances(
task_ids=task_ids,
@@ -1856,7 +1880,6 @@ class DAG(LoggingMixin):
include_subdags=include_subdags,
include_parentdag=include_parentdag,
include_dependent_dags=include_subdags, # compat, yes this is not a typo
- as_pk_tuple=False,
session=session,
dag_bag=dag_bag,
exclude_task_ids=exclude_task_ids,
@@ -1865,7 +1888,7 @@ class DAG(LoggingMixin):
if dry_run:
return tis
- tis = tis.all()
+ tis = list(tis)
count = len(tis)
do_it = True
@@ -2095,7 +2118,7 @@ class DAG(LoggingMixin):
return d
@provide_session
- def pickle(self, session=None) -> DagPickle:
+ def pickle(self, session=NEW_SESSION) -> DagPickle:
dag = session.query(DagModel).filter(DagModel.dag_id == self.dag_id).first()
dp = None
if dag and dag.pickle_id:
@@ -2278,7 +2301,7 @@ class DAG(LoggingMixin):
external_trigger: Optional[bool] = False,
conf: Optional[dict] = None,
run_type: Optional[DagRunType] = None,
- session=None,
+ session=NEW_SESSION,
dag_hash: Optional[str] = None,
creating_job_id: Optional[int] = None,
data_interval: Optional[Tuple[datetime, datetime]] = None,
@@ -2367,7 +2390,7 @@ class DAG(LoggingMixin):
@classmethod
@provide_session
- def bulk_sync_to_db(cls, dags: Collection["DAG"], session=None):
+ def bulk_sync_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
"""This method is deprecated in favor of bulk_write_to_db"""
warnings.warn(
"This method is deprecated and will be removed in a future version. Please use bulk_write_to_db",
@@ -2378,7 +2401,7 @@ class DAG(LoggingMixin):
@classmethod
@provide_session
- def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
+ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
"""
Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB, including
calculated fields.
@@ -2491,7 +2514,7 @@ class DAG(LoggingMixin):
cls.bulk_write_to_db(dag.subdags, session=session)
@provide_session
- def sync_to_db(self, session=None):
+ def sync_to_db(self, session=NEW_SESSION):
"""
Save attributes about this DAG to the DB. Note that this method
can be called for both DAGs and SubDAGs. A SubDag is actually a
@@ -2510,7 +2533,7 @@ class DAG(LoggingMixin):
@staticmethod
@provide_session
- def deactivate_unknown_dags(active_dag_ids, session=None):
+ def deactivate_unknown_dags(active_dag_ids, session=NEW_SESSION):
"""
Given a list of known DAGs, deactivate any other DAGs that are
marked as active in the ORM
@@ -2528,7 +2551,7 @@ class DAG(LoggingMixin):
@staticmethod
@provide_session
- def deactivate_stale_dags(expiration_date, session=None):
+ def deactivate_stale_dags(expiration_date, session=NEW_SESSION):
"""
Deactivate any DAGs that were last touched by the scheduler before
the expiration date. These DAGs were likely deleted.
@@ -2554,7 +2577,7 @@ class DAG(LoggingMixin):
@staticmethod
@provide_session
- def get_num_task_instances(dag_id, task_ids=None, states=None, session=None):
+ def get_num_task_instances(dag_id, task_ids=None, states=None, session=NEW_SESSION):
"""
Returns the number of task instances in the given DAG.
@@ -2604,7 +2627,6 @@ class DAG(LoggingMixin):
'params',
'_pickle_id',
'_log',
- 'is_subdag',
'task_dict',
'template_searchpath',
'sla_miss_callback',
@@ -2621,13 +2643,14 @@ class DAG(LoggingMixin):
def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType:
"""
Returns edge information for the given pair of tasks if present, and
- None if there is no information.
+ an empty edge if there is no information.
"""
# Note - older serialized DAGs may not have edge_info being a dict at all
+ empty = cast(EdgeInfoType, {})
if self.edge_info:
- return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, {})
+ return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty)
else:
- return {}
+ return empty
def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType):
"""
@@ -2778,23 +2801,23 @@ class DagModel(Base):
@staticmethod
@provide_session
- def get_dagmodel(dag_id, session=None):
+ def get_dagmodel(dag_id, session=NEW_SESSION):
return session.query(DagModel).options(joinedload(DagModel.parent_dag)).get(dag_id)
@classmethod
@provide_session
- def get_current(cls, dag_id, session=None):
+ def get_current(cls, dag_id, session=NEW_SESSION):
return session.query(cls).filter(cls.dag_id == dag_id).first()
@provide_session
- def get_last_dagrun(self, session=None, include_externally_triggered=False):
+ def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False):
return get_last_dagrun(
self.dag_id, session=session, include_externally_triggered=include_externally_triggered
)
@staticmethod
@provide_session
- def get_paused_dag_ids(dag_ids: List[str], session: Session = None) -> Set[str]:
+ def get_paused_dag_ids(dag_ids: List[str], session: Session = NEW_SESSION) -> Set[str]:
"""
Given a list of dag_ids, get a set of Paused Dag Ids
@@ -2837,7 +2860,7 @@ class DagModel(Base):
return path
@provide_session
- def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=None) -> None:
+ def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=NEW_SESSION) -> None:
"""
Pause/Un-pause a DAG.
@@ -2857,7 +2880,7 @@ class DagModel(Base):
@classmethod
@provide_session
- def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=None):
+ def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=NEW_SESSION):
"""
Set ``is_active=False`` on the DAGs for which the DAG files have been removed.
@@ -2913,6 +2936,7 @@ class DagModel(Base):
:param most_recent_dag_run: DataInterval (or datetime) of most recent run of this dag, or none
if not yet scheduled.
"""
+ most_recent_data_interval: Optional[DataInterval]
if isinstance(most_recent_dag_run, datetime):
warnings.warn(
"Passing a datetime to `DagModel.calculate_dagrun_date_fields` is deprecated. "
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 9f5a135..2c737e2 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -218,8 +218,8 @@ class DagBag(LoggingMixin):
root_dag_id = dag_id
if dag_id in self.dags:
dag = self.dags[dag_id]
- if dag.is_subdag:
- root_dag_id = dag.parent_dag.dag_id # type: ignore
+ if dag.parent_dag:
+ root_dag_id = dag.parent_dag.dag_id
# If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized?
orm_dag = DagModel.get_current(root_dag_id, session=session)
@@ -234,7 +234,7 @@ class DagBag(LoggingMixin):
self.dags = {
key: dag
for key, dag in self.dags.items()
- if root_dag_id != key and not (dag.is_subdag and root_dag_id == dag.parent_dag.dag_id)
+ if root_dag_id != key and not (dag.parent_dag and root_dag_id == dag.parent_dag.dag_id)
}
if is_missing or is_expired:
# Reprocess source file.
@@ -397,7 +397,6 @@ class DagBag(LoggingMixin):
for (dag, mod) in top_level_dags:
dag.fileloc = mod.__file__
try:
- dag.is_subdag = False
dag.timetable.validate()
self.bag_dag(dag=dag, root_dag=dag)
found_dags.append(dag)
@@ -451,7 +450,6 @@ class DagBag(LoggingMixin):
for subdag in subdags:
subdag.fileloc = dag.fileloc
subdag.parent_dag = dag
- subdag.is_subdag = True
self._bag_dag(dag=subdag, root_dag=root_dag, recursive=False)
prev_dag = self.dags.get(dag.dag_id)
@@ -572,7 +570,7 @@ class DagBag(LoggingMixin):
return report
@provide_session
- def sync_to_db(self, session: Optional[Session] = None):
+ def sync_to_db(self, session: Session = None):
"""Save attributes about list of DAG to the DB."""
# To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag
from airflow.models.dag import DAG
@@ -628,7 +626,7 @@ class DagBag(LoggingMixin):
self.import_errors.update(dict(serialize_errors))
@provide_session
- def _sync_perm_for_dag(self, dag, session: Optional[Session] = None):
+ def _sync_perm_for_dag(self, dag, session: Session = None):
"""Sync DAG specific permissions, if necessary"""
from airflow.security.permissions import DAG_ACTIONS, resource_name_for_dag
from airflow.www.fab_security.sqla.models import Action, Permission, Resource
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 20ec7cd..f9ced3a 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import os
import warnings
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
@@ -48,10 +49,10 @@ from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES
from airflow.utils import callback_requests, timezone
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
-from airflow.utils.types import DagRunType
+from airflow.utils.types import NOTSET, ArgNotSet, DagRunType
if TYPE_CHECKING:
from airflow.models.dag import DAG
@@ -75,8 +76,6 @@ class DagRun(Base, LoggingMixin):
__tablename__ = "dag_run"
- __NO_VALUE = object()
-
id = Column(Integer, primary_key=True)
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
queued_at = Column(UtcDateTime)
@@ -96,7 +95,11 @@ class DagRun(Base, LoggingMixin):
last_scheduling_decision = Column(UtcDateTime)
dag_hash = Column(String(32))
- dag = None
+ # Remove this `if` after upgrading Sphinx-AutoAPI
+ if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
+ dag: "Optional[DAG]"
+ else:
+ dag: "Optional[DAG]" = None
__table_args__ = (
Index('dag_id_state', dag_id, _state),
@@ -138,7 +141,7 @@ class DagRun(Base, LoggingMixin):
self,
dag_id: Optional[str] = None,
run_id: Optional[str] = None,
- queued_at: Optional[datetime] = __NO_VALUE,
+ queued_at: Union[datetime, None, ArgNotSet] = NOTSET, # type: ignore
execution_date: Optional[datetime] = None,
start_date: Optional[datetime] = None,
external_trigger: Optional[bool] = None,
@@ -163,7 +166,7 @@ class DagRun(Base, LoggingMixin):
self.conf = conf or {}
if state is not None:
self.state = state
- if queued_at is self.__NO_VALUE:
+ if queued_at is NOTSET:
self.queued_at = timezone.utcnow() if state == State.QUEUED else None
else:
self.queued_at = queued_at
@@ -203,7 +206,7 @@ class DagRun(Base, LoggingMixin):
return synonym('_state', descriptor=property(self.get_state, self.set_state))
@provide_session
- def refresh_from_db(self, session: Session = None):
+ def refresh_from_db(self, session: Session = NEW_SESSION) -> None:
"""
Reloads the current dagrun from the database
@@ -299,7 +302,7 @@ class DagRun(Base, LoggingMixin):
external_trigger: Optional[bool] = None,
no_backfills: bool = False,
run_type: Optional[DagRunType] = None,
- session: Session = None,
+ session: Session = NEW_SESSION,
execution_start_date: Optional[datetime] = None,
execution_end_date: Optional[datetime] = None,
) -> List["DagRun"]:
@@ -363,7 +366,7 @@ class DagRun(Base, LoggingMixin):
dag_id: str,
run_id: str,
execution_date: datetime,
- session: Session = None,
+ session: Session = NEW_SESSION,
) -> Optional['DagRun']:
"""
Return an existing run for the DAG with a specific run_id or execution_date.
@@ -412,7 +415,7 @@ class DagRun(Base, LoggingMixin):
tis = tis.filter(TI.state == state)
else:
# this is required to deal with NULL values
- if None in state:
+ if TaskInstanceState.NONE in state:
if all(x is None for x in state):
tis = tis.filter(TI.state.is_(None))
else:
@@ -426,7 +429,7 @@ class DagRun(Base, LoggingMixin):
return tis.all()
@provide_session
- def get_task_instance(self, task_id: str, session: Session = None) -> Optional[TI]:
+ def get_task_instance(self, task_id: str, session: Session = NEW_SESSION) -> Optional[TI]:
"""
Returns the task instance specified by task_id for this dag run
@@ -454,7 +457,7 @@ class DagRun(Base, LoggingMixin):
@provide_session
def get_previous_dagrun(
- self, state: Optional[DagRunState] = None, session: Session = None
+ self, state: Optional[DagRunState] = None, session: Session = NEW_SESSION
) -> Optional['DagRun']:
"""The previous DagRun, if there is one"""
filters = [
@@ -466,7 +469,7 @@ class DagRun(Base, LoggingMixin):
return session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first()
@provide_session
- def get_previous_scheduled_dagrun(self, session: Session = None) -> Optional['DagRun']:
+ def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) -> Optional['DagRun']:
"""The previous, SCHEDULED DagRun, if there is one"""
return (
session.query(DagRun)
@@ -481,7 +484,7 @@ class DagRun(Base, LoggingMixin):
@provide_session
def update_state(
- self, session: Session = None, execute_callbacks: bool = True
+ self, session: Session = NEW_SESSION, execute_callbacks: bool = True
) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]:
"""
Determines the overall state of the DagRun based on the state
@@ -528,7 +531,7 @@ class DagRun(Base, LoggingMixin):
# if all roots finished and at least one failed, the run failed
if not unfinished_tasks and any(leaf_ti.state in State.failed_states for leaf_ti in leaf_tis):
self.log.error('Marking run %s failed', self)
- self.set_state(State.FAILED)
+ self.set_state(DagRunState.FAILED)
if execute_callbacks:
dag.handle_callback(self, success=False, reason='task_failure', session=session)
elif dag.has_on_failure_callback:
@@ -543,7 +546,7 @@ class DagRun(Base, LoggingMixin):
# if all leaves succeeded and no unfinished tasks, the run succeeded
elif not unfinished_tasks and all(leaf_ti.state in State.success_states for leaf_ti in leaf_tis):
self.log.info('Marking run %s successful', self)
- self.set_state(State.SUCCESS)
+ self.set_state(DagRunState.SUCCESS)
if execute_callbacks:
dag.handle_callback(self, success=True, reason='success', session=session)
elif dag.has_on_success_callback:
@@ -564,7 +567,7 @@ class DagRun(Base, LoggingMixin):
and not are_runnable_tasks
):
self.log.error('Deadlock; marking run %s failed', self)
- self.set_state(State.FAILED)
+ self.set_state(DagRunState.FAILED)
if execute_callbacks:
dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session)
elif dag.has_on_failure_callback:
@@ -578,9 +581,9 @@ class DagRun(Base, LoggingMixin):
# finally, if the roots aren't done, the dag is still running
else:
- self.set_state(State.RUNNING)
+ self.set_state(DagRunState.RUNNING)
- if self._state == State.FAILED or self._state == State.SUCCESS:
+ if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS:
msg = (
"DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, "
"run_start_date=%s, run_end_date=%s, run_duration=%s, "
@@ -613,7 +616,7 @@ class DagRun(Base, LoggingMixin):
return schedulable_tis, callback
@provide_session
- def task_instance_scheduling_decisions(self, session: Session = None) -> TISchedulingDecision:
+ def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision:
schedulable_tis: List[TI] = []
changed_tis = False
@@ -759,7 +762,7 @@ class DagRun(Base, LoggingMixin):
Stats.timing(f'dagrun.duration.failed.{self.dag_id}', duration)
@provide_session
- def verify_integrity(self, session: Session = None):
+ def verify_integrity(self, session: Session = NEW_SESSION):
"""
Verifies the DagRun by checking for removed tasks or tasks that are not in the
database yet. It will set state to removed or add the task if required.
@@ -869,7 +872,7 @@ class DagRun(Base, LoggingMixin):
)
@provide_session
- def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) -> int:
+ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SESSION) -> int:
"""
Set the given task instances in to the scheduled state.
diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py
index a9e359e..0a68587 100644
--- a/airflow/models/serialized_dag.py
+++ b/airflow/models/serialized_dag.py
@@ -73,7 +73,7 @@ class SerializedDagModel(Base):
dag_runs = relationship(
DagRun,
- primaryjoin=dag_id == foreign(DagRun.dag_id),
+ primaryjoin=dag_id == foreign(DagRun.dag_id), # type: ignore
backref=backref('serialized_dag', uselist=False, innerjoin=True),
)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 2638739..8d7bd36 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -27,7 +27,20 @@ from collections import defaultdict
from datetime import datetime, timedelta
from functools import partial
from tempfile import NamedTemporaryFile
-from typing import IO, TYPE_CHECKING, Any, Iterable, List, NamedTuple, Optional, Tuple, Union
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
from urllib.parse import quote
import dill
@@ -86,13 +99,12 @@ from airflow.typing_compat import Literal
from airflow.utils import timezone
from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor
from airflow.utils.email import send_email
-from airflow.utils.helpers import is_container
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.platform import getuser
from airflow.utils.retries import run_with_db_retries
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
from airflow.utils.state import DagRunState, State
from airflow.utils.timeout import timeout
@@ -117,7 +129,7 @@ if TYPE_CHECKING:
@contextlib.contextmanager
-def set_current_context(context: Context) -> None:
+def set_current_context(context: Context) -> Iterator[Context]:
"""
Sets the current execution context to the provided context object.
This method should be called once per Task execution, before calling operator.execute.
@@ -179,7 +191,9 @@ def clear_task_instances(
:param activate_dag_runs: Deprecated parameter, do not pass
"""
job_ids = []
- task_id_by_key = defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
+ task_id_by_key: Dict[str, Dict[str, Dict[int, Set[str]]]] = defaultdict(
+ lambda: defaultdict(lambda: defaultdict(set))
+ )
for ti in tis:
if ti.state == State.RUNNING:
if ti.job_id:
@@ -404,7 +418,11 @@ class TaskInstance(Base, LoggingMixin):
execution_date = association_proxy("dag_run", "execution_date")
def __init__(
- self, task, execution_date: Optional[datetime] = None, run_id: str = None, state: Optional[str] = None
+ self,
+ task: "BaseOperator",
+ execution_date: Optional[datetime] = None,
+ run_id: Optional[str] = None,
+ state: Optional[str] = None,
):
super().__init__()
self.dag_id = task.dag_id
@@ -562,7 +580,7 @@ class TaskInstance(Base, LoggingMixin):
def generate_command(
dag_id: str,
task_id: str,
- run_id: str = None,
+ run_id: str,
mark_success: bool = False,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
@@ -666,7 +684,7 @@ class TaskInstance(Base, LoggingMixin):
)
@provide_session
- def current_state(self, session=None) -> str:
+ def current_state(self, session=NEW_SESSION) -> str:
"""
Get the very latest state from the database, if a session is passed,
we use and looking up the state becomes part of the session, otherwise
@@ -691,7 +709,7 @@ class TaskInstance(Base, LoggingMixin):
return state
@provide_session
- def error(self, session=None):
+ def error(self, session=NEW_SESSION):
"""
Forces the task instance's state to FAILED in the database.
@@ -704,7 +722,7 @@ class TaskInstance(Base, LoggingMixin):
session.commit()
@provide_session
- def refresh_from_db(self, session=None, lock_for_update=False) -> None:
+ def refresh_from_db(self, session=NEW_SESSION, lock_for_update=False) -> None:
"""
Refreshes the task instance from the database based on the primary key
@@ -760,7 +778,7 @@ class TaskInstance(Base, LoggingMixin):
self.log.debug("Refreshed TaskInstance %s", self)
- def refresh_from_task(self, task, pool_override=None):
+ def refresh_from_task(self, task: "BaseOperator", pool_override=None):
"""
Copy common attributes from the given task.
@@ -780,7 +798,7 @@ class TaskInstance(Base, LoggingMixin):
self.operator = task.task_type
@provide_session
- def clear_xcom_data(self, session=None):
+ def clear_xcom_data(self, session=NEW_SESSION):
"""
Clears all XCom data from the database for the task instance
@@ -802,7 +820,7 @@ class TaskInstance(Base, LoggingMixin):
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number)
@provide_session
- def set_state(self, state: str, session=None):
+ def set_state(self, state: str, session=NEW_SESSION):
"""
Set TaskInstance state.
@@ -830,7 +848,7 @@ class TaskInstance(Base, LoggingMixin):
return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
@provide_session
- def are_dependents_done(self, session=None):
+ def are_dependents_done(self, session=NEW_SESSION):
"""
Checks whether the immediate dependents of this task instance have succeeded or have been skipped.
This is meant to be used by wait_for_downstream.
@@ -880,7 +898,7 @@ class TaskInstance(Base, LoggingMixin):
# XXX: This uses DAG internals, but as the outer comment
# said, the block is only reached for legacy reasons for
# development code, so that's OK-ish.
- schedule = dag.timetable._schedule
+ schedule = dag.timetable._schedule # type: ignore
except AttributeError:
return None
dt = pendulum.instance(self.execution_date)
@@ -908,7 +926,7 @@ class TaskInstance(Base, LoggingMixin):
@provide_session
def get_previous_ti(
- self, state: Optional[str] = None, session: Session = None
+ self, state: Optional[str] = None, session: Session = NEW_SESSION
) -> Optional['TaskInstance']:
"""
The task instance for the task that ran before this task instance.
@@ -957,7 +975,7 @@ class TaskInstance(Base, LoggingMixin):
def get_previous_execution_date(
self,
state: Optional[str] = None,
- session: Session = None,
+ session: Session = NEW_SESSION,
) -> Optional[pendulum.DateTime]:
"""
The execution date from property previous_ti_success.
@@ -971,7 +989,7 @@ class TaskInstance(Base, LoggingMixin):
@provide_session
def get_previous_start_date(
- self, state: Optional[str] = None, session: Session = None
+ self, state: Optional[str] = None, session: Session = NEW_SESSION
) -> Optional[pendulum.DateTime]:
"""
The start date from property previous_ti_success.
@@ -1001,7 +1019,7 @@ class TaskInstance(Base, LoggingMixin):
return self.get_previous_start_date(state=State.SUCCESS)
@provide_session
- def are_dependencies_met(self, dep_context=None, session=None, verbose=False):
+ def are_dependencies_met(self, dep_context=None, session=NEW_SESSION, verbose=False):
"""
Returns whether or not all the conditions are met for this task instance to be run
given the context for the dependencies (e.g. a task instance being force run from
@@ -1036,7 +1054,7 @@ class TaskInstance(Base, LoggingMixin):
return True
@provide_session
- def get_failed_dep_statuses(self, dep_context=None, session=None):
+ def get_failed_dep_statuses(self, dep_context=None, session=NEW_SESSION):
"""Get failed Dependencies"""
dep_context = dep_context or DepContext()
for dep in dep_context.deps | self.task.deps:
@@ -1103,7 +1121,7 @@ class TaskInstance(Base, LoggingMixin):
return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow()
@provide_session
- def get_dagrun(self, session: Session = None):
+ def get_dagrun(self, session: Session = NEW_SESSION):
"""
Returns the DagRun for this TaskInstance
@@ -1136,7 +1154,7 @@ class TaskInstance(Base, LoggingMixin):
job_id: Optional[str] = None,
pool: Optional[str] = None,
external_executor_id: Optional[str] = None,
- session=None,
+ session=NEW_SESSION,
) -> bool:
"""
Checks dependencies and then sets state to RUNNING if they are met. Returns
@@ -1295,7 +1313,7 @@ class TaskInstance(Base, LoggingMixin):
job_id: Optional[str] = None,
pool: Optional[str] = None,
error_file: Optional[str] = None,
- session=None,
+ session=NEW_SESSION,
) -> None:
"""
Immediately runs the task (without checking or changing db state
@@ -1462,7 +1480,7 @@ class TaskInstance(Base, LoggingMixin):
Stats.incr('ti_successes')
@provide_session
- def _update_ti_state_for_sensing(self, session=None):
+ def _update_ti_state_for_sensing(self, session=NEW_SESSION):
self.log.info('Submitting %s to sensor service', self)
self.state = State.SENSING
self.start_date = timezone.utcnow()
@@ -1606,7 +1624,7 @@ class TaskInstance(Base, LoggingMixin):
test_mode: bool = False,
job_id: Optional[str] = None,
pool: Optional[str] = None,
- session=None,
+ session=NEW_SESSION,
) -> None:
"""Run TaskInstance"""
res = self.check_and_change_state_before_execution(
@@ -1649,7 +1667,9 @@ class TaskInstance(Base, LoggingMixin):
task_copy.dry_run()
@provide_session
- def _handle_reschedule(self, actual_start_date, reschedule_exception, test_mode=False, session=None):
+ def _handle_reschedule(
+ self, actual_start_date, reschedule_exception, test_mode=False, session=NEW_SESSION
+ ):
# Don't record reschedule request in test mode
if test_mode:
return
@@ -1690,7 +1710,7 @@ class TaskInstance(Base, LoggingMixin):
test_mode: Optional[bool] = None,
force_fail: bool = False,
error_file: Optional[str] = None,
- session=None,
+ session=NEW_SESSION,
) -> None:
"""Handle Failure for the TaskInstance"""
if test_mode is None:
@@ -1761,7 +1781,7 @@ class TaskInstance(Base, LoggingMixin):
error: Union[str, Exception],
test_mode: Optional[bool] = None,
force_fail: bool = False,
- session=None,
+ session=NEW_SESSION,
) -> None:
self.handle_failure(error=error, test_mode=test_mode, force_fail=force_fail, session=session)
self._run_finished_callback(error=error)
@@ -1775,7 +1795,9 @@ class TaskInstance(Base, LoggingMixin):
return self.task.retries and self.try_number <= self.max_tries
- def get_template_context(self, session: Session = None, ignore_param_exceptions: bool = True) -> Context:
+ def get_template_context(
+ self, session: Session = NEW_SESSION, ignore_param_exceptions: bool = True
+ ) -> Context:
"""Return TI Context"""
# Do not use provide_session here -- it expunges everything on exit!
if not session:
@@ -1798,7 +1820,7 @@ class TaskInstance(Base, LoggingMixin):
params.update(task.params)
if conf.getboolean('core', 'dag_run_conf_overrides_params'):
self.overwrite_params_with_dag_run_conf(params=params, dag_run=dag_run)
- task.params = params.validate()
+ validated_params = task.params = params.validate()
logical_date = timezone.coerce_datetime(self.execution_date)
ds = logical_date.strftime('%Y-%m-%d')
@@ -1914,7 +1936,7 @@ class TaskInstance(Base, LoggingMixin):
'next_ds_nodash': get_next_ds_nodash(),
'next_execution_date': get_next_execution_date(),
'outlets': task.outlets,
- 'params': task.params,
+ 'params': validated_params,
'prev_data_interval_start_success': get_prev_data_interval_start_success(),
'prev_data_interval_end_success': get_prev_data_interval_end_success(),
'prev_ds': get_prev_ds(),
@@ -1947,7 +1969,7 @@ class TaskInstance(Base, LoggingMixin):
return Context(context)
@provide_session
- def get_rendered_template_fields(self, session=None):
+ def get_rendered_template_fields(self, session=NEW_SESSION):
"""Fetch rendered template fields from DB"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
@@ -1967,7 +1989,7 @@ class TaskInstance(Base, LoggingMixin):
) from e
@provide_session
- def get_rendered_k8s_spec(self, session=None):
+ def get_rendered_k8s_spec(self, session=NEW_SESSION):
"""Fetch rendered template fields from DB"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
@@ -2006,7 +2028,7 @@ class TaskInstance(Base, LoggingMixin):
date=self.execution_date,
args=self.command_as_list(),
pod_override_object=PodGenerator.from_obj(self.executor_config),
- scheduler_job_id="worker-config",
+ scheduler_job_id=0,
namespace=kube_config.executor_namespace,
base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file),
)
@@ -2109,7 +2131,7 @@ class TaskInstance(Base, LoggingMixin):
key: str,
value: Any,
execution_date: Optional[datetime] = None,
- session: Session = None,
+ session: Session = NEW_SESSION,
) -> None:
"""
Make an XCom available for tasks to pull.
@@ -2149,7 +2171,7 @@ class TaskInstance(Base, LoggingMixin):
dag_id: Optional[str] = None,
key: str = XCOM_RETURN_KEY,
include_prior_dates: bool = False,
- session: Session = None,
+ session: Session = NEW_SESSION,
) -> Any:
"""
Pull XComs that optionally meet certain criteria.
@@ -2199,7 +2221,11 @@ class TaskInstance(Base, LoggingMixin):
# Since we're only fetching the values field, and not the
# whole class, the @recreate annotation does not kick in.
# Therefore we need to deserialize the fields by ourselves.
- if is_container(task_ids):
+ if task_ids is None or isinstance(task_ids, str):
+ xcom = query.with_entities(XCom.value).first()
+ if xcom:
+ return XCom.deserialize_value(xcom)
+ else:
vals_kv = {
result.task_id: XCom.deserialize_value(result)
for result in query.with_entities(XCom.task_id, XCom.value)
@@ -2207,10 +2233,6 @@ class TaskInstance(Base, LoggingMixin):
values_ordered_by_id = [vals_kv.get(task_id) for task_id in task_ids]
return values_ordered_by_id
- else:
- xcom = query.with_entities(XCom.value).first()
- if xcom:
- return XCom.deserialize_value(xcom)
@provide_session
def get_num_running_task_instances(self, session):
@@ -2261,7 +2283,7 @@ class TaskInstance(Base, LoggingMixin):
TaskInstance.task_id == first_task_id,
)
- if settings.Session.bind.dialect.name == 'mssql':
+ if settings.engine.dialect.name == 'mssql':
return or_(
and_(
TaskInstance.dag_id == ti.dag_id,
@@ -2291,7 +2313,7 @@ class SimpleTaskInstance:
def __init__(self, ti: TaskInstance):
self._dag_id: str = ti.dag_id
self._task_id: str = ti.task_id
- self._run_id: datetime = ti.run_id
+ self._run_id: str = ti.run_id
self._start_date: datetime = ti.start_date
self._end_date: datetime = ti.end_date
self._try_number: int = ti.try_number
diff --git a/airflow/models/variable.py b/airflow/models/variable.py
index 00edb55..b5c3921 100644
--- a/airflow/models/variable.py
+++ b/airflow/models/variable.py
@@ -153,7 +153,7 @@ class Variable(Base, LoggingMixin):
cls,
key: str,
value: Any,
- description: str = None,
+ description: Optional[str] = None,
serialize_json: bool = False,
session: Session = None,
):
diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py
index 35223cb..599284c 100644
--- a/airflow/operators/subdag.py
+++ b/airflow/operators/subdag.py
@@ -34,7 +34,7 @@ from airflow.models.dag import DAG, DagContext
from airflow.models.pool import Pool
from airflow.models.taskinstance import TaskInstance
from airflow.sensors.base import BaseSensorOperator
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType
@@ -69,12 +69,14 @@ class SubDagOperator(BaseSensorOperator):
ui_color = '#555'
ui_fgcolor = '#fff'
+ subdag: "DAG"
+
@provide_session
def __init__(
self,
*,
subdag: DAG,
- session: Optional[Session] = None,
+ session: Session = NEW_SESSION,
conf: Optional[Dict] = None,
propagate_skipped_state: Optional[SkippedStatePropagationOptions] = None,
**kwargs,
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 2559f1e..07437e0 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -918,7 +918,6 @@ class SerializedDAG(DAG, BaseSerialization):
if serializable_task.subdag is not None:
setattr(serializable_task.subdag, 'parent_dag', dag)
- serializable_task.subdag.is_subdag = True
for task_id in serializable_task.downstream_task_ids:
# Bypass set_upstream etc here - it does more than we want
diff --git a/airflow/settings.py b/airflow/settings.py
index 9cfed37..a4b76d5 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -79,7 +79,7 @@ LOGGING_CLASS_PATH: Optional[str] = None
DONOT_MODIFY_HANDLERS: Optional[bool] = None
DAGS_FOLDER: str = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
-engine: Optional[Engine] = None
+engine: Engine
Session: Callable[..., SASession]
# The JSON library to use for DAG Serialization and De-Serialization
@@ -378,6 +378,8 @@ def configure_adapters():
def validate_session():
"""Validate ORM Session"""
+ global engine
+
worker_precheck = conf.getboolean('celery', 'worker_precheck', fallback=False)
if not worker_precheck:
return True
diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py
index e97f253..850bc47 100644
--- a/airflow/timetables/base.py
+++ b/airflow/timetables/base.py
@@ -33,7 +33,7 @@ class DataInterval(NamedTuple):
end: DateTime
@classmethod
- def exact(cls, at: DateTime) -> "DagRunInfo":
+ def exact(cls, at: DateTime) -> "DataInterval":
"""Represent an "interval" containing only an exact time."""
return cls(start=at, end=at)
diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi
index 0921d79..1249112 100644
--- a/airflow/utils/context.pyi
+++ b/airflow/utils/context.pyi
@@ -41,6 +41,13 @@ class _VariableAccessors(TypedDict):
json: Any
value: Any
+class VariableAccessor:
+ def __init__(self, *, deserialize_json: bool) -> None: ...
+ def get(self, key, default: Any = ...) -> Any: ...
+
+class ConnectionAccessor:
+ def get(self, key: str, default_conn: Any = None) -> Any: ...
+
class Context(TypedDict, total=False):
conf: AirflowConfigParser
conn: Any
diff --git a/airflow/utils/file.py b/airflow/utils/file.py
index a940a60..a7f45e9 100644
--- a/airflow/utils/file.py
+++ b/airflow/utils/file.py
@@ -21,7 +21,7 @@ import os
import re
import zipfile
from pathlib import Path
-from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Pattern, Union
+from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Pattern, Union, overload
from airflow.configuration import conf
@@ -68,7 +68,17 @@ def mkdirs(path, mode):
ZIP_REGEX = re.compile(fr'((.*\.zip){re.escape(os.sep)})?(.*)')
-def correct_maybe_zipped(fileloc):
+@overload
+def correct_maybe_zipped(fileloc: None) -> None:
+ ...
+
+
+@overload
+def correct_maybe_zipped(fileloc: Union[str, Path]) -> Union[str, Path]:
+ ...
+
+
+def correct_maybe_zipped(fileloc: Union[None, str, Path]) -> Union[None, str, Path]:
"""
If the path contains a folder with a .zip suffix, then
the folder is treated as a zip archive and path to zip is returned.
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index 745cdcd..a3e9c96 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -32,6 +32,7 @@ class TaskInstanceState(str, Enum):
# Set by the scheduler
# None - Task is created but should not run yet
+ NONE = None
REMOVED = "removed" # Task vanished from DAG before it ran
SCHEDULED = "scheduled" # Task should run and will be handed to executor soon
diff --git a/airflow/utils/timezone.py b/airflow/utils/timezone.py
index e5245d9..1051ee1 100644
--- a/airflow/utils/timezone.py
+++ b/airflow/utils/timezone.py
@@ -17,7 +17,7 @@
# under the License.
#
import datetime as dt
-from typing import Optional, Union
+from typing import TYPE_CHECKING, Optional, overload
import pendulum
from pendulum.datetime import DateTime
@@ -27,6 +27,9 @@ from airflow.settings import TIMEZONE
# UTC time zone as a tzinfo instance.
utc = pendulum.tz.timezone('UTC')
+if TYPE_CHECKING:
+ from pendulum.tz.timezone import Timezone
+
def is_localized(value):
"""
@@ -97,7 +100,17 @@ def convert_to_utc(value):
return value.astimezone(utc)
-def make_aware(value, timezone=None):
+@overload
+def make_aware(v: None, timezone: Optional["Timezone"] = None) -> None:
+ ...
+
+
+@overload
+def make_aware(v: dt.datetime, timezone: Optional["Timezone"] = None) -> dt.datetime:
+ ...
+
+
+def make_aware(value: Optional[dt.datetime], timezone: Optional["Timezone"] = None) -> Optional[dt.datetime]:
"""
Make a naive datetime.datetime in a given time zone aware.
@@ -175,7 +188,17 @@ def parse(string: str, timezone=None) -> DateTime:
return pendulum.parse(string, tz=timezone or TIMEZONE, strict=False) # type: ignore
-def coerce_datetime(v: Union[None, dt.datetime, DateTime]) -> Optional[DateTime]:
+@overload
+def coerce_datetime(v: None) -> None:
+ ...
+
+
+@overload
+def coerce_datetime(v: dt.datetime) -> DateTime:
+ ...
+
+
+def coerce_datetime(v: Optional[dt.datetime]) -> Optional[DateTime]:
"""Convert whatever is passed in to an timezone-aware ``pendulum.DateTime``."""
if v is None:
return None
diff --git a/airflow/utils/types.py b/airflow/utils/types.py
index 9f3c559..04688a7 100644
--- a/airflow/utils/types.py
+++ b/airflow/utils/types.py
@@ -20,6 +20,25 @@ from typing import Optional
from airflow.typing_compat import TypedDict
+class ArgNotSet:
+ """Sentinel type for annotations, useful when None is not viable.
+
+ Use like this::
+
+ def is_arg_passed(arg: Union[ArgNotSet, None] = NOTSET) -> bool:
+ if arg is NOTSET:
+ return False
+ return True
+
+ is_arg_passed() # False.
+ is_arg_passed(None) # True.
+ """
+
+
+NOTSET = ArgNotSet()
+"""Sentinel value for argument default. See ``ArgNotSet``."""
+
+
class DagRunType(str, enum.Enum):
"""Class with DagRun types"""
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 4921e1e..2ec863d 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -866,7 +866,6 @@ class TestDag(unittest.TestCase):
)
# parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
subdag.parent_dag = dag
- subdag.is_subdag = True
SubDagOperator(task_id='subtask', owner='owner2', subdag=subdag)
session = settings.Session()
dag.sync_to_db(session=session)
@@ -932,7 +931,6 @@ class TestDag(unittest.TestCase):
# parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
subdag.parent_dag = dag
- subdag.is_subdag = True
session.query(DagModel).filter(DagModel.dag_id.in_([subdag_id, dag_id])).delete(
synchronize_session=False
@@ -1427,7 +1425,6 @@ class TestDag(unittest.TestCase):
SubDagOperator(task_id='test', subdag=subdag, dag=dag)
t_2 = DummyOperator(task_id='task', dag=subdag)
subdag.parent_dag = dag
- subdag.is_subdag = True
dag.sync_to_db()
@@ -1806,7 +1803,6 @@ class TestDag(unittest.TestCase):
subdag = section_1.subdag
# parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set.
subdag.parent_dag = dag
- subdag.is_subdag = True
next_parent_info = dag.next_dagrun_info(None)
assert next_parent_info.logical_date == timezone.datetime(2019, 1, 1, 0, 0)
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index b967e33..6dc2d4c 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -357,7 +357,7 @@ class TestTaskInstance:
test that try to create a task with pool_slots less than 1
"""
- with pytest.raises(AirflowException):
+ with pytest.raises(ValueError, match="pool slots .* cannot be less than 1"):
dag = models.DAG(dag_id='test_run_pooling_task')
DummyOperator(
task_id='test_run_pooling_task_op',
@@ -1926,7 +1926,7 @@ class TestTaskInstance:
'try_number': '1',
},
'labels': {
- 'airflow-worker': 'worker-config',
+ 'airflow-worker': '0',
'airflow_version': version,
'dag_id': 'test_render_k8s_pod_yaml',
'execution_date': '2016-01-01T00_00_00_plus_00_00',
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 577a1df..82c848a 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1094,9 +1094,8 @@ class TestStringifiedDAGs:
"""
base_operator = BaseOperator(task_id="10")
fields = base_operator.__dict__
- assert {
+ assert fields == {
'_BaseOperator__instantiated': True,
- '_dag': None,
'_downstream_task_ids': set(),
'_inlets': [],
'_log': base_operator.log,
@@ -1139,12 +1138,11 @@ class TestStringifiedDAGs:
'run_as_user': None,
'sla': None,
'start_date': None,
- 'subdag': None,
'task_id': '10',
'trigger_rule': 'all_success',
'wait_for_downstream': False,
'weight_rule': 'downstream',
- } == fields, """
+ }, """
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
ACTION NEEDED! PLEASE READ THIS CAREFULLY AND CORRECT TESTS CAREFULLY