You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/01/28 21:25:24 UTC
[airflow] 11/17: Helper for provide_session-decorated functions (#20104)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit d5721619bc970e42abdc474362fcf9c151bc33a8
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Dec 7 21:50:34 2021 +0800
Helper for provide_session-decorated functions (#20104)
* Helper for provide_session-decorated functions
* Apply NEW_SESSION trick on XCom
(cherry picked from commit a80ac1eecc0ea187de7984510b4ef6f981b97196)
---
airflow/models/xcom.py | 24 ++++++++++++------------
airflow/settings.py | 10 ++++++----
airflow/utils/session.py | 11 +++++++++--
3 files changed, 27 insertions(+), 18 deletions(-)
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 4bb9689..5efaa0a 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -32,7 +32,7 @@ from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
from airflow.utils import timezone
from airflow.utils.helpers import is_container
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
log = logging.getLogger(__name__)
@@ -90,7 +90,7 @@ class BaseXCom(Base, LoggingMixin):
dag_id: str,
task_id: str,
run_id: str,
- session: Optional[Session] = None,
+ session: Session = NEW_SESSION,
) -> None:
"""Store an XCom value.
@@ -116,7 +116,7 @@ class BaseXCom(Base, LoggingMixin):
task_id: str,
dag_id: str,
execution_date: datetime.datetime,
- session: Optional[Session] = None,
+ session: Session = NEW_SESSION,
) -> None:
""":sphinx-autoapi-skip:"""
@@ -129,7 +129,7 @@ class BaseXCom(Base, LoggingMixin):
task_id: str,
dag_id: str,
execution_date: Optional[datetime.datetime] = None,
- session: Session = None,
+ session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
) -> None:
@@ -170,7 +170,7 @@ class BaseXCom(Base, LoggingMixin):
task_id: Optional[str] = None,
dag_id: Optional[str] = None,
include_prior_dates: bool = False,
- session: Optional[Session] = None,
+ session: Session = NEW_SESSION,
) -> Optional[Any]:
"""Retrieve an XCom value, optionally meeting certain criteria.
@@ -207,7 +207,7 @@ class BaseXCom(Base, LoggingMixin):
task_id: Optional[str] = None,
dag_id: Optional[str] = None,
include_prior_dates: bool = False,
- session: Optional[Session] = None,
+ session: Session = NEW_SESSION,
) -> Optional[Any]:
""":sphinx-autoapi-skip:"""
@@ -220,7 +220,7 @@ class BaseXCom(Base, LoggingMixin):
task_id: Optional[Union[str, Iterable[str]]] = None,
dag_id: Optional[Union[str, Iterable[str]]] = None,
include_prior_dates: bool = False,
- session: Session = None,
+ session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
) -> Optional[Any]:
@@ -265,7 +265,7 @@ class BaseXCom(Base, LoggingMixin):
dag_ids: Union[str, Iterable[str], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
- session: Optional[Session] = None,
+ session: Session = NEW_SESSION,
) -> Query:
"""Composes a query to get one or more XCom entries.
@@ -300,7 +300,7 @@ class BaseXCom(Base, LoggingMixin):
dag_ids: Union[str, Iterable[str], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
- session: Optional[Session] = None,
+ session: Session = NEW_SESSION,
) -> Query:
""":sphinx-autoapi-skip:"""
@@ -314,7 +314,7 @@ class BaseXCom(Base, LoggingMixin):
dag_ids: Optional[Union[str, Iterable[str]]] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
- session: Session = None,
+ session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
) -> Query:
@@ -397,7 +397,7 @@ class BaseXCom(Base, LoggingMixin):
execution_date: pendulum.DateTime,
dag_id: str,
task_id: str,
- session: Optional[Session] = None,
+ session: Session = NEW_SESSION,
) -> None:
""":sphinx-autoapi-skip:"""
@@ -409,7 +409,7 @@ class BaseXCom(Base, LoggingMixin):
dag_id: Optional[str] = None,
task_id: Optional[str] = None,
run_id: Optional[str] = None,
- session: Session = None,
+ session: Session = NEW_SESSION,
) -> None:
""":sphinx-autoapi-skip:"""
# Given the historic order of this function (execution_date was first argument) to add a new optional
diff --git a/airflow/settings.py b/airflow/settings.py
index f9b97a2..139d6a4 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -22,7 +22,7 @@ import logging
import os
import sys
import warnings
-from typing import Optional
+from typing import TYPE_CHECKING, Callable, List, Optional
import pendulum
import sqlalchemy
@@ -37,6 +37,9 @@ from airflow.executors import executor_constants
from airflow.logging_config import configure_logging
from airflow.utils.orm_event_handlers import setup_event_handlers
+if TYPE_CHECKING:
+ from airflow.www.utils import UIAlert
+
log = logging.getLogger(__name__)
@@ -77,7 +80,7 @@ DONOT_MODIFY_HANDLERS: Optional[bool] = None
DAGS_FOLDER: str = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
engine: Optional[Engine] = None
-Session: Optional[SASession] = None
+Session: Callable[..., SASession]
# The JSON library to use for DAG Serialization and De-Serialization
json = json
@@ -563,8 +566,7 @@ MASK_SECRETS_IN_LOGS = False
# UIAlert('Visit <a href="http://airflow.apache.org">airflow.apache.org</a>', html=True),
# ]
#
-# DASHBOARD_UIALERTS: List["UIAlert"]
-DASHBOARD_UIALERTS = []
+DASHBOARD_UIALERTS: List["UIAlert"] = []
# Prefix used to identify tables holding data moved during migration.
AIRFLOW_MOVED_TABLE_PREFIX = "_airflow_moved"
diff --git a/airflow/utils/session.py b/airflow/utils/session.py
index 9636fc4..f0c3168 100644
--- a/airflow/utils/session.py
+++ b/airflow/utils/session.py
@@ -18,7 +18,7 @@
import contextlib
from functools import wraps
from inspect import signature
-from typing import Callable, Iterator, TypeVar
+from typing import Callable, Iterator, TypeVar, cast
from airflow import settings
@@ -26,7 +26,7 @@ from airflow import settings
@contextlib.contextmanager
def create_session() -> Iterator[settings.SASession]:
"""Contextmanager that will create and teardown a session."""
- session: settings.SASession = settings.Session()
+ session = settings.Session()
try:
yield session
session.commit()
@@ -105,3 +105,10 @@ def create_global_lock(session=None, pg_lock_id=1, lock_name='init', mysql_lock_
if dialect.name == 'mssql':
# TODO: make locking works for MSSQL
pass
+
+
+# A fake session to use in functions decorated by provide_session. This allows
+# the 'session' argument to be of type Session instead of Optional[Session],
+# making it easier to type hint the function body without dealing with the None
+# case that can never happen at runtime.
+NEW_SESSION: settings.SASession = cast(settings.SASession, None)