You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/01/28 21:25:23 UTC
[airflow] 10/17: Type-annotate SkipMixin and BaseXCom (#20011)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 0cc934ce12c36cbcf2572dc50675b9da77859eb9
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Dec 7 17:55:00 2021 +0800
Type-annotate SkipMixin and BaseXCom (#20011)
(cherry picked from commit 6dd0a0df7e6a2f025e9234bdbf97b41e9b8f6257)
---
airflow/models/skipmixin.py | 15 +-
airflow/models/xcom.py | 335 ++++++++++++++++++++++++++++++--------------
2 files changed, 232 insertions(+), 118 deletions(-)
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 5cd50a3..765a947 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -17,7 +17,7 @@
# under the License.
import warnings
-from typing import TYPE_CHECKING, Iterable, Union
+from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
@@ -26,6 +26,7 @@ from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
if TYPE_CHECKING:
+ from pendulum import DateTime
from sqlalchemy import Session
from airflow.models import DagRun
@@ -66,9 +67,9 @@ class SkipMixin(LoggingMixin):
def skip(
self,
dag_run: "DagRun",
- execution_date: "timezone.DateTime",
- tasks: "Iterable[BaseOperator]",
- session: "Session" = None,
+ execution_date: "DateTime",
+ tasks: Sequence["BaseOperator"],
+ session: "Session",
):
"""
Sets tasks instances to skipped from the same dag run.
@@ -114,11 +115,7 @@ class SkipMixin(LoggingMixin):
session.commit()
# SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
- try:
- task_id = self.task_id
- except AttributeError:
- task_id = None
-
+ task_id: Optional[str] = getattr(self, "task_id", None)
if task_id is not None:
from airflow.models.xcom import XCom
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 99c2b9a..4bb9689 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -16,10 +16,11 @@
# specific language governing permissions and limitations
# under the License.
+import datetime
import json
import logging
import pickle
-from typing import Any, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, Union, cast, overload
import pendulum
from sqlalchemy import Column, LargeBinary, String
@@ -79,14 +80,60 @@ class BaseXCom(Base, LoggingMixin):
def __repr__(self):
return f'<XCom "{self.key}" ({self.task_id} @ {self.execution_date})>'
+ @overload
@classmethod
- @provide_session
- def set(cls, key, value, task_id, dag_id, execution_date=None, run_id=None, session=None):
+ def set(
+ cls,
+ key: str,
+ value: Any,
+ *,
+ dag_id: str,
+ task_id: str,
+ run_id: str,
+ session: Optional[Session] = None,
+ ) -> None:
+ """Store an XCom value.
+
+ A deprecated form of this function accepts ``execution_date`` instead of
+ ``run_id``. The two arguments are mutually exclusive.
+
+ :param key: Key to store the XCom.
+ :param value: XCom value to store.
+ :param dag_id: DAG ID.
+ :param task_id: Task ID.
+ :param run_id: DAG run ID for the task.
+ :param session: Database session. If not given, a new session will be
+ created for this function.
+ :type session: sqlalchemy.orm.session.Session
"""
- Store an XCom value.
- :return: None
- """
+ @overload
+ @classmethod
+ def set(
+ cls,
+ key: str,
+ value: Any,
+ task_id: str,
+ dag_id: str,
+ execution_date: datetime.datetime,
+ session: Optional[Session] = None,
+ ) -> None:
+ """:sphinx-autoapi-skip:"""
+
+ @classmethod
+ @provide_session
+ def set(
+ cls,
+ key: str,
+ value: Any,
+ task_id: str,
+ dag_id: str,
+ execution_date: Optional[datetime.datetime] = None,
+ session: Session = None,
+ *,
+ run_id: Optional[str] = None,
+ ) -> None:
+ """:sphinx-autoapi-skip:"""
if not (execution_date is None) ^ (run_id is None):
raise ValueError("Exactly one of execution_date or run_id must be passed")
@@ -94,70 +141,95 @@ class BaseXCom(Base, LoggingMixin):
from airflow.models.dagrun import DagRun
dag_run = session.query(DagRun).filter_by(dag_id=dag_id, run_id=run_id).one()
-
execution_date = dag_run.execution_date
- value = XCom.serialize_value(value)
-
- # remove any duplicate XComs
+ # Remove duplicate XComs and insert a new one.
session.query(cls).filter(
- cls.key == key, cls.execution_date == execution_date, cls.task_id == task_id, cls.dag_id == dag_id
+ cls.key == key,
+ cls.execution_date == execution_date,
+ cls.task_id == task_id,
+ cls.dag_id == dag_id,
).delete()
-
+ new = cast(Any, cls)( # Work around Mypy complaining model not defining '__init__'.
+ key=key,
+ value=cls.serialize_value(value),
+ execution_date=execution_date,
+ task_id=task_id,
+ dag_id=dag_id,
+ )
+ session.add(new)
session.flush()
- # insert new XCom
- session.add(XCom(key=key, value=value, execution_date=execution_date, task_id=task_id, dag_id=dag_id))
+ @overload
+ @classmethod
+ def get_one(
+ cls,
+ *,
+ run_id: str,
+ key: Optional[str] = None,
+ task_id: Optional[str] = None,
+ dag_id: Optional[str] = None,
+ include_prior_dates: bool = False,
+ session: Optional[Session] = None,
+ ) -> Optional[Any]:
+ """Retrieve an XCom value, optionally meeting certain criteria.
+
+ This method returns "full" XCom values (i.e. uses ``deserialize_value``
+ from the XCom backend). Use :meth:`get_many` if you want the "shortened"
+ value via ``orm_deserialize_value``.
+
+ If there are no results, *None* is returned.
+
+ A deprecated form of this function accepts ``execution_date`` instead of
+ ``run_id``. The two arguments are mutually exclusive.
+
+ :param run_id: DAG run ID for the task.
+ :param key: A key for the XCom. If provided, only XCom with matching
+ keys will be returned. Pass *None* (default) to remove the filter.
+ :param task_id: Only XCom from task with matching ID will be pulled.
+ Pass *None* (default) to remove the filter.
+ :param dag_id: Only pull XCom from this DAG. If *None* (default), the
+ DAG of the calling task is used.
+ :param include_prior_dates: If *False* (default), only XCom from the
+ specified DAG run is returned. If *True*, the latest matching XCom is
+ returned regardless of the run it belongs to.
+ :param session: Database session. If not given, a new session will be
+ created for this function.
+ :type session: sqlalchemy.orm.session.Session
+ """
- session.flush()
+ @overload
+ @classmethod
+ def get_one(
+ cls,
+ execution_date: pendulum.DateTime,
+ key: Optional[str] = None,
+ task_id: Optional[str] = None,
+ dag_id: Optional[str] = None,
+ include_prior_dates: bool = False,
+ session: Optional[Session] = None,
+ ) -> Optional[Any]:
+ """:sphinx-autoapi-skip:"""
@classmethod
@provide_session
def get_one(
cls,
execution_date: Optional[pendulum.DateTime] = None,
- run_id: Optional[str] = None,
key: Optional[str] = None,
task_id: Optional[Union[str, Iterable[str]]] = None,
dag_id: Optional[Union[str, Iterable[str]]] = None,
include_prior_dates: bool = False,
session: Session = None,
+ *,
+ run_id: Optional[str] = None,
) -> Optional[Any]:
- """
- Retrieve an XCom value, optionally meeting certain criteria. Returns None
- of there are no results.
-
- ``run_id`` and ``execution_date`` are mutually exclusive.
-
- This method returns "full" XCom values (i.e. it uses ``deserialize_value`` from the XCom backend).
- Please use :meth:`get_many` if you want the "shortened" value via ``orm_deserialize_value``
-
- :param execution_date: Execution date for the task
- :type execution_date: pendulum.datetime
- :param run_id: Dag run id for the task
- :type run_id: str
- :param key: A key for the XCom. If provided, only XComs with matching
- keys will be returned. To remove the filter, pass key=None.
- :type key: str
- :param task_id: Only XComs from task with matching id will be
- pulled. Can pass None to remove the filter.
- :type task_id: str
- :param dag_id: If provided, only pulls XCom from this DAG.
- If None (default), the DAG of the calling task is used.
- :type dag_id: str
- :param include_prior_dates: If False, only XCom from the current
- execution_date are returned. If True, XCom from previous dates
- are returned as well.
- :type include_prior_dates: bool
- :param session: database session
- :type session: sqlalchemy.orm.session.Session
- """
+ """:sphinx-autoapi-skip:"""
if not (execution_date is None) ^ (run_id is None):
raise ValueError("Exactly one of execution_date or run_id must be passed")
- result = (
- cls.get_many(
- execution_date=execution_date,
+ if run_id is not None:
+ query = cls.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
@@ -165,58 +237,88 @@ class BaseXCom(Base, LoggingMixin):
include_prior_dates=include_prior_dates,
session=session,
)
- .with_entities(cls.value)
- .first()
- )
+ elif execution_date is not None:
+ query = cls.get_many(
+ execution_date=execution_date,
+ key=key,
+ task_ids=task_id,
+ dag_ids=dag_id,
+ include_prior_dates=include_prior_dates,
+ session=session,
+ )
+ else:
+ raise RuntimeError("Should not happen?")
+
+ result = query.with_entities(cls.value).first()
if result:
return cls.deserialize_value(result)
return None
+ @overload
+ @classmethod
+ def get_many(
+ cls,
+ *,
+ run_id: str,
+ key: Optional[str] = None,
+ task_ids: Union[str, Iterable[str], None] = None,
+ dag_ids: Union[str, Iterable[str], None] = None,
+ include_prior_dates: bool = False,
+ limit: Optional[int] = None,
+ session: Optional[Session] = None,
+ ) -> Query:
+ """Composes a query to get one or more XCom entries.
+
+ This function returns an SQLAlchemy query of full XCom objects. If you
+ just want one stored value, use :meth:`get_one` instead.
+
+ A deprecated form of this function accepts ``execution_date`` instead of
+ ``run_id``. The two arguments are mutually exclusive.
+
+ :param run_id: DAG run ID for the task.
+ :param key: A key for the XComs. If provided, only XComs with matching
+ keys will be returned. Pass *None* (default) to remove the filter.
+ :param task_ids: Only XComs from task with matching IDs will be pulled.
+ Pass *None* (default) to remove the filter.
+ :param dag_id: Only pulls XComs from this DAG. If *None* (default), the
+ DAG of the calling task is used.
+ :param include_prior_dates: If *False* (default), only XComs from the
+ specified DAG run are returned. If *True*, all matching XComs are
+ returned regardless of the run it belongs to.
+ :param session: Database session. If not given, a new session will be
+ created for this function.
+ :type session: sqlalchemy.orm.session.Session
+ """
+
+ @overload
+ @classmethod
+ def get_many(
+ cls,
+ execution_date: pendulum.DateTime,
+ key: Optional[str] = None,
+ task_ids: Union[str, Iterable[str], None] = None,
+ dag_ids: Union[str, Iterable[str], None] = None,
+ include_prior_dates: bool = False,
+ limit: Optional[int] = None,
+ session: Optional[Session] = None,
+ ) -> Query:
+ """:sphinx-autoapi-skip:"""
+
@classmethod
@provide_session
def get_many(
cls,
execution_date: Optional[pendulum.DateTime] = None,
- run_id: Optional[str] = None,
key: Optional[str] = None,
task_ids: Optional[Union[str, Iterable[str]]] = None,
dag_ids: Optional[Union[str, Iterable[str]]] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = None,
+ *,
+ run_id: Optional[str] = None,
) -> Query:
- """
- Composes a query to get one or more values from the xcom table.
-
- ``run_id`` and ``execution_date`` are mutually exclusive.
-
- This function returns an SQLAlchemy query of full XCom objects. If you just want one stored value then
- use :meth:`get_one`.
-
- :param execution_date: Execution date for the task
- :type execution_date: pendulum.datetime
- :param run_id: Dag run id for the task
- :type run_id: str
- :param key: A key for the XCom. If provided, only XComs with matching
- keys will be returned. To remove the filter, pass key=None.
- :type key: str
- :param task_ids: Only XComs from tasks with matching ids will be
- pulled. Can pass None to remove the filter.
- :type task_ids: str or iterable of strings (representing task_ids)
- :param dag_ids: If provided, only pulls XComs from this DAG.
- If None (default), the DAG of the calling task is used.
- :type dag_ids: str
- :param include_prior_dates: If False, only XComs from the current
- execution_date are returned. If True, XComs from previous dates
- are returned as well.
- :type include_prior_dates: bool
- :param limit: If required, limit the number of returned objects.
- XCom objects can be quite big and you might want to limit the
- number of rows.
- :type limit: int
- :param session: database session
- :type session: sqlalchemy.orm.session.Session
- """
+ """:sphinx-autoapi-skip:"""
if not (execution_date is None) ^ (run_id is None):
raise ValueError("Exactly one of execution_date or run_id must be passed")
@@ -262,8 +364,8 @@ class BaseXCom(Base, LoggingMixin):
@classmethod
@provide_session
- def delete(cls, xcoms, session=None):
- """Delete Xcom"""
+ def delete(cls, xcoms: Union["XCom", Iterable["XCom"]], session: Session) -> None:
+ """Delete one or multiple XCom entries."""
if isinstance(xcoms, XCom):
xcoms = [xcoms]
for xcom in xcoms:
@@ -272,37 +374,49 @@ class BaseXCom(Base, LoggingMixin):
session.delete(xcom)
session.commit()
+ @overload
+ @classmethod
+ def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Session] = None) -> None:
+ """Clear all XCom data from the database for the given task instance.
+
+ A deprecated form of this function accepts ``execution_date`` instead of
+ ``run_id``. The two arguments are mutually exclusive.
+
+ :param dag_id: ID of DAG to clear the XCom for.
+ :param task_id: ID of task to clear the XCom for.
+ :param run_id: ID of DAG run to clear the XCom for.
+ :param session: Database session. If not given, a new session will be
+ created for this function.
+ :type session: sqlalchemy.orm.session.Session
+ """
+
+ @overload
+ @classmethod
+ def clear(
+ cls,
+ execution_date: pendulum.DateTime,
+ dag_id: str,
+ task_id: str,
+ session: Optional[Session] = None,
+ ) -> None:
+ """:sphinx-autoapi-skip:"""
+
@classmethod
@provide_session
def clear(
cls,
execution_date: Optional[pendulum.DateTime] = None,
- dag_id: str = None,
- task_id: str = None,
- run_id: str = None,
+ dag_id: Optional[str] = None,
+ task_id: Optional[str] = None,
+ run_id: Optional[str] = None,
session: Session = None,
) -> None:
- """
- Clears all XCom data from the database for the task instance
-
- ``run_id`` and ``execution_date`` are mutually exclusive.
-
- :param execution_date: Execution date for the task
- :type execution_date: pendulum.datetime or None
- :param dag_id: ID of DAG to clear the XCom for.
- :type dag_id: str
- :param task_id: Only XComs from task with matching id will be cleared.
- :type task_id: str
- :param run_id: Dag run id for the task
- :type run_id: str or None
- :param session: database session
- :type session: sqlalchemy.orm.session.Session
- """
+ """:sphinx-autoapi-skip:"""
# Given the historic order of this function (execution_date was first argument) to add a new optional
# param we need to add default values for everything :(
- if not dag_id:
+ if dag_id is None:
raise TypeError("clear() missing required argument: dag_id")
- if not task_id:
+ if task_id is None:
raise TypeError("clear() missing required argument: task_id")
if not (execution_date is None) ^ (run_id is None):
@@ -364,7 +478,7 @@ class BaseXCom(Base, LoggingMixin):
return BaseXCom.deserialize_value(self)
-def resolve_xcom_backend():
+def resolve_xcom_backend() -> Type[BaseXCom]:
"""Resolves custom XCom class"""
clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}")
if clazz:
@@ -376,4 +490,7 @@ def resolve_xcom_backend():
return BaseXCom
-XCom = resolve_xcom_backend()
+if TYPE_CHECKING:
+ XCom = BaseXCom # Hack to avoid Mypy "Variable 'XCom' is not valid as a type".
+else:
+ XCom = resolve_xcom_backend()