You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/01/28 21:25:23 UTC

[airflow] 10/17: Type-annotate SkipMixin and BaseXCom (#20011)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 0cc934ce12c36cbcf2572dc50675b9da77859eb9
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Dec 7 17:55:00 2021 +0800

    Type-annotate SkipMixin and BaseXCom (#20011)
    
    (cherry picked from commit 6dd0a0df7e6a2f025e9234bdbf97b41e9b8f6257)
---
 airflow/models/skipmixin.py |  15 +-
 airflow/models/xcom.py      | 335 ++++++++++++++++++++++++++++++--------------
 2 files changed, 232 insertions(+), 118 deletions(-)

diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 5cd50a3..765a947 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -17,7 +17,7 @@
 # under the License.
 
 import warnings
-from typing import TYPE_CHECKING, Iterable, Union
+from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
 
 from airflow.models.taskinstance import TaskInstance
 from airflow.utils import timezone
@@ -26,6 +26,7 @@ from airflow.utils.session import create_session, provide_session
 from airflow.utils.state import State
 
 if TYPE_CHECKING:
+    from pendulum import DateTime
     from sqlalchemy import Session
 
     from airflow.models import DagRun
@@ -66,9 +67,9 @@ class SkipMixin(LoggingMixin):
     def skip(
         self,
         dag_run: "DagRun",
-        execution_date: "timezone.DateTime",
-        tasks: "Iterable[BaseOperator]",
-        session: "Session" = None,
+        execution_date: "DateTime",
+        tasks: Sequence["BaseOperator"],
+        session: "Session",
     ):
         """
         Sets tasks instances to skipped from the same dag run.
@@ -114,11 +115,7 @@ class SkipMixin(LoggingMixin):
         session.commit()
 
         # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
-        try:
-            task_id = self.task_id
-        except AttributeError:
-            task_id = None
-
+        task_id: Optional[str] = getattr(self, "task_id", None)
         if task_id is not None:
             from airflow.models.xcom import XCom
 
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 99c2b9a..4bb9689 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -16,10 +16,11 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import datetime
 import json
 import logging
 import pickle
-from typing import Any, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, Union, cast, overload
 
 import pendulum
 from sqlalchemy import Column, LargeBinary, String
@@ -79,14 +80,60 @@ class BaseXCom(Base, LoggingMixin):
     def __repr__(self):
         return f'<XCom "{self.key}" ({self.task_id} @ {self.execution_date})>'
 
+    @overload
     @classmethod
-    @provide_session
-    def set(cls, key, value, task_id, dag_id, execution_date=None, run_id=None, session=None):
+    def set(
+        cls,
+        key: str,
+        value: Any,
+        *,
+        dag_id: str,
+        task_id: str,
+        run_id: str,
+        session: Optional[Session] = None,
+    ) -> None:
+        """Store an XCom value.
+
+        A deprecated form of this function accepts ``execution_date`` instead of
+        ``run_id``. The two arguments are mutually exclusive.
+
+        :param key: Key to store the XCom.
+        :param value: XCom value to store.
+        :param dag_id: DAG ID.
+        :param task_id: Task ID.
+        :param run_id: DAG run ID for the task.
+        :param session: Database session. If not given, a new session will be
+            created for this function.
+        :type session: sqlalchemy.orm.session.Session
         """
-        Store an XCom value.
 
-        :return: None
-        """
+    @overload
+    @classmethod
+    def set(
+        cls,
+        key: str,
+        value: Any,
+        task_id: str,
+        dag_id: str,
+        execution_date: datetime.datetime,
+        session: Optional[Session] = None,
+    ) -> None:
+        """:sphinx-autoapi-skip:"""
+
+    @classmethod
+    @provide_session
+    def set(
+        cls,
+        key: str,
+        value: Any,
+        task_id: str,
+        dag_id: str,
+        execution_date: Optional[datetime.datetime] = None,
+        session: Session = None,
+        *,
+        run_id: Optional[str] = None,
+    ) -> None:
+        """:sphinx-autoapi-skip:"""
         if not (execution_date is None) ^ (run_id is None):
             raise ValueError("Exactly one of execution_date or run_id must be passed")
 
@@ -94,70 +141,95 @@ class BaseXCom(Base, LoggingMixin):
             from airflow.models.dagrun import DagRun
 
             dag_run = session.query(DagRun).filter_by(dag_id=dag_id, run_id=run_id).one()
-
             execution_date = dag_run.execution_date
 
-        value = XCom.serialize_value(value)
-
-        # remove any duplicate XComs
+        # Remove duplicate XComs and insert a new one.
         session.query(cls).filter(
-            cls.key == key, cls.execution_date == execution_date, cls.task_id == task_id, cls.dag_id == dag_id
+            cls.key == key,
+            cls.execution_date == execution_date,
+            cls.task_id == task_id,
+            cls.dag_id == dag_id,
         ).delete()
-
+        new = cast(Any, cls)(  # Work around Mypy complaining model not defining '__init__'.
+            key=key,
+            value=cls.serialize_value(value),
+            execution_date=execution_date,
+            task_id=task_id,
+            dag_id=dag_id,
+        )
+        session.add(new)
         session.flush()
 
-        # insert new XCom
-        session.add(XCom(key=key, value=value, execution_date=execution_date, task_id=task_id, dag_id=dag_id))
+    @overload
+    @classmethod
+    def get_one(
+        cls,
+        *,
+        run_id: str,
+        key: Optional[str] = None,
+        task_id: Optional[str] = None,
+        dag_id: Optional[str] = None,
+        include_prior_dates: bool = False,
+        session: Optional[Session] = None,
+    ) -> Optional[Any]:
+        """Retrieve an XCom value, optionally meeting certain criteria.
+
+        This method returns "full" XCom values (i.e. uses ``deserialize_value``
+        from the XCom backend). Use :meth:`get_many` if you want the "shortened"
+        value via ``orm_deserialize_value``.
+
+        If there are no results, *None* is returned.
+
+        A deprecated form of this function accepts ``execution_date`` instead of
+        ``run_id``. The two arguments are mutually exclusive.
+
+        :param run_id: DAG run ID for the task.
+        :param key: A key for the XCom. If provided, only XCom with matching
+            keys will be returned. Pass *None* (default) to remove the filter.
+        :param task_id: Only XCom from task with matching ID will be pulled.
+            Pass *None* (default) to remove the filter.
+        :param dag_id: Only pull XCom from this DAG. If *None* (default), the
+            DAG of the calling task is used.
+        :param include_prior_dates: If *False* (default), only XCom from the
+            specified DAG run is returned. If *True*, the latest matching XCom is
+            returned regardless of the run it belongs to.
+        :param session: Database session. If not given, a new session will be
+            created for this function.
+        :type session: sqlalchemy.orm.session.Session
+        """
 
-        session.flush()
+    @overload
+    @classmethod
+    def get_one(
+        cls,
+        execution_date: pendulum.DateTime,
+        key: Optional[str] = None,
+        task_id: Optional[str] = None,
+        dag_id: Optional[str] = None,
+        include_prior_dates: bool = False,
+        session: Optional[Session] = None,
+    ) -> Optional[Any]:
+        """:sphinx-autoapi-skip:"""
 
     @classmethod
     @provide_session
     def get_one(
         cls,
         execution_date: Optional[pendulum.DateTime] = None,
-        run_id: Optional[str] = None,
         key: Optional[str] = None,
         task_id: Optional[Union[str, Iterable[str]]] = None,
         dag_id: Optional[Union[str, Iterable[str]]] = None,
         include_prior_dates: bool = False,
         session: Session = None,
+        *,
+        run_id: Optional[str] = None,
     ) -> Optional[Any]:
-        """
-        Retrieve an XCom value, optionally meeting certain criteria. Returns None
-        of there are no results.
-
-        ``run_id`` and ``execution_date`` are mutually exclusive.
-
-        This method returns "full" XCom values (i.e. it uses ``deserialize_value`` from the XCom backend).
-        Please use :meth:`get_many` if you want the "shortened" value via ``orm_deserialize_value``
-
-        :param execution_date: Execution date for the task
-        :type execution_date: pendulum.datetime
-        :param run_id: Dag run id for the task
-        :type run_id: str
-        :param key: A key for the XCom. If provided, only XComs with matching
-            keys will be returned. To remove the filter, pass key=None.
-        :type key: str
-        :param task_id: Only XComs from task with matching id will be
-            pulled. Can pass None to remove the filter.
-        :type task_id: str
-        :param dag_id: If provided, only pulls XCom from this DAG.
-            If None (default), the DAG of the calling task is used.
-        :type dag_id: str
-        :param include_prior_dates: If False, only XCom from the current
-            execution_date are returned. If True, XCom from previous dates
-            are returned as well.
-        :type include_prior_dates: bool
-        :param session: database session
-        :type session: sqlalchemy.orm.session.Session
-        """
+        """:sphinx-autoapi-skip:"""
         if not (execution_date is None) ^ (run_id is None):
             raise ValueError("Exactly one of execution_date or run_id must be passed")
 
-        result = (
-            cls.get_many(
-                execution_date=execution_date,
+        if run_id is not None:
+            query = cls.get_many(
                 run_id=run_id,
                 key=key,
                 task_ids=task_id,
@@ -165,58 +237,88 @@ class BaseXCom(Base, LoggingMixin):
                 include_prior_dates=include_prior_dates,
                 session=session,
             )
-            .with_entities(cls.value)
-            .first()
-        )
+        elif execution_date is not None:
+            query = cls.get_many(
+                execution_date=execution_date,
+                key=key,
+                task_ids=task_id,
+                dag_ids=dag_id,
+                include_prior_dates=include_prior_dates,
+                session=session,
+            )
+        else:
+            raise RuntimeError("Should not happen?")
+
+        result = query.with_entities(cls.value).first()
         if result:
             return cls.deserialize_value(result)
         return None
 
+    @overload
+    @classmethod
+    def get_many(
+        cls,
+        *,
+        run_id: str,
+        key: Optional[str] = None,
+        task_ids: Union[str, Iterable[str], None] = None,
+        dag_ids: Union[str, Iterable[str], None] = None,
+        include_prior_dates: bool = False,
+        limit: Optional[int] = None,
+        session: Optional[Session] = None,
+    ) -> Query:
+        """Composes a query to get one or more XCom entries.
+
+        This function returns an SQLAlchemy query of full XCom objects. If you
+        just want one stored value, use :meth:`get_one` instead.
+
+        A deprecated form of this function accepts ``execution_date`` instead of
+        ``run_id``. The two arguments are mutually exclusive.
+
+        :param run_id: DAG run ID for the task.
+        :param key: A key for the XComs. If provided, only XComs with matching
+            keys will be returned. Pass *None* (default) to remove the filter.
+        :param task_ids: Only XComs from task with matching IDs will be pulled.
+            Pass *None* (default) to remove the filter.
+        :param dag_id: Only pulls XComs from this DAG. If *None* (default), the
+            DAG of the calling task is used.
+        :param include_prior_dates: If *False* (default), only XComs from the
+            specified DAG run are returned. If *True*, all matching XComs are
+            returned regardless of the run it belongs to.
+        :param session: Database session. If not given, a new session will be
+            created for this function.
+        :type session: sqlalchemy.orm.session.Session
+        """
+
+    @overload
+    @classmethod
+    def get_many(
+        cls,
+        execution_date: pendulum.DateTime,
+        key: Optional[str] = None,
+        task_ids: Union[str, Iterable[str], None] = None,
+        dag_ids: Union[str, Iterable[str], None] = None,
+        include_prior_dates: bool = False,
+        limit: Optional[int] = None,
+        session: Optional[Session] = None,
+    ) -> Query:
+        """:sphinx-autoapi-skip:"""
+
     @classmethod
     @provide_session
     def get_many(
         cls,
         execution_date: Optional[pendulum.DateTime] = None,
-        run_id: Optional[str] = None,
         key: Optional[str] = None,
         task_ids: Optional[Union[str, Iterable[str]]] = None,
         dag_ids: Optional[Union[str, Iterable[str]]] = None,
         include_prior_dates: bool = False,
         limit: Optional[int] = None,
         session: Session = None,
+        *,
+        run_id: Optional[str] = None,
     ) -> Query:
-        """
-        Composes a query to get one or more values from the xcom table.
-
-        ``run_id`` and ``execution_date`` are mutually exclusive.
-
-        This function returns an SQLAlchemy query of full XCom objects. If you just want one stored value then
-        use :meth:`get_one`.
-
-        :param execution_date: Execution date for the task
-        :type execution_date: pendulum.datetime
-        :param run_id: Dag run id for the task
-        :type run_id: str
-        :param key: A key for the XCom. If provided, only XComs with matching
-            keys will be returned. To remove the filter, pass key=None.
-        :type key: str
-        :param task_ids: Only XComs from tasks with matching ids will be
-            pulled. Can pass None to remove the filter.
-        :type task_ids: str or iterable of strings (representing task_ids)
-        :param dag_ids: If provided, only pulls XComs from this DAG.
-            If None (default), the DAG of the calling task is used.
-        :type dag_ids: str
-        :param include_prior_dates: If False, only XComs from the current
-            execution_date are returned. If True, XComs from previous dates
-            are returned as well.
-        :type include_prior_dates: bool
-        :param limit: If required, limit the number of returned objects.
-            XCom objects can be quite big and you might want to limit the
-            number of rows.
-        :type limit: int
-        :param session: database session
-        :type session: sqlalchemy.orm.session.Session
-        """
+        """:sphinx-autoapi-skip:"""
         if not (execution_date is None) ^ (run_id is None):
             raise ValueError("Exactly one of execution_date or run_id must be passed")
 
@@ -262,8 +364,8 @@ class BaseXCom(Base, LoggingMixin):
 
     @classmethod
     @provide_session
-    def delete(cls, xcoms, session=None):
-        """Delete Xcom"""
+    def delete(cls, xcoms: Union["XCom", Iterable["XCom"]], session: Session) -> None:
+        """Delete one or multiple XCom entries."""
         if isinstance(xcoms, XCom):
             xcoms = [xcoms]
         for xcom in xcoms:
@@ -272,37 +374,49 @@ class BaseXCom(Base, LoggingMixin):
             session.delete(xcom)
         session.commit()
 
+    @overload
+    @classmethod
+    def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Session] = None) -> None:
+        """Clear all XCom data from the database for the given task instance.
+
+        A deprecated form of this function accepts ``execution_date`` instead of
+        ``run_id``. The two arguments are mutually exclusive.
+
+        :param dag_id: ID of DAG to clear the XCom for.
+        :param task_id: ID of task to clear the XCom for.
+        :param run_id: ID of DAG run to clear the XCom for.
+        :param session: Database session. If not given, a new session will be
+            created for this function.
+        :type session: sqlalchemy.orm.session.Session
+        """
+
+    @overload
+    @classmethod
+    def clear(
+        cls,
+        execution_date: pendulum.DateTime,
+        dag_id: str,
+        task_id: str,
+        session: Optional[Session] = None,
+    ) -> None:
+        """:sphinx-autoapi-skip:"""
+
     @classmethod
     @provide_session
     def clear(
         cls,
         execution_date: Optional[pendulum.DateTime] = None,
-        dag_id: str = None,
-        task_id: str = None,
-        run_id: str = None,
+        dag_id: Optional[str] = None,
+        task_id: Optional[str] = None,
+        run_id: Optional[str] = None,
         session: Session = None,
     ) -> None:
-        """
-        Clears all XCom data from the database for the task instance
-
-        ``run_id`` and ``execution_date`` are mutually exclusive.
-
-        :param execution_date: Execution date for the task
-        :type execution_date: pendulum.datetime or None
-        :param dag_id: ID of DAG to clear the XCom for.
-        :type dag_id: str
-        :param task_id: Only XComs from task with matching id will be cleared.
-        :type task_id: str
-        :param run_id: Dag run id for the task
-        :type run_id: str or None
-        :param session: database session
-        :type session: sqlalchemy.orm.session.Session
-        """
+        """:sphinx-autoapi-skip:"""
         # Given the historic order of this function (execution_date was first argument) to add a new optional
         # param we need to add default values for everything :(
-        if not dag_id:
+        if dag_id is None:
             raise TypeError("clear() missing required argument: dag_id")
-        if not task_id:
+        if task_id is None:
             raise TypeError("clear() missing required argument: task_id")
 
         if not (execution_date is None) ^ (run_id is None):
@@ -364,7 +478,7 @@ class BaseXCom(Base, LoggingMixin):
         return BaseXCom.deserialize_value(self)
 
 
-def resolve_xcom_backend():
+def resolve_xcom_backend() -> Type[BaseXCom]:
     """Resolves custom XCom class"""
     clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}")
     if clazz:
@@ -376,4 +490,7 @@ def resolve_xcom_backend():
     return BaseXCom
 
 
-XCom = resolve_xcom_backend()
+if TYPE_CHECKING:
+    XCom = BaseXCom  # Hack to avoid Mypy "Variable 'XCom' is not valid as a type".
+else:
+    XCom = resolve_xcom_backend()