You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/01/28 21:25:24 UTC

[airflow] 11/17: Helper for provide_session-decorated functions (#20104)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d5721619bc970e42abdc474362fcf9c151bc33a8
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Dec 7 21:50:34 2021 +0800

    Helper for provide_session-decorated functions (#20104)
    
    * Helper for provide_session-decorated functions
    
    * Apply NEW_SESSION trick on XCom
    
    (cherry picked from commit a80ac1eecc0ea187de7984510b4ef6f981b97196)
---
 airflow/models/xcom.py   | 24 ++++++++++++------------
 airflow/settings.py      | 10 ++++++----
 airflow/utils/session.py | 11 +++++++++--
 3 files changed, 27 insertions(+), 18 deletions(-)

diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 4bb9689..5efaa0a 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -32,7 +32,7 @@ from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
 from airflow.utils import timezone
 from airflow.utils.helpers import is_container
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
 
 log = logging.getLogger(__name__)
@@ -90,7 +90,7 @@ class BaseXCom(Base, LoggingMixin):
         dag_id: str,
         task_id: str,
         run_id: str,
-        session: Optional[Session] = None,
+        session: Session = NEW_SESSION,
     ) -> None:
         """Store an XCom value.
 
@@ -116,7 +116,7 @@ class BaseXCom(Base, LoggingMixin):
         task_id: str,
         dag_id: str,
         execution_date: datetime.datetime,
-        session: Optional[Session] = None,
+        session: Session = NEW_SESSION,
     ) -> None:
         """:sphinx-autoapi-skip:"""
 
@@ -129,7 +129,7 @@ class BaseXCom(Base, LoggingMixin):
         task_id: str,
         dag_id: str,
         execution_date: Optional[datetime.datetime] = None,
-        session: Session = None,
+        session: Session = NEW_SESSION,
         *,
         run_id: Optional[str] = None,
     ) -> None:
@@ -170,7 +170,7 @@ class BaseXCom(Base, LoggingMixin):
         task_id: Optional[str] = None,
         dag_id: Optional[str] = None,
         include_prior_dates: bool = False,
-        session: Optional[Session] = None,
+        session: Session = NEW_SESSION,
     ) -> Optional[Any]:
         """Retrieve an XCom value, optionally meeting certain criteria.
 
@@ -207,7 +207,7 @@ class BaseXCom(Base, LoggingMixin):
         task_id: Optional[str] = None,
         dag_id: Optional[str] = None,
         include_prior_dates: bool = False,
-        session: Optional[Session] = None,
+        session: Session = NEW_SESSION,
     ) -> Optional[Any]:
         """:sphinx-autoapi-skip:"""
 
@@ -220,7 +220,7 @@ class BaseXCom(Base, LoggingMixin):
         task_id: Optional[Union[str, Iterable[str]]] = None,
         dag_id: Optional[Union[str, Iterable[str]]] = None,
         include_prior_dates: bool = False,
-        session: Session = None,
+        session: Session = NEW_SESSION,
         *,
         run_id: Optional[str] = None,
     ) -> Optional[Any]:
@@ -265,7 +265,7 @@ class BaseXCom(Base, LoggingMixin):
         dag_ids: Union[str, Iterable[str], None] = None,
         include_prior_dates: bool = False,
         limit: Optional[int] = None,
-        session: Optional[Session] = None,
+        session: Session = NEW_SESSION,
     ) -> Query:
         """Composes a query to get one or more XCom entries.
 
@@ -300,7 +300,7 @@ class BaseXCom(Base, LoggingMixin):
         dag_ids: Union[str, Iterable[str], None] = None,
         include_prior_dates: bool = False,
         limit: Optional[int] = None,
-        session: Optional[Session] = None,
+        session: Session = NEW_SESSION,
     ) -> Query:
         """:sphinx-autoapi-skip:"""
 
@@ -314,7 +314,7 @@ class BaseXCom(Base, LoggingMixin):
         dag_ids: Optional[Union[str, Iterable[str]]] = None,
         include_prior_dates: bool = False,
         limit: Optional[int] = None,
-        session: Session = None,
+        session: Session = NEW_SESSION,
         *,
         run_id: Optional[str] = None,
     ) -> Query:
@@ -397,7 +397,7 @@ class BaseXCom(Base, LoggingMixin):
         execution_date: pendulum.DateTime,
         dag_id: str,
         task_id: str,
-        session: Optional[Session] = None,
+        session: Session = NEW_SESSION,
     ) -> None:
         """:sphinx-autoapi-skip:"""
 
@@ -409,7 +409,7 @@ class BaseXCom(Base, LoggingMixin):
         dag_id: Optional[str] = None,
         task_id: Optional[str] = None,
         run_id: Optional[str] = None,
-        session: Session = None,
+        session: Session = NEW_SESSION,
     ) -> None:
         """:sphinx-autoapi-skip:"""
         # Given the historic order of this function (execution_date was first argument) to add a new optional
diff --git a/airflow/settings.py b/airflow/settings.py
index f9b97a2..139d6a4 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -22,7 +22,7 @@ import logging
 import os
 import sys
 import warnings
-from typing import Optional
+from typing import TYPE_CHECKING, Callable, List, Optional
 
 import pendulum
 import sqlalchemy
@@ -37,6 +37,9 @@ from airflow.executors import executor_constants
 from airflow.logging_config import configure_logging
 from airflow.utils.orm_event_handlers import setup_event_handlers
 
+if TYPE_CHECKING:
+    from airflow.www.utils import UIAlert
+
 log = logging.getLogger(__name__)
 
 
@@ -77,7 +80,7 @@ DONOT_MODIFY_HANDLERS: Optional[bool] = None
 DAGS_FOLDER: str = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
 
 engine: Optional[Engine] = None
-Session: Optional[SASession] = None
+Session: Callable[..., SASession]
 
 # The JSON library to use for DAG Serialization and De-Serialization
 json = json
@@ -563,8 +566,7 @@ MASK_SECRETS_IN_LOGS = False
 #       UIAlert('Visit <a href="http://airflow.apache.org">airflow.apache.org</a>', html=True),
 #   ]
 #
-# DASHBOARD_UIALERTS: List["UIAlert"]
-DASHBOARD_UIALERTS = []
+DASHBOARD_UIALERTS: List["UIAlert"] = []
 
 # Prefix used to identify tables holding data moved during migration.
 AIRFLOW_MOVED_TABLE_PREFIX = "_airflow_moved"
diff --git a/airflow/utils/session.py b/airflow/utils/session.py
index 9636fc4..f0c3168 100644
--- a/airflow/utils/session.py
+++ b/airflow/utils/session.py
@@ -18,7 +18,7 @@
 import contextlib
 from functools import wraps
 from inspect import signature
-from typing import Callable, Iterator, TypeVar
+from typing import Callable, Iterator, TypeVar, cast
 
 from airflow import settings
 
@@ -26,7 +26,7 @@ from airflow import settings
 @contextlib.contextmanager
 def create_session() -> Iterator[settings.SASession]:
     """Contextmanager that will create and teardown a session."""
-    session: settings.SASession = settings.Session()
+    session = settings.Session()
     try:
         yield session
         session.commit()
@@ -105,3 +105,10 @@ def create_global_lock(session=None, pg_lock_id=1, lock_name='init', mysql_lock_
         if dialect.name == 'mssql':
             # TODO: make locking works for MSSQL
             pass
+
+
+# A fake session to use in functions decorated by provide_session. This allows
+# the 'session' argument to be of type Session instead of Optional[Session],
+# making it easier to type hint the function body without dealing with the None
+# case that can never happen at runtime.
+NEW_SESSION: settings.SASession = cast(settings.SASession, None)